Commit 38d26977 authored by Yuxin Wu's avatar Yuxin Wu

backward-compat with tf1.0

parent 3898d354
...@@ -59,7 +59,13 @@ def get_global_step_var(): ...@@ -59,7 +59,13 @@ def get_global_step_var():
"The global_step variable should be created under the root variable scope!" "The global_step variable should be created under the root variable scope!"
assert not scope.reuse, \ assert not scope.reuse, \
"The global_step variable shouldn't be called under a reuse variable scope!" "The global_step variable shouldn't be called under a reuse variable scope!"
var = training_util.get_or_create_global_step() if get_tf_version_number() <= 1.0:
var = tf.get_variable('global_step',
initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int64)
tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, var)
else:
var = training_util.get_or_create_global_step()
return var return var
......
...@@ -64,12 +64,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -64,12 +64,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() self.build_train_tower()
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
# opt may be created under first-tower variable scope (which is '')
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# GATE_NONE faster? # GATE_NONE faster?
varlist = tf.trainable_variables() varlist = tf.trainable_variables()
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name: if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower # only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')] varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment