Commit ca16fb7e authored by Yuxin Wu's avatar Yuxin Wu

changes in tower to allow replicated training

parent 118c2a26
...@@ -20,7 +20,7 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net) ...@@ -20,7 +20,7 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net)
To use the script. You'll need: To use the script. You'll need:
+ TensorFlow >= 1.0.0rc0 + TensorFlow >= 1.0.0 (>=1.1 for MultiGPU)
+ OpenCV bindings for Python + OpenCV bindings for Python
......
...@@ -38,12 +38,16 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -38,12 +38,16 @@ def regularize_cost(regex, func, name='regularize_cost'):
cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5)) cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5))
""" """
ctx = get_current_tower_context()
G = tf.get_default_graph() G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
costs = [] costs = []
for p in params: for p in params:
para_name = p.name para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.name)):
continue
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) costs.append(func(p))
_log_regularizer(para_name) _log_regularizer(para_name)
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from termcolor import colored from termcolor import colored
from tabulate import tabulate from tabulate import tabulate
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from .summary import add_moving_summary from .summary import add_moving_summary
...@@ -62,7 +63,9 @@ def apply_slim_collections(cost): ...@@ -62,7 +63,9 @@ def apply_slim_collections(cost):
a scalar tensor, the cost after applying the collections. a scalar tensor, the cost after applying the collections.
""" """
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if len(regulization_losses) > 0: if len(regulization_losses) > 0:
assert not ctx.has_own_variables, "REGULARIZATION_LOSSES collection doesn't work in replicated mode!"
logger.info("Applying REGULARIZATION_LOSSES on cost.") logger.info("Applying REGULARIZATION_LOSSES on cost.")
reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss") reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss")
cost = tf.add(reg_loss, cost, name='total_cost') cost = tf.add(reg_loss, cost, name='total_cost')
......
...@@ -15,12 +15,15 @@ _CurrentTowerContext = None ...@@ -15,12 +15,15 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, device=None, is_training=None): def __init__(self, tower_name,
device=None, is_training=None,
var_strategy='shared'):
""" """
Args: Args:
tower_name (str): 'tower0', 'towerp0', or '' tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0. device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'.
""" """
self._name = tower_name self._name = tower_name
if device is None: if device is None:
...@@ -31,6 +34,11 @@ class TowerContext(object): ...@@ -31,6 +34,11 @@ class TowerContext(object):
is_training = not self._name.startswith(PREDICT_TOWER) is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = is_training self._is_training = is_training
assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy
if self._var_strategy == 'replicated':
assert self._name
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0') return self.is_training and (self._name == '' or self._name == 'tower0')
...@@ -43,6 +51,10 @@ class TowerContext(object): ...@@ -43,6 +51,10 @@ class TowerContext(object):
def is_training(self): def is_training(self):
return self._is_training return self._is_training
@property
def has_own_variables(self):
return self._var_strategy == 'replicated'
@property @property
def name(self): def name(self):
return self._name return self._name
...@@ -88,18 +100,25 @@ class TowerContext(object): ...@@ -88,18 +100,25 @@ class TowerContext(object):
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, \
"Nesting TowerContext!" "Nesting TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
self._ctxs = []
if len(self._name): if len(self._name):
self._scope_ctx = tf.name_scope(self._name) if self.has_own_variables:
self._scope_ctx.__enter__() # open new variable scopes
self._device_ctx = tf.device(self._device) self._ctxs.append(tf.variable_scope(self._name))
self._device_ctx.__enter__() else:
# use existing variable scope
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=self.index > 0))
self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device))
for c in self._ctxs:
c.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext global _CurrentTowerContext
_CurrentTowerContext = None _CurrentTowerContext = None
if len(self._name): for c in self._ctxs[::-1]:
self._scope_ctx.__exit__(exc_type, exc_val, exc_tb) c.__exit__(exc_type, exc_val, exc_tb)
self._device_ctx.__exit__(exc_type, exc_val, exc_tb)
return False return False
def __str__(self): def __str__(self):
......
...@@ -27,12 +27,13 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', ...@@ -27,12 +27,13 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
class MultiGPUTrainerBase(Trainer): class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def build_on_multi_tower(towers, func, devices=None): def build_on_multi_tower(towers, func, devices=None, var_strategy='shared'):
""" """
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers. devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str):
Returns: Returns:
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
...@@ -40,17 +41,19 @@ class MultiGPUTrainerBase(Trainer): ...@@ -40,17 +41,19 @@ class MultiGPUTrainerBase(Trainer):
logger.info("Training a model of {} tower".format(len(towers))) logger.info("Training a model of {} tower".format(len(towers)))
ret = [] ret = []
global_scope = tf.get_variable_scope()
if devices is not None: if devices is not None:
assert len(devices) == len(towers) assert len(devices) == len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with tf.variable_scope(global_scope, reuse=idx > 0), \ with TowerContext(
TowerContext(
'tower{}'.format(idx), 'tower{}'.format(idx),
device=device, device=device, is_training=True,
is_training=True): var_strategy=var_strategy):
logger.info("Building graph for training tower {}...".format(idx)) if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
else:
logger.info("Building graph for training tower {} on device {}...".format(idx, t))
ret.append(func()) ret.append(func())
...@@ -92,14 +95,15 @@ class LeastLoadedDeviceSetter(object): ...@@ -92,14 +95,15 @@ class LeastLoadedDeviceSetter(object):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
""" """
A multi-tower multi-GPU trainer which synchronoizes the gradients computed A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored on PS. from each tower, averages them and update to variables stored across all
GPUs or on CPU.
""" """
def __init__(self, config, ps_device='gpu'): def __init__(self, config, ps_device='gpu'):
""" """
Args: Args:
config: same as in :class:`QueueInputTrainer`. config: same as in :class:`QueueInputTrainer`.
ps_device: either 'gpu' or 'cpu' ps_device: either 'gpu' or 'cpu', where variables are stored.
""" """
if config.dataflow is not None: if config.dataflow is not None:
# use queueinput by default. May need to avoid this in the future (when more input type is available) # use queueinput by default. May need to avoid this in the future (when more input type is available)
......
...@@ -123,9 +123,11 @@ def get_caffe_pb(): ...@@ -123,9 +123,11 @@ def get_caffe_pb():
if not os.path.isfile(caffe_pb_file): if not os.path.isfile(caffe_pb_file):
download(CAFFE_PROTO_URL, dir) download(CAFFE_PROTO_URL, dir)
assert os.path.isfile(os.path.join(dir, 'caffe.proto')) assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir)) cmd = 'cd {} && protoc caffe.proto --python_out .'.format(dir)
ret = os.system(cmd)
assert ret == 0, \ assert ret == 0, \
"Command `protoc caffe.proto --python_out .` failed!" "Command `{}` failed!".format(cmd)
assert os.path.isfile(caffe_pb_file), caffe_pb_file
import imp import imp
return imp.load_source('caffepb', caffe_pb_file) return imp.load_source('caffepb', caffe_pb_file)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment