Commit 6d7276b8 authored by Yuxin Wu's avatar Yuxin Wu

use LOCAL_VARIABLES in replicated trainer, so duplicated vars won't get saved

parent 4fa66545
......@@ -31,6 +31,7 @@ class NewSessionCreator(tf.train.SessionCreator):
def create_session(self):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logger.info("Global variables initialized.")
return sess
......
......@@ -20,7 +20,7 @@ class TowerContext(object):
tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower, only used in training.
use_vs (bool): Open a variable scope with this name.
use_vs (bool): Open a new variable scope with this name.
"""
self._name = tower_name
self._is_training = bool(is_training)
......
......@@ -8,30 +8,15 @@ import os
from six.moves import range
from ..utils import logger
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated']
from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
__all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated']
class DistributedTrainerReplicated(MultiGPUTrainerBase):
......@@ -220,9 +205,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
with override_to_local_variable():
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import operator
from six.moves import zip, range
from ..utils import logger
......@@ -17,6 +16,7 @@ from ..callbacks.graph import RunOp
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .base import Trainer
from .utility import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated',
......@@ -69,25 +69,28 @@ class MultiGPUTrainerBase(Trainer):
ret = []
if devices is not None:
assert len(devices) == len(towers)
if use_vs is not None:
assert len(use_vs) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if use_vs is None:
use_vs = [False] * len(towers)
assert len(use_vs) == len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext(
tower_names[idx],
is_training=True,
index=idx,
use_vs=use_vs[idx]):
use_vs=usevs):
if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
else:
logger.info("Building graph for training tower {} on device {}...".format(idx, device))
# When use_vs is True, use LOCAL_VARIABLES,
# so these duplicated variables won't be saved by default.
with override_to_local_variable(enable=usevs):
ret.append(func())
if idx == 0:
......@@ -111,37 +114,6 @@ class MultiGPUTrainerBase(Trainer):
return model.get_cost_and_grad()[1]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
"""
A data-parallel multi-GPU trainer. It builds one tower on each GPU with
......@@ -308,6 +280,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
for idx in range(len(tower)):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
# apply_gradients may create variables. Make them LOCAL_VARIABLES
with override_to_local_variable(enable=idx > 0):
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
train_op = tf.group(*train_ops, name='train_op')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
import tensorflow as tf
from contextlib import contextmanager
import operator
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment