Commit a53da5ab authored by Yuxin Wu's avatar Yuxin Wu

[distributed] Can save and sync MODEL_VARIABLES

parent 5b18f8be
......@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path
......@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name):
"""
This is a hack temporarily used to improve logging. Do not use this function.
Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
......
......@@ -3,6 +3,7 @@
# File: distributed.py
import tensorflow as tf
import re
from six.moves import range
from ..utils import logger
......@@ -110,6 +111,31 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
ps_var_grads.append((grad, new_v))
return ps_var_grads
@staticmethod
def _shadow_model_variables(shadow_vars):
"""
Create shadow vars for model_variables as well, and add to the list of ``shadow_vars``.
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
curr_shadow_vars = set([v.name for v in shadow_vars])
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the model!"
stripped_name = get_op_tensor_name(re.sub('tower[0-9]+/', '', v.name))[0]
if stripped_name in curr_shadow_vars:
continue
new_v = tf.get_variable(stripped_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value,
trainable=False)
curr_shadow_vars.add(stripped_name) # avoid duplicated shadow_model_vars
shadow_vars.append(new_v)
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars
def _apply_gradients_and_copy(self, raw_grad_list, ps_var_grads):
"""
Args:
......@@ -142,7 +168,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
self.model.get_optimizer() # TODO in global scope, not local
# do this before super.setup because input_source my need global step
super(DistributedReplicatedTrainer, self)._setup()
......@@ -161,16 +186,27 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
self._shadow_vars = [v for (_, v) in ps_var_grads]
self._shadow_vars = [v for (_, v) in ps_var_grads]
self._shadow_model_vars = DistributedReplicatedTrainer._shadow_model_variables(self._shadow_vars)
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
cb = RunOp(self.get_post_init_ops,
# initial local_vars syncing
cb = RunOp(self.get_initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
# model_variables syncing
if len(self._shadow_model_vars) and self.is_chief:
cb = RunOp(self.get_sync_model_vars_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequenctly.")
self.register_callback(cb)
self._set_session_creator()
def _set_session_creator(self):
......@@ -230,26 +266,38 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return tf.group(*queue_ops)
def get_post_init_ops(self):
# Copy initialized variables for variables on the parameter server
# to the local copy of the variable.
def get_initial_sync_op(self):
"""
Get the op to copy-initialized all local variables from PS.
"""
def strip_port(s):
if s.endswith(':0'):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
post_init_ops = []
ops = []
nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars:
vname = strip_port(v.name)
for i in range(self.nr_gpu):
name = 'tower%s/%s' % (i, vname)
if name in local_var_by_name:
copy_to = local_var_by_name[name]
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global variable {} doesn't match a corresponding local var".format(v.name))
return tf.group(*post_init_ops, name='sync_variables_from_ps')
assert name in local_var_by_name, \
"Shadow variable {} doesn't match a corresponding local variable!".format(v.name)
copy_to = local_var_by_name[name]
# logger.info("{} -> {}".format(v.name, copy_to.name))
ops.append(copy_to.assign(v.read_value()))
return tf.group(*ops, name='sync_{}_variables_from_ps'.format(nr_shadow_vars))
def get_sync_model_vars_op(self):
"""
Get the op to sync local model_variables to PS.
"""
ops = []
for (shadow_v, local_v) in self._shadow_model_vars:
ops.append(shadow_v.assign(local_v.read_value()))
assert len(ops)
return tf.group(*ops, name='sync_{}_model_variables_to_ps'.format(len(ops)))
@property
def vs_name_for_predictor(self):
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip
from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput
......@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost() # assume single cost
# opt may be created under first-tower variable scope (which is '')
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients(
grads = tf.gradients(
cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE,
varlist,
gate_gradients=False,
colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
return cost, grads
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment