Commit c2edd999 authored by Yuxin Wu's avatar Yuxin Wu

Allow >=10 GPUs. (fix #310)

parent c1a280a2
......@@ -72,11 +72,15 @@ class TowerContext(object):
def vs_name(self):
return self._vs_name
# TODO pass index into the constructor
@property
def index(self):
if self._name == '':
return 0
return int(self._name[-1])
idx = re.findall('[0-9]+$', self._name)
if len(idx) == 0:
return 0
return int(idx[0])
@property
def device(self):
......
......@@ -179,7 +179,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return #4713
self.server.join() # this will never return tensorflow#4713
return
with tf.device(self.param_server_device):
gs = get_global_step_var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment