Commit 6d7276b8 authored by Yuxin Wu's avatar Yuxin Wu

use LOCAL_VARIABLES in replicated trainer, so duplicated vars won't get saved

parent 4fa66545
...@@ -31,6 +31,7 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -31,6 +31,7 @@ class NewSessionCreator(tf.train.SessionCreator):
def create_session(self): def create_session(self):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config) sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logger.info("Global variables initialized.") logger.info("Global variables initialized.")
return sess return sess
......
...@@ -20,7 +20,7 @@ class TowerContext(object): ...@@ -20,7 +20,7 @@ class TowerContext(object):
tower_name (str): The name scope of the tower. tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower, only used in training. index (int): index of this tower, only used in training.
use_vs (bool): Open a variable scope with this name. use_vs (bool): Open a new variable scope with this name.
""" """
self._name = tower_name self._name = tower_name
self._is_training = bool(is_training) self._is_training = bool(is_training)
......
...@@ -8,30 +8,15 @@ import os ...@@ -8,30 +8,15 @@ import os
from six.moves import range from six.moves import range
from ..utils import logger from ..utils import logger
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated'] from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
class OverrideToLocalVariable(object): __all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated']
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
class DistributedTrainerReplicated(MultiGPUTrainerBase): class DistributedTrainerReplicated(MultiGPUTrainerBase):
...@@ -220,9 +205,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -220,9 +205,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
with tf.variable_scope( with override_to_local_variable():
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import operator
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
...@@ -17,6 +16,7 @@ from ..callbacks.graph import RunOp ...@@ -17,6 +16,7 @@ from ..callbacks.graph import RunOp
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .base import Trainer from .base import Trainer
from .utility import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter', __all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
...@@ -69,26 +69,29 @@ class MultiGPUTrainerBase(Trainer): ...@@ -69,26 +69,29 @@ class MultiGPUTrainerBase(Trainer):
ret = [] ret = []
if devices is not None: if devices is not None:
assert len(devices) == len(towers) assert len(devices) == len(towers)
if use_vs is not None:
assert len(use_vs) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))] tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
keys_to_freeze = TOWER_FREEZE_KEYS[:] keys_to_freeze = TOWER_FREEZE_KEYS[:]
if use_vs is None:
use_vs = [False] * len(towers)
assert len(use_vs) == len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext( with tf.device(device), TowerContext(
tower_names[idx], tower_names[idx],
is_training=True, is_training=True,
index=idx, index=idx,
use_vs=use_vs[idx]): use_vs=usevs):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
else: else:
logger.info("Building graph for training tower {} on device {}...".format(idx, device)) logger.info("Building graph for training tower {} on device {}...".format(idx, device))
ret.append(func()) # When use_vs is True, use LOCAL_VARIABLES,
# so these duplicated variables won't be saved by default.
with override_to_local_variable(enable=usevs):
ret.append(func())
if idx == 0: if idx == 0:
# avoid duplicated summary & update_ops from each device # avoid duplicated summary & update_ops from each device
...@@ -111,37 +114,6 @@ class MultiGPUTrainerBase(Trainer): ...@@ -111,37 +114,6 @@ class MultiGPUTrainerBase(Trainer):
return model.get_cost_and_grad()[1] return model.get_cost_and_grad()[1]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
""" """
A data-parallel multi-GPU trainer. It builds one tower on each GPU with A data-parallel multi-GPU trainer. It builds one tower on each GPU with
...@@ -308,8 +280,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -308,8 +280,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
for idx in range(len(tower)): for idx in range(len(tower)):
with tf.device(raw_devices[idx]): with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads] grad_and_vars = [x[idx] for x in grads]
train_ops.append(opt.apply_gradients( # apply_gradients may create variables. Make them LOCAL_VARIABLES
grad_and_vars, name='apply_grad_{}'.format(idx))) with override_to_local_variable(enable=idx > 0):
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
train_op = tf.group(*train_ops, name='train_op') train_op = tf.group(*train_ops, name='train_op')
cb = RunOp( cb = RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops, SyncMultiGPUTrainerReplicated.get_post_init_ops,
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
import tensorflow as tf
from contextlib import contextmanager
import operator
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment