Commit c04c1ef8 authored by Yuxin Wu's avatar Yuxin Wu

pass Server to trainer

parent b0677681
......@@ -177,6 +177,7 @@ class Trainer(object):
# trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.")
except KeyboardInterrupt:
......
......@@ -49,11 +49,14 @@ class OverrideToLocalVariableIfNotPsVar(object):
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def __init__(self, config, job_name, task_index, cluster):
assert job_name in ['ps', 'worker'], job_name
self.job_name = job_name
self.task_index = task_index
self.cluster = cluster
def __init__(self, config, server):
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], job_name
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
......@@ -76,9 +79,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if self.nr_gpu > 1:
assert tf.test.is_gpu_available()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs
if not isinstance(self._input_source, StagingInputWrapper):
self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
#if not isinstance(self._input_source, StagingInputWrapper):
#self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@staticmethod
def _average_grads(tower_grads, devices):
......@@ -96,7 +100,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*grad_list)):
for i, grad_and_vars in enumerate(zip(*tower_grads)):
# Ngpu * 2
with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1]
......@@ -150,18 +154,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return var_update_ops
def _setup(self):
conf = get_default_sess_config()
self.server = tf.train.Server(
self.cluster, job_name=self.job_name,
task_index=self.task_index,
config=conf # TODO sessconfig
)
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index))
self.server.join()
return
opt = self.model.get_optimizer() # in global scope, not local
return # TODO exit and skip mainloop how?
super(DistributedReplicatedTrainer, self)._setup()
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()):
......@@ -185,9 +183,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with tf.device(self.param_server_device):
gs = get_global_step_var()
opt = self.model.get_optimizer() # in global scope, not local
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self)
self._setup()
self.monitors = Monitors(self.monitors)
......
......@@ -367,6 +367,7 @@ class DummyConstantInput(TensorInput):
def fn():
tlist = []
ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment