Commit bbb47815 authored by Yuxin Wu's avatar Yuxin Wu

get_global_step -> get_global_step_value to avoid confusion with tf.train.get_global_step

parent c7021a87
...@@ -9,7 +9,6 @@ import shutil ...@@ -9,7 +9,6 @@ import shutil
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname from ..tfutils.varmanip import get_savename_from_varname
from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -76,7 +75,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -76,7 +75,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
global_step=get_global_step(), global_step=tf.train.get_global_step(),
write_meta_graph=False) write_meta_graph=False)
logger.info("Model saved to %s" % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path) logger.info("Model saved to %s" % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError): # disk error sometimes.. just ignore it except (OSError, IOError): # disk error sometimes.. just ignore it
......
...@@ -8,7 +8,7 @@ import json ...@@ -8,7 +8,7 @@ import json
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_global_step from ..tfutils.common import get_global_step_value
__all__ = ['StatHolder', 'StatPrinter', 'SendStat'] __all__ = ['StatHolder', 'StatPrinter', 'SendStat']
...@@ -134,7 +134,7 @@ class StatPrinter(Callback): ...@@ -134,7 +134,7 @@ class StatPrinter(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
# by default, add this two stat # by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step()) self._stat_holder.add_stat('global_step', get_global_step_value())
self._stat_holder.finalize() self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1) self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
......
...@@ -10,7 +10,7 @@ import six ...@@ -10,7 +10,7 @@ import six
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
...@@ -56,16 +56,16 @@ def get_global_step_var(): ...@@ -56,16 +56,16 @@ def get_global_step_var():
assert scope.name == '', \ assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!" "Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[], var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=tf.constant_initializer(dtype=tf.int32), initializer=0,
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int32)
return var return var
def get_global_step(): def get_global_step_value():
""" """
Returns: Returns:
float: global_step value in current graph and session""" int: global_step value in current graph and session"""
return tf.train.global_step( return tf.train.global_step(
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
......
...@@ -140,8 +140,7 @@ def get_scalar_var(name, init_value, summary=False, trainable=False): ...@@ -140,8 +140,7 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
Returns: Returns:
tf.Variable: the variable tf.Variable: the variable
""" """
ret = tf.get_variable(name, shape=[], ret = tf.get_variable(name, initializer=init_value,
initializer=tf.constant_initializer(init_value),
trainable=trainable) trainable=trainable)
if summary: if summary:
# this is recognized in callbacks.StatHolder # this is recognized in callbacks.StatHolder
......
...@@ -13,7 +13,7 @@ from .config import TrainConfig ...@@ -13,7 +13,7 @@ from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step_var, get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary from ..tfutils.summary import create_scalar_summary
...@@ -121,7 +121,7 @@ class Trainer(object): ...@@ -121,7 +121,7 @@ class Trainer(object):
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step()) self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val): def add_scalar_summary(self, name, val):
""" """
...@@ -144,7 +144,7 @@ class Trainer(object): ...@@ -144,7 +144,7 @@ class Trainer(object):
""" """
self._setup() self._setup()
describe_model() describe_model()
get_global_step_var() get_global_step_var() # ensure such var exists
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
...@@ -178,12 +178,12 @@ class Trainer(object): ...@@ -178,12 +178,12 @@ class Trainer(object):
with self.sess.as_default(): with self.sess.as_default():
try: try:
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step_value()))
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
self.epoch_num, get_global_step() + self.config.step_per_epoch), self.epoch_num, get_global_step_value() + self.config.step_per_epoch),
log_start=True): log_start=True):
for self.step_num in range(self.config.step_per_epoch): for self.step_num in range(self.config.step_per_epoch):
if self.coord.should_stop(): if self.coord.should_stop():
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
# this is also the name used by tf.train.get_global_step
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment