Commit ca0f0bd0 authored by Yuxin Wu's avatar Yuxin Wu

put device ctx into TowerContext

parent b75ed18c
......@@ -186,10 +186,10 @@ class PredictorTowerBuilder(object):
msg = "Building predictor graph {} on gpu={} ...".format(towername, tower)
logger.info(msg)
# No matter where this get called, clear any existing name scope.
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.device('/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
TowerContext(towername, device=device, is_training=False):
self._fn(tower)
# useful only when the placeholders don't have tower prefix
......
......@@ -15,13 +15,20 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None):
def __init__(self, tower_name, device=None, is_training=None):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
"""
self._name = tower_name
if device is None:
device = '/gpu:0' if tf.test.is_gpu_available() else '/cpu:0'
assert self.index == int(device[-1]), \
"Tower name {} and device {} mismatch!".format(self._name, device)
self._device = device
if is_training is None:
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = is_training
......@@ -48,6 +55,10 @@ class TowerContext(object):
return 0
return int(self._name[-1])
@property
def device(self):
return self._device
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
......@@ -79,16 +90,18 @@ class TowerContext(object):
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
# TODO enter name_scope(None) first
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
self._scope_ctx = tf.name_scope(self._name)
self._scope_ctx.__enter__()
self._device_ctx = tf.device(self._device)
self._device_ctx.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
self._scope_ctx.__exit__(exc_type, exc_val, exc_tb)
self._device_ctx.__exit__(exc_type, exc_val, exc_tb)
return False
def __str__(self):
......
......@@ -27,8 +27,8 @@ class FeedfreeTrainerBase(Trainer):
self._input_tensors = self._input_method.get_input_tensors()
self.model.build_graph(self._input_tensors)
ctx = get_current_tower_context()
if ctx is None:
with TowerContext(''):
if ctx is None: # call without a context, use a default one
with TowerContext('', is_training=True):
f()
else:
assert ctx.is_training, ctx
......
......@@ -17,6 +17,7 @@ from six.moves import range
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.concurrency import ShareSessionThread
......@@ -168,13 +169,14 @@ class QueueInput(FeedfreeInput):
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
with tf.device('/cpu:0'):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
class BatchQueueInput(FeedfreeInput):
......@@ -232,15 +234,16 @@ class BatchQueueInput(FeedfreeInput):
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
shp = v.get_shape().as_list()
shp[0] = self.batch_size
qv.set_shape(shp)
return ret
with tf.device('/cpu:0'):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
shp = v.get_shape().as_list()
shp[0] = self.batch_size
qv.set_shape(shp)
return ret
class DummyConstantInput(FeedfreeInput):
......@@ -254,7 +257,6 @@ class DummyConstantInput(FeedfreeInput):
"""
self.shapes = shapes
logger.warn("Using dummy input for debug!")
self._cnt = 0
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
......@@ -271,7 +273,6 @@ class DummyConstantInput(FeedfreeInput):
# don't share variables
for tower in range(nr_tower):
tlist = []
# TODO. keep device info in tower
with tf.device('/gpu:{}'.format(tower)):
for idx, p in enumerate(placehdrs):
tlist.append(tf.get_variable(
......@@ -280,9 +281,8 @@ class DummyConstantInput(FeedfreeInput):
self.tensors.append(tlist)
def get_input_tensors(self):
# TODO XXX call with tower index
ret = self.tensors[self._cnt]
self._cnt += 1
ctx = get_current_tower_context()
ret = self.tensors[ctx.index]
return ret
......@@ -359,8 +359,6 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops = []
self._unstage_ops = []
self._cnt_unstage = 0
def setup(self, model):
self._input.setup(model)
self.setup_staging_areas()
......@@ -390,10 +388,8 @@ class StagingInputWrapper(FeedfreeInput):
return self._input.size()
def get_input_tensors(self):
assert self._cnt_unstage < len(self._areas)
assert len(self._areas) == len(self._devices)
ret = self._unstage_ops[self._cnt_unstage]
self._cnt_unstage += 1
ctx = get_current_tower_context()
ret = self._unstage_ops[ctx.index]
return ret
@staticmethod
......
......@@ -40,9 +40,11 @@ class MultiGPUTrainer(Trainer):
ret = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)):
with tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext(
'tower{}'.format(idx),
device='/gpu:{}'.format(t),
is_training=True):
logger.info("Building graph for training tower {}...".format(idx))
ret.append(func())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment