Commit 775f5c9a authored by Yuxin Wu's avatar Yuxin Wu

bug fix in multigpu

parent aaf7eeda
......@@ -85,7 +85,7 @@ def Deconv2D(x, out_shape, kernel_shape,
:returns: a NHWC tensor
"""
in_shape = x.get_shape().as_list()[1:]
assert None is not in in_shape, "Input to Deconv2D cannot have unknown shape!"
assert None not in in_shape, "Input to Deconv2D cannot have unknown shape!"
in_channel = in_shape[-1]
kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride)
......
......@@ -15,7 +15,7 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer
from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer, QueueInputTrainerBase
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
......@@ -37,13 +37,15 @@ class MultiGPUTrainer(FeedlessTrainer):
grad_list.append(get_tower_grad_func())
if idx == 0:
add_moving_summary(cost_var)
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
class SyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
......@@ -85,10 +87,12 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeed
def run_step(self):
self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
class AsyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
......@@ -132,7 +136,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFee
for th in self.training_threads: # resume all threads
th.resume()
next(self.async_step_counter)
super(AsyncMultiGPUTrainer, self).run_step()
self.sess.run(self.train_op)
def _trigger_epoch(self):
self.async_running = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment