Commit e6a136ce authored by Yuxin Wu's avatar Yuxin Wu

some update

parent fc7f0aac
...@@ -92,6 +92,7 @@ def get_config(): ...@@ -92,6 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 3
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -85,15 +85,24 @@ class FixedSizeData(DataFlow): ...@@ -85,15 +85,24 @@ class FixedSizeData(DataFlow):
return return
class RepeatedData(DataFlow): class RepeatedData(DataFlow):
""" repeat another dataflow for certain times""" """ repeat another dataflow for certain times
if nr == -1, repeat infinitely many times
"""
def __init__(self, ds, nr): def __init__(self, ds, nr):
self.nr = nr self.nr = nr
self.ds = ds self.ds = ds
def size(self): def size(self):
if self.nr == -1:
raise RuntimeError(), "size() is unavailable for infinite dataflow"
return self.ds.size() * self.nr return self.ds.size() * self.nr
def get_data(self): def get_data(self):
if self.nr == -1:
while True:
for dp in self.ds.get_data():
yield dp
else:
for _ in xrange(self.nr): for _ in xrange(self.nr):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
......
...@@ -35,6 +35,8 @@ class ModelDesc(object): ...@@ -35,6 +35,8 @@ class ModelDesc(object):
def get_input_queue(self): def get_input_queue(self):
""" """
return the queue for input. the dequeued elements will be fed to self.get_cost return the queue for input. the dequeued elements will be fed to self.get_cost
if queue is None, datapoints from dataflow will be fed to the graph directly.
when running with multiGPU, queue cannot be None
""" """
assert self.input_vars is not None assert self.input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue') return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue')
......
...@@ -11,6 +11,7 @@ import re ...@@ -11,6 +11,7 @@ import re
import tqdm import tqdm
from models import ModelDesc from models import ModelDesc
from dataflow.common import RepeatedData
from utils import * from utils import *
from utils.concurrency import EnqueueThread from utils.concurrency import EnqueueThread
from callbacks import * from callbacks import *
...@@ -25,8 +26,7 @@ class TrainConfig(object): ...@@ -25,8 +26,7 @@ class TrainConfig(object):
""" """
Args: Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance. dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
for trainig. default to an AdamOptimizer
callbacks: a tensorpack.utils.callback.Callbacks instance. Define callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver SummaryWriter and a PeriodicSaver
...@@ -44,16 +44,17 @@ class TrainConfig(object): ...@@ -44,16 +44,17 @@ class TrainConfig(object):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset') self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow) assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer', tf.train.AdamOptimizer()) self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer) assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks') self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks) assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config()) self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession()) self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size())) self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
...@@ -101,46 +102,45 @@ def start_train(config): ...@@ -101,46 +102,45 @@ def start_train(config):
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
input_queue = model.get_input_queue() input_queue = model.get_input_queue()
callbacks = config.callbacks callbacks = config.callbacks
tf.add_to_collection(MODEL_KEY, model) tf.add_to_collection(MODEL_KEY, model)
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs(): def get_model_inputs():
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): if isinstance(model_inputs, tf.Tensor): # only one input
model_inputs = [model_inputs] model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
return model_inputs return model_inputs
enqueue_op = input_queue.enqueue(input_vars)
# get gradients to update: # get gradients to update:
logger.info("Training a model of {} tower".format(config.nr_tower))
if config.nr_tower > 1: if config.nr_tower > 1:
logger.info("Training a model of {} tower".format(config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {} kept_summaries = {}
grads = [] grad_list = []
for i in range(config.nr_tower): for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads.append( grad_list.append(
config.optimizer.compute_gradients(cost_var)) config.optimizer.compute_gradients(cost_var))
if i == 0: if i == 0:
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in coll_keys: for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k)) kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys: # avoid repeating summary on multiple devices for k in coll_keys:
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
grads = average_grads(grads) grads = average_grads(grad_list)
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var) grads = config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads) check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier()) grads = scale_grads(grads, model.get_lr_multiplier())
...@@ -156,9 +156,10 @@ def start_train(config): ...@@ -156,9 +156,10 @@ def start_train(config):
# start training: # start training:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
# a thread that keeps filling the queue tf.train.start_queue_runners(
model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True) sess=sess, coord=coord, daemon=True, start=True)
# create a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue) input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
input_th.start() input_th.start()
...@@ -169,21 +170,23 @@ def start_train(config): ...@@ -169,21 +170,23 @@ def start_train(config):
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('Epoch {}'.format(epoch)):
for step in tqdm.trange( for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.5, dynamic_ncols=True): config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if coord.should_stop(): if coord.should_stop():
return return
sess.run([train_op]) # faster since train_op return None sess.run([train_op]) # faster since train_op return None
callbacks.trigger_step() callbacks.trigger_step()
# note that summary_op will take a data from the queue. # note that summary_op will take a data from the queue
callbacks.trigger_epoch() callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
raise raise
finally: finally:
coord.request_stop() coord.request_stop()
# Do I need to run queue.close # Do I need to run queue.close?
callbacks.after_train() callbacks.after_train()
sess.close() sess.close()
...@@ -28,7 +28,7 @@ def timed_operation(msg, log_start=False): ...@@ -28,7 +28,7 @@ def timed_operation(msg, log_start=False):
logger.info('start {} ...'.format(msg)) logger.info('start {} ...'.format(msg))
start = time.time() start = time.time()
yield yield
logger.info('finished {}, time={:.2f}sec.'.format( logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
def get_default_sess_config(): def get_default_sess_config():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment