Commit 26edfabe authored by Yuxin Wu's avatar Yuxin Wu

misc fix

parent f698a04d
......@@ -98,7 +98,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 30
step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config()
......
......@@ -130,6 +130,7 @@ def start_train(config):
with sess.as_default(), \
coordinator_guard(sess, coord):
logger.info("Start with global_step={}".format(get_global_step()))
callbacks.before_train()
for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)):
......
......@@ -37,6 +37,9 @@ def create_test_graph():
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = []
for v in input_vars_train:
name = v.name
......@@ -99,11 +102,20 @@ class memoized(object):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
@memoized
def get_global_step_var():
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return global_step_var
""" get global_step variable in the current graph"""
try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError:
var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return var
def get_global_step():
""" get global_step value with current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
def get_rng(self):
return np.random.RandomState()
......@@ -10,7 +10,7 @@ import os
import time
from abc import abstractmethod, ABCMeta
from . import create_test_session
from . import create_test_session, get_global_step
from .naming import *
import logger
......@@ -53,6 +53,7 @@ class PeriodicCallback(Callback):
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
@abstractmethod
......@@ -72,13 +73,14 @@ class PeriodicSaver(PeriodicCallback):
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num)
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self):
self.log_dir = logger.LOG_DIR
self.epoch_num = 0
def _before_train(self):
self.writer = tf.train.SummaryWriter(
......@@ -91,9 +93,7 @@ class SummaryWriter(Callback):
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
self.writer.add_summary(summary_str, get_global_step())
class CallbackTimeLogger(object):
def __init__(self):
......@@ -214,5 +214,6 @@ class Callbacks(Callback):
def trigger_epoch(self):
self.train.trigger_epoch()
# TODO test callbacks can be run async?
self.test.trigger_epoch()
......@@ -51,7 +51,7 @@ def summary_moving_average(cost_var):
"""
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg')
0.9, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY)
......
......@@ -8,7 +8,7 @@ import numpy as np
__all__ = ['one_hot', 'batch_flatten', 'logSoftmax']
def one_hot(y, num_labels):
with tf.variable_scope('one_hot'):
with tf.op_scope([y, num_labels], 'one_hot'):
batch_size = tf.size(y)
y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
......@@ -23,7 +23,7 @@ def batch_flatten(x):
return tf.reshape(x, [-1, total_dim])
def logSoftmax(x):
with tf.variable_scope('logSoftmax'):
with tf.op_scope([x], 'logSoftmax'):
z = x - tf.reduce_max(x, 1, keep_dims=True)
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob
......
......@@ -58,11 +58,11 @@ class ValidationError(PeriodicCallback):
self.writer.add_summary(
create_summary('{}_error'.format(self.prefix),
err_stat.accuracy),
self.epoch_num)
self.global_step)
self.writer.add_summary(
create_summary('{}_cost'.format(self.prefix),
cost_avg),
self.epoch_num)
self.global_step)
logger.info(
"{} validation after epoch {}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg))
"{} validation after epoch{},step{}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, self.global_step, err_stat.accuracy, cost_avg))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment