Commit bbb47815 authored by Yuxin Wu's avatar Yuxin Wu

get_global_step -> get_global_step_value to avoid confusion with tf.train.get_global_step

parent c7021a87
......@@ -9,7 +9,6 @@ import shutil
from .base import Callback
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -76,7 +75,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.saver.save(
tf.get_default_session(),
self.path,
global_step=get_global_step(),
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s" % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError): # disk error sometimes.. just ignore it
......
......@@ -8,7 +8,7 @@ import json
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_global_step
from ..tfutils.common import get_global_step_value
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
......@@ -134,7 +134,7 @@ class StatPrinter(Callback):
def _trigger_epoch(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step())
self._stat_holder.add_stat('global_step', get_global_step_value())
self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
......
......@@ -10,7 +10,7 @@ import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_value',
'get_global_step_var',
'get_op_tensor_name',
'get_tensors_by_names',
......@@ -56,16 +56,16 @@ def get_global_step_var():
assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.constant_initializer(dtype=tf.int32),
var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0,
trainable=False, dtype=tf.int32)
return var
def get_global_step():
def get_global_step_value():
"""
Returns:
float: global_step value in current graph and session"""
int: global_step value in current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
......
......@@ -140,8 +140,7 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
Returns:
tf.Variable: the variable
"""
ret = tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value),
ret = tf.get_variable(name, initializer=init_value,
trainable=trainable)
if summary:
# this is recognized in callbacks.StatHolder
......
......@@ -13,7 +13,7 @@ from .config import TrainConfig
from ..utils import logger
from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils import get_global_step_var, get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary
......@@ -121,7 +121,7 @@ class Trainer(object):
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step())
self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val):
"""
......@@ -144,7 +144,7 @@ class Trainer(object):
"""
self._setup()
describe_model()
get_global_step_var()
get_global_step_var() # ensure such var exists
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
......@@ -178,12 +178,12 @@ class Trainer(object):
with self.sess.as_default():
try:
callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step()))
logger.info("Start training with global_step={}".format(get_global_step_value()))
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation(
'Epoch {} (global_step {})'.format(
self.epoch_num, get_global_step() + self.config.step_per_epoch),
self.epoch_num, get_global_step_value() + self.config.step_per_epoch),
log_start=True):
for self.step_num in range(self.config.step_per_epoch):
if self.coord.should_stop():
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
# this is also the name used by tf.train.get_global_step
GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment