Commit adb45736 authored by Yuxin Wu's avatar Yuxin Wu

fix trainconfig

parent acd7f798
...@@ -55,6 +55,13 @@ class TrainConfig(object): ...@@ -55,6 +55,13 @@ class TrainConfig(object):
self.max_epoch = int(kwargs.pop('max_epoch', 99999)) self.max_epoch = int(kwargs.pop('max_epoch', 99999))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
if 'nr_tower' in kwargs or 'tower' in kwargs:
self.set_tower(**kwargs)
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, **kwargs):
nr_tower = kwargs.pop('nr_tower', None) nr_tower = kwargs.pop('nr_tower', None)
tower = kwargs.pop('tower', None) tower = kwargs.pop('tower', None)
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!" assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
...@@ -64,7 +71,4 @@ class TrainConfig(object): ...@@ -64,7 +71,4 @@ class TrainConfig(object):
if isinstance(tower, int): if isinstance(tower, int):
tower = list(range(tower)) tower = list(range(tower))
self.tower = tower self.tower = tower
assert isinstance(self.tower, list)
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -93,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -93,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# sync have consistent effective learning rate # sync have consistent effective learning rate
def scale(grads): def scale(grads):
with tf.name_scope('async_scale_grad'): with tf.name_scope('async_scale_grad'):
return [(grad / self.config.nr_tower if grad is not None else None, var) return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads] for grad, var in grads]
grad_list = map(scale, grad_list) grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list] grad_list = [self.process_grads(g) for g in grad_list]
...@@ -113,7 +113,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -113,7 +113,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# itertools.count is atomic w.r.t. python threads # itertools.count is atomic w.r.t. python threads
self.async_step_counter = itertools.count() self.async_step_counter = itertools.count()
self.training_threads = [] self.training_threads = []
for k in range(1, self.config.nr_tower): for k in range(1, len(self.config.tower)):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op]) self.sess.run([op])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment