Commit 775f5c9a authored by Yuxin Wu's avatar Yuxin Wu

bug fix in multigpu

parent aaf7eeda
...@@ -85,7 +85,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -85,7 +85,7 @@ def Deconv2D(x, out_shape, kernel_shape,
:returns: a NHWC tensor :returns: a NHWC tensor
""" """
in_shape = x.get_shape().as_list()[1:] in_shape = x.get_shape().as_list()[1:]
assert None is not in in_shape, "Input to Deconv2D cannot have unknown shape!" assert None not in in_shape, "Input to Deconv2D cannot have unknown shape!"
in_channel = in_shape[-1] in_channel = in_shape[-1]
kernel_shape = shape2d(kernel_shape) kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride) stride2d = shape2d(stride)
......
...@@ -15,7 +15,7 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -15,7 +15,7 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer, QueueInputTrainerBase from .queue import QueueInputTrainer, QueueInputTrainerBase
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
...@@ -37,13 +37,15 @@ class MultiGPUTrainer(FeedlessTrainer): ...@@ -37,13 +37,15 @@ class MultiGPUTrainer(FeedlessTrainer):
grad_list.append(get_tower_grad_func()) grad_list.append(get_tower_grad_func())
if idx == 0: if idx == 0:
add_moving_summary(cost_var)
# avoid repeated summary from each device # avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS) backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer): class SyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
...@@ -85,10 +87,12 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeed ...@@ -85,10 +87,12 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeed
def run_step(self): def run_step(self):
self.sess.run(self.train_op) self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer): class AsyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." super(AsyncMultiGPUTrainer, self).__init__(config)
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._build_enque_thread(input_queue)
...@@ -132,7 +136,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFee ...@@ -132,7 +136,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFee
for th in self.training_threads: # resume all threads for th in self.training_threads: # resume all threads
th.resume() th.resume()
next(self.async_step_counter) next(self.async_step_counter)
super(AsyncMultiGPUTrainer, self).run_step() self.sess.run(self.train_op)
def _trigger_epoch(self): def _trigger_epoch(self):
self.async_running = False self.async_running = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment