Commit c04c1ef8 authored by Yuxin Wu's avatar Yuxin Wu

pass Server to trainer

parent b0677681
...@@ -177,6 +177,7 @@ class Trainer(object): ...@@ -177,6 +177,7 @@ class Trainer(object):
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError): except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
......
...@@ -49,11 +49,14 @@ class OverrideToLocalVariableIfNotPsVar(object): ...@@ -49,11 +49,14 @@ class OverrideToLocalVariableIfNotPsVar(object):
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def __init__(self, config, job_name, task_index, cluster): def __init__(self, config, server):
assert job_name in ['ps', 'worker'], job_name self.server = server
self.job_name = job_name server_def = server.server_def
self.task_index = task_index self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.cluster = cluster self.job_name = server_def.job_name
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], job_name
self._input_source = config.data self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker') self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config) super(DistributedReplicatedTrainer, self).__init__(config)
...@@ -76,9 +79,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -76,9 +79,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if self.nr_gpu > 1: if self.nr_gpu > 1:
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
if not isinstance(self._input_source, StagingInputWrapper): #if not isinstance(self._input_source, StagingInputWrapper):
self._input_source = StagingInputWrapper(self._input_source, self.raw_devices) #self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@staticmethod @staticmethod
def _average_grads(tower_grads, devices): def _average_grads(tower_grads, devices):
...@@ -96,7 +100,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -96,7 +100,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return tower_grads[0] return tower_grads[0]
new_tower_grads = [] new_tower_grads = []
with tf.name_scope('AvgGrad'): with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*grad_list)): for i, grad_and_vars in enumerate(zip(*tower_grads)):
# Ngpu * 2 # Ngpu * 2
with tf.device(devices[i % nr_device]): with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
...@@ -150,18 +154,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -150,18 +154,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return var_update_ops return var_update_ops
def _setup(self): def _setup(self):
conf = get_default_sess_config()
self.server = tf.train.Server(
self.cluster, job_name=self.job_name,
task_index=self.task_index,
config=conf # TODO sessconfig
)
if self.job_name == 'ps': if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index)) logger.info("Running ps {}".format(self.task_index))
self.server.join() self.server.join()
return return # TODO exit and skip mainloop how?
opt = self.model.get_optimizer() # in global scope, not local
super(DistributedReplicatedTrainer, self)._setup()
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()): custom_getter=OverrideToLocalVariableIfNotPsVar()):
...@@ -185,9 +183,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -185,9 +183,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
gs = get_global_step_var() gs = get_global_step_var()
opt = self.model.get_optimizer() # in global scope, not local opt = self.model.get_optimizer() # in global scope, not local
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self)
self._setup() self._setup()
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
......
...@@ -367,6 +367,7 @@ class DummyConstantInput(TensorInput): ...@@ -367,6 +367,7 @@ class DummyConstantInput(TensorInput):
def fn(): def fn():
tlist = [] tlist = []
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs) assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs): for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable( tlist.append(tf.get_variable(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment