Commit ae985fc4 authored by Yuxin Wu's avatar Yuxin Wu

ugly fix of MODEL_KEY

parent 24f898ec
......@@ -92,6 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......
......@@ -13,22 +13,22 @@ from ..utils import *
__all__ = ['Callbacks']
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
model = G.get_collection(MODEL_KEY)[0]
def create_test_graph(trainer):
model = trainer.model.__class__()
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
new_model = model.__class__()
input_vars = new_model.get_input_vars()
cost = new_model.get_cost(input_vars, is_training=False)
Gtest.add_to_collection(MODEL_KEY, new_model)
input_vars = model.get_input_vars()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
cost = model.get_cost(input_vars, is_training=False)
yield Gtest
@contextmanager
def create_test_session():
with create_test_graph():
def create_test_session(trainer):
""" create a test-time session from trainer"""
with create_test_graph(trainer):
with tf.Session() as sess:
yield sess
......@@ -66,16 +66,13 @@ class TestCallbackContext(object):
def __init__(self):
self.sess = None
def _init_test_sess(self):
with create_test_session() as sess:
@contextmanager
def before_train_context(self, trainer):
if self.sess is None:
with create_test_session(trainer) as sess:
self.sess = sess
self.graph = sess.graph
self.saver = tf.train.Saver()
@contextmanager
def before_train_context(self):
if self.sess is None:
self._init_test_sess()
with self.graph.as_default(), self.sess.as_default():
yield
......@@ -112,7 +109,7 @@ class Callbacks(Callback):
if isinstance(cb.type, TrainCallback):
cb.before_train(self.trainer)
else:
with self.test_callback_context.before_train_context():
with self.test_callback_context.before_train_context(self.trainer):
cb.before_train(self.trainer)
def _after_train(self):
......
......@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name
def _before_train(self):
self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars()
self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars()
......
......@@ -24,7 +24,7 @@ class Trainer(object):
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
self.model = config.model
@abstractmethod
def train(self):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: train.py
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -47,7 +47,7 @@ class SimpleTrainer(Trainer):
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self):
model = self.config.model
model = self.model
input_vars = model.get_input_vars()
self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True)
......@@ -91,7 +91,7 @@ class QueueInputTrainer(Trainer):
return ret
def train(self):
model = self.config.model
model = self.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
......@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
self.init_session_and_coord()
# create a thread that keeps filling the queue
input_th = EnqueueThread(self.sess, self.coord, enqueue_op, self.config.dataset, input_queue)
input_th = EnqueueThread(self, enqueue_op, self.config.dataset, input_queue)
input_th.start()
self.main_loop()
......
......@@ -23,11 +23,12 @@ class StoppableThread(threading.Thread):
class EnqueueThread(threading.Thread):
def __init__(self, sess, coord, enqueue_op, dataflow, queue):
def __init__(self, trainer, enqueue_op, dataflow, queue):
super(EnqueueThread, self).__init__()
self.sess = sess
self.coord = coord
self.input_vars = sess.graph.get_collection(MODEL_KEY)[0].get_input_vars()
self.sess = trainer.sess
self.coord = trainer.coord
self.input_vars = trainer.model.get_input_vars()
self.dataflow = dataflow
self.op = enqueue_op
self.queue = queue
......
......@@ -8,7 +8,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
MODEL_KEY = 'MODEL'
INPUT_VARS_KEY = 'INPUT_VARS'
# export all upper case variables
all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment