Commit e6a136ce authored by Yuxin Wu's avatar Yuxin Wu

some update

parent fc7f0aac
......@@ -92,6 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 3
# prepare session
sess_config = get_default_sess_config()
......
......@@ -85,18 +85,27 @@ class FixedSizeData(DataFlow):
return
class RepeatedData(DataFlow):
""" repeat another dataflow for certain times"""
""" repeat another dataflow for certain times
if nr == -1, repeat infinitely many times
"""
def __init__(self, ds, nr):
self.nr = nr
self.ds = ds
def size(self):
if self.nr == -1:
raise RuntimeError(), "size() is unavailable for infinite dataflow"
return self.ds.size() * self.nr
def get_data(self):
for _ in xrange(self.nr):
for dp in self.ds.get_data():
yield dp
if self.nr == -1:
while True:
for dp in self.ds.get_data():
yield dp
else:
for _ in xrange(self.nr):
for dp in self.ds.get_data():
yield dp
class FakeData(DataFlow):
""" Build fake random data of given shapes"""
......
......@@ -35,6 +35,8 @@ class ModelDesc(object):
def get_input_queue(self):
"""
return the queue for input. the dequeued elements will be fed to self.get_cost
if queue is None, datapoints from dataflow will be fed to the graph directly.
when running with multiGPU, queue cannot be None
"""
assert self.input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue')
......
......@@ -11,6 +11,7 @@ import re
import tqdm
from models import ModelDesc
from dataflow.common import RepeatedData
from utils import *
from utils.concurrency import EnqueueThread
from callbacks import *
......@@ -25,8 +26,7 @@ class TrainConfig(object):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer
for trainig. default to an AdamOptimizer
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
......@@ -44,16 +44,17 @@ class TrainConfig(object):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer', tf.train.AdamOptimizer())
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
......@@ -101,46 +102,45 @@ def start_train(config):
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
callbacks = config.callbacks
tf.add_to_collection(MODEL_KEY, model)
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor):
if isinstance(model_inputs, tf.Tensor): # only one input
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
return model_inputs
enqueue_op = input_queue.enqueue(input_vars)
# get gradients to update:
logger.info("Training a model of {} tower".format(config.nr_tower))
if config.nr_tower > 1:
logger.info("Training a model of {} tower".format(config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grads = []
grad_list = []
for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads.append(
grad_list.append(
config.optimizer.compute_gradients(cost_var))
if i == 0:
tf.get_variable_scope().reuse_variables()
for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys: # avoid repeating summary on multiple devices
for k in coll_keys:
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_grads(grads)
grads = average_grads(grad_list)
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
......@@ -156,9 +156,10 @@ def start_train(config):
# start training:
coord = tf.train.Coordinator()
# a thread that keeps filling the queue
model_th = tf.train.start_queue_runners(
tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True)
# create a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
input_th.start()
......@@ -169,21 +170,23 @@ def start_train(config):
tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)):
with timed_operation('Epoch {}'.format(epoch)):
for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.5, dynamic_ncols=True):
config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if coord.should_stop():
return
sess.run([train_op]) # faster since train_op return None
callbacks.trigger_step()
# note that summary_op will take a data from the queue.
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
coord.request_stop()
# Do I need to run queue.close
# Do I need to run queue.close?
callbacks.after_train()
sess.close()
......@@ -28,7 +28,7 @@ def timed_operation(msg, log_start=False):
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('finished {}, time={:.2f}sec.'.format(
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
def get_default_sess_config():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment