Commit c2edd999 authored by Yuxin Wu's avatar Yuxin Wu

Allow >=10 GPUs. (fix #310)

parent c1a280a2
...@@ -72,11 +72,15 @@ class TowerContext(object): ...@@ -72,11 +72,15 @@ class TowerContext(object):
def vs_name(self): def vs_name(self):
return self._vs_name return self._vs_name
# TODO pass index into the constructor
@property @property
def index(self): def index(self):
if self._name == '': if self._name == '':
return 0 return 0
return int(self._name[-1]) idx = re.findall('[0-9]+$', self._name)
if len(idx) == 0:
return 0
return int(idx[0])
@property @property
def device(self): def device(self):
......
...@@ -179,7 +179,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -179,7 +179,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if self.job_name == 'ps': if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index)) logger.info("Running ps {}".format(self.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid())) logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return #4713 self.server.join() # this will never return tensorflow#4713
return return
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
gs = get_global_step_var() gs = get_global_step_var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment