Commit adb45736 authored by Yuxin Wu's avatar Yuxin Wu

fix trainconfig

parent acd7f798
......@@ -55,6 +55,13 @@ class TrainConfig(object):
self.max_epoch = int(kwargs.pop('max_epoch', 99999))
assert self.step_per_epoch > 0 and self.max_epoch > 0
if 'nr_tower' in kwargs or 'tower' in kwargs:
self.set_tower(**kwargs)
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, **kwargs):
nr_tower = kwargs.pop('nr_tower', None)
tower = kwargs.pop('tower', None)
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
......@@ -64,7 +71,4 @@ class TrainConfig(object):
if isinstance(tower, int):
tower = list(range(tower))
self.tower = tower
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
assert isinstance(self.tower, list)
......@@ -93,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# sync have consistent effective learning rate
def scale(grads):
with tf.name_scope('async_scale_grad'):
return [(grad / self.config.nr_tower if grad is not None else None, var)
return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
......@@ -113,7 +113,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# itertools.count is atomic w.r.t. python threads
self.async_step_counter = itertools.count()
self.training_threads = []
for k in range(1, self.config.nr_tower):
for k in range(1, len(self.config.tower)):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
self.sess.run([op])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment