Commit 34dfab52 authored by Yuxin Wu's avatar Yuxin Wu

don't assume default device in towercontext (#295)

parent 7c7957fc
......@@ -29,8 +29,6 @@ class TowerContext(object):
'replicated' mode. Defaults to be tower_name.
"""
self._name = tower_name
if device is None:
device = '/gpu:0' if tf.test.is_gpu_available() else '/cpu:0'
self._device = device
if is_training is None:
......@@ -129,7 +127,8 @@ class TowerContext(object):
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device))
if self._device is not None:
self._ctxs.append(tf.device(self._device))
for c in self._ctxs:
c.__enter__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment