Commit e9a6a5af authored by Yuxin Wu's avatar Yuxin Wu

train/ directory

parent 2264b5a3
......@@ -8,7 +8,7 @@ import argparse
import numpy as np
import os
from tensorpack.train import TrainConfig, start_train
from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
......@@ -158,4 +158,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
start_train(config)
QueueInputTrainer(config).train()
......@@ -10,7 +10,7 @@ import numpy as np
import os, sys
import argparse
from tensorpack.train import TrainConfig, start_train
from tensorpack.train import TrainConfig, SimpleTrainer
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
......@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 20
#step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......@@ -131,5 +131,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
start_train(config)
SimpleTrainer(config).train()
......@@ -47,7 +47,6 @@ def BatchNorm(x, is_training, gamma_init=1.0):
x.set_shape(hack_shape)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
print batch_mean
ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema_apply_op = ema.apply([batch_mean, batch_var])
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from abc import ABCMeta
import tqdm
import re
from .config import TrainConfig
from ..utils import *
from ..callbacks import StatHolder
from ..utils.modelutils import describe_model
__all__ = ['Trainer']
class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Config: a `TrainConfig` instance
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.summary_writer.close()
self.sess.close()
def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..callbacks import Callbacks
from ..models import ModelDesc
from ..utils import *
from ..dataflow import DataFlow
__all__ = ['TrainConfig']
class TrainConfig(object):
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......@@ -4,65 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from itertools import count
import copy
import argparse
import re
import tqdm
from abc import ABCMeta
from .models import ModelDesc
from .dataflow.common import RepeatedData
from .utils import *
from .utils.concurrency import EnqueueThread
from .callbacks import *
from .utils.summary import summary_moving_average
from .utils.modelutils import describe_model
from .utils import logger
from .dataflow import DataFlow
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import *
from ..utils.concurrency import EnqueueThread
from ..utils.summary import summary_moving_average
class TrainConfig(object):
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
def summary_grads(grads):
for grad, var in grads:
......@@ -88,95 +39,6 @@ def scale_grads(grads, multiplier):
ret.append((grad, var))
return ret
class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Config: a `TrainConfig` instance
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.summary_writer.close()
self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer):
def run_step(self):
......@@ -200,7 +62,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
# create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
......@@ -280,7 +141,6 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
# create a thread that keeps filling the queue
......@@ -299,11 +159,5 @@ class QueueInputTrainer(Trainer):
def start_train(config):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
tr = SimpleTrainer(config)
#tr = QueueInputTrainer(config)
tr.train()
......@@ -7,7 +7,6 @@ import threading
from contextlib import contextmanager
import tensorflow as tf
from .utils import expand_dim_if_necessary
from .naming import *
from . import logger
......@@ -44,7 +43,6 @@ class EnqueueThread(threading.Thread):
return
feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed)
#print '\nExauhsted!!!'
except tf.errors.CancelledError as e:
pass
except Exception:
......
......@@ -4,19 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
def expand_dim_if_necessary(var, dp):
"""
Args:
var: a tensor
dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var
"""
shape = var.get_shape().as_list()
valid_shape = [k for k in shape if k]
if dp.shape == tuple(valid_shape):
new_shape = [k if k else 1 for k in shape]
dp = dp.reshape(new_shape)
return dp
#def expand_dim_if_necessary(var, dp):
# """
# Args:
# var: a tensor
# dp: a numpy array
# Return a reshaped version of dp, if that makes it match the valid dimension of var
# """
# shape = var.get_shape().as_list()
# valid_shape = [k for k in shape if k]
# if dp.shape == tuple(valid_shape):
# new_shape = [k if k else 1 for k in shape]
# dp = dp.reshape(new_shape)
# return dp
def mkdir_p(dirname):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment