Commit 26edfabe authored by Yuxin Wu's avatar Yuxin Wu

misc fix

parent f698a04d
...@@ -98,7 +98,7 @@ def get_config(): ...@@ -98,7 +98,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 30 step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -130,6 +130,7 @@ def start_train(config): ...@@ -130,6 +130,7 @@ def start_train(config):
with sess.as_default(), \ with sess.as_default(), \
coordinator_guard(sess, coord): coordinator_guard(sess, coord):
logger.info("Start with global_step={}".format(get_global_step()))
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
......
...@@ -37,6 +37,9 @@ def create_test_graph(): ...@@ -37,6 +37,9 @@ def create_test_graph():
input_vars_train = G.get_collection(INPUT_VARS_KEY) input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0] forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest: with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = [] input_vars = []
for v in input_vars_train: for v in input_vars_train:
name = v.name name = v.name
...@@ -99,11 +102,20 @@ class memoized(object): ...@@ -99,11 +102,20 @@ class memoized(object):
'''Support instance methods.''' '''Support instance methods.'''
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
@memoized
def get_global_step_var(): def get_global_step_var():
global_step_var = tf.Variable( """ get global_step variable in the current graph"""
0, trainable=False, name=GLOBAL_STEP_OP_NAME) try:
return global_step_var return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError:
var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return var
def get_global_step():
""" get global_step value with current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
def get_rng(self): def get_rng(self):
return np.random.RandomState() return np.random.RandomState()
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import time import time
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from . import create_test_session from . import create_test_session, get_global_step
from .naming import * from .naming import *
import logger import logger
...@@ -53,6 +53,7 @@ class PeriodicCallback(Callback): ...@@ -53,6 +53,7 @@ class PeriodicCallback(Callback):
def trigger_epoch(self): def trigger_epoch(self):
self.epoch_num += 1 self.epoch_num += 1
if self.epoch_num % self.__period == 0: if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger() self._trigger()
@abstractmethod @abstractmethod
...@@ -72,13 +73,14 @@ class PeriodicSaver(PeriodicCallback): ...@@ -72,13 +73,14 @@ class PeriodicSaver(PeriodicCallback):
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self): def _trigger(self):
self.saver.save(tf.get_default_session(), self.path, self.saver.save(
global_step=self.epoch_num) tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback): class SummaryWriter(Callback):
def __init__(self): def __init__(self):
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
self.epoch_num = 0
def _before_train(self): def _before_train(self):
self.writer = tf.train.SummaryWriter( self.writer = tf.train.SummaryWriter(
...@@ -91,9 +93,7 @@ class SummaryWriter(Callback): ...@@ -91,9 +93,7 @@ class SummaryWriter(Callback):
if self.summary_op is None: if self.summary_op is None:
return return
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.epoch_num += 1 self.writer.add_summary(summary_str, get_global_step())
self.writer.add_summary(summary_str, self.epoch_num)
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
...@@ -214,5 +214,6 @@ class Callbacks(Callback): ...@@ -214,5 +214,6 @@ class Callbacks(Callback):
def trigger_epoch(self): def trigger_epoch(self):
self.train.trigger_epoch() self.train.trigger_epoch()
# TODO test callbacks can be run async?
self.test.trigger_epoch() self.test.trigger_epoch()
...@@ -51,7 +51,7 @@ def summary_moving_average(cost_var): ...@@ -51,7 +51,7 @@ def summary_moving_average(cost_var):
""" """
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg') 0.9, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \ vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \ tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY) tf.get_collection(COST_VARS_KEY)
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
__all__ = ['one_hot', 'batch_flatten', 'logSoftmax'] __all__ = ['one_hot', 'batch_flatten', 'logSoftmax']
def one_hot(y, num_labels): def one_hot(y, num_labels):
with tf.variable_scope('one_hot'): with tf.op_scope([y, num_labels], 'one_hot'):
batch_size = tf.size(y) batch_size = tf.size(y)
y = tf.expand_dims(y, 1) y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1) indices = tf.expand_dims(tf.range(0, batch_size), 1)
...@@ -23,7 +23,7 @@ def batch_flatten(x): ...@@ -23,7 +23,7 @@ def batch_flatten(x):
return tf.reshape(x, [-1, total_dim]) return tf.reshape(x, [-1, total_dim])
def logSoftmax(x): def logSoftmax(x):
with tf.variable_scope('logSoftmax'): with tf.op_scope([x], 'logSoftmax'):
z = x - tf.reduce_max(x, 1, keep_dims=True) z = x - tf.reduce_max(x, 1, keep_dims=True)
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True)) logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob return logprob
......
...@@ -58,11 +58,11 @@ class ValidationError(PeriodicCallback): ...@@ -58,11 +58,11 @@ class ValidationError(PeriodicCallback):
self.writer.add_summary( self.writer.add_summary(
create_summary('{}_error'.format(self.prefix), create_summary('{}_error'.format(self.prefix),
err_stat.accuracy), err_stat.accuracy),
self.epoch_num) self.global_step)
self.writer.add_summary( self.writer.add_summary(
create_summary('{}_cost'.format(self.prefix), create_summary('{}_cost'.format(self.prefix),
cost_avg), cost_avg),
self.epoch_num) self.global_step)
logger.info( logger.info(
"{} validation after epoch {}: err={:.4f}, cost={:.3f}".format( "{} validation after epoch{},step{}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg)) self.prefix, self.epoch_num, self.global_step, err_stat.accuracy, cost_avg))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment