Commit dc68ce0d authored by Yuxin Wu's avatar Yuxin Wu

Fix chief_only for input callbacks

parent f2d2501b
...@@ -30,6 +30,9 @@ class Model(ImageNetModel): ...@@ -30,6 +30,9 @@ class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', mode='resnet'): def __init__(self, depth, data_format='NCHW', mode='resnet'):
super(Model, self).__init__(data_format) super(Model, self).__init__(data_format)
if mode == 'se':
assert depth >= 50
self.mode = mode self.mode = mode
basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock
bottleneck = { bottleneck = {
...@@ -115,9 +118,6 @@ if __name__ == '__main__': ...@@ -115,9 +118,6 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.mode == 'se':
assert args.depth >= 50
model = Model(args.depth, args.data_format, args.mode) model = Model(args.depth, args.data_format, args.mode)
if args.eval: if args.eval:
batch = 128 # something that can run on one gpu batch = 128 # something that can run on one gpu
......
...@@ -142,8 +142,6 @@ class ImageNetModel(ModelDesc): ...@@ -142,8 +142,6 @@ class ImageNetModel(ModelDesc):
image_dtype = tf.uint8 image_dtype = tf.uint8
def __init__(self, data_format='NCHW'): def __init__(self, data_format='NCHW'):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format self.data_format = data_format
def _get_inputs(self): def _get_inputs(self):
......
...@@ -32,9 +32,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput', ...@@ -32,9 +32,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
def _get_reset_callback(df): def _get_reset_callback(df):
ret = CallbackFactory(setup_graph=lambda _: df.reset_state()) return CallbackFactory(setup_graph=lambda _: df.reset_state())
ret.chief_only = False
return ret
class PlaceholderInput(InputSource): class PlaceholderInput(InputSource):
...@@ -240,7 +238,6 @@ class QueueInput(FeedfreeInput): ...@@ -240,7 +238,6 @@ class QueueInput(FeedfreeInput):
def _get_callbacks(self): def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb, self._create_ema_callback(), _get_reset_callback(self._inf_ds)] return [cb, self._create_ema_callback(), _get_reset_callback(self._inf_ds)]
def _get_input_tensors(self): def _get_input_tensors(self):
......
...@@ -115,13 +115,20 @@ class InputSource(object): ...@@ -115,13 +115,20 @@ class InputSource(object):
which is done also through the Callback interface. which is done also through the Callback interface.
This method returns the callbacks and the return value will be memoized. This method returns the callbacks and the return value will be memoized.
All callbacks will be automatically marked as `chief_only=False`,
so they will run on all nodes.
Returns: Returns:
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
""" """
assert self.setup_done() assert self.setup_done()
return [CallbackFactory( ret = [CallbackFactory(
before_train=lambda _: self.reset_state())] + self._get_callbacks() before_train=lambda _: self.reset_state())] + self._get_callbacks()
for r in ret:
r.chief_only = False # no input callbacks should be chief-only
return ret
def _get_callbacks(self): def _get_callbacks(self):
return [] return []
......
...@@ -183,7 +183,6 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -183,7 +183,6 @@ class DistributedTrainerReplicated(SingleCostTrainer):
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster)) logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
super(DistributedTrainerReplicated, self).__init__() super(DistributedTrainerReplicated, self).__init__()
def _setup_input(self, inputs_desc, input):
if self.job_name == 'ps': if self.job_name == 'ps':
# ps shouldn't setup input either # ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index)) logger.info("Running ps {}".format(self.server.server_def.task_index))
...@@ -191,6 +190,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -191,6 +190,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
self.server.join() # this function will never return tensorflow#4713 self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug. Server.join() for ps should never return!") raise RuntimeError("This is a bug. Server.join() for ps should never return!")
def _setup_input(self, inputs_desc, input):
with override_to_local_variable(): with override_to_local_variable():
get_global_step_var() # gs should be local get_global_step_var() # gs should be local
# input source may create variable (queue size summary) # input source may create variable (queue size summary)
...@@ -205,13 +205,13 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -205,13 +205,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn) self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
callbacks = [] callbacks = []
# initial local_vars syncing # Initial syncing vars from PS
cb = RunOp(lambda: initial_sync_op, cb = RunOp(lambda: initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True) run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False cb.chief_only = False
callbacks.append(cb) callbacks.append(cb)
# model_variables syncing # Sync model_variables to PS, only chief needs to do this
if model_sync_op: if model_sync_op:
cb = RunOp(lambda: model_sync_op, cb = RunOp(lambda: model_sync_op,
run_before=False, run_as_trigger=True, verbose=True) run_before=False, run_as_trigger=True, verbose=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment