Commit ae985fc4 authored by Yuxin Wu's avatar Yuxin Wu

ugly fix of MODEL_KEY

parent 24f898ec
...@@ -92,6 +92,7 @@ def get_config(): ...@@ -92,6 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -13,22 +13,22 @@ from ..utils import * ...@@ -13,22 +13,22 @@ from ..utils import *
__all__ = ['Callbacks'] __all__ = ['Callbacks']
@contextmanager @contextmanager
def create_test_graph(): def create_test_graph(trainer):
G = tf.get_default_graph() model = trainer.model.__class__()
model = G.get_collection(MODEL_KEY)[0]
with tf.Graph().as_default() as Gtest: with tf.Graph().as_default() as Gtest:
# create a global step var in test graph # create a global step var in test graph
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
new_model = model.__class__() input_vars = model.get_input_vars()
input_vars = new_model.get_input_vars() for v in input_vars:
cost = new_model.get_cost(input_vars, is_training=False) tf.add_to_collection(INPUT_VARS_KEY, v)
Gtest.add_to_collection(MODEL_KEY, new_model) cost = model.get_cost(input_vars, is_training=False)
yield Gtest yield Gtest
@contextmanager @contextmanager
def create_test_session(): def create_test_session(trainer):
with create_test_graph(): """ create a test-time session from trainer"""
with create_test_graph(trainer):
with tf.Session() as sess: with tf.Session() as sess:
yield sess yield sess
...@@ -66,16 +66,13 @@ class TestCallbackContext(object): ...@@ -66,16 +66,13 @@ class TestCallbackContext(object):
def __init__(self): def __init__(self):
self.sess = None self.sess = None
def _init_test_sess(self): @contextmanager
with create_test_session() as sess: def before_train_context(self, trainer):
if self.sess is None:
with create_test_session(trainer) as sess:
self.sess = sess self.sess = sess
self.graph = sess.graph self.graph = sess.graph
self.saver = tf.train.Saver() self.saver = tf.train.Saver()
@contextmanager
def before_train_context(self):
if self.sess is None:
self._init_test_sess()
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
yield yield
...@@ -112,7 +109,7 @@ class Callbacks(Callback): ...@@ -112,7 +109,7 @@ class Callbacks(Callback):
if isinstance(cb.type, TrainCallback): if isinstance(cb.type, TrainCallback):
cb.before_train(self.trainer) cb.before_train(self.trainer)
else: else:
with self.test_callback_context.before_train_context(): with self.test_callback_context.before_train_context(self.trainer):
cb.before_train(self.trainer) cb.before_train(self.trainer)
def _after_train(self): def _after_train(self):
......
...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name self.cost_var_name = cost_var_name
def _before_train(self): def _before_train(self):
self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars() self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars() self._find_output_vars()
......
...@@ -24,7 +24,7 @@ class Trainer(object): ...@@ -24,7 +24,7 @@ class Trainer(object):
""" """
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
tf.add_to_collection(MODEL_KEY, config.model) self.model = config.model
@abstractmethod @abstractmethod
def train(self): def train(self):
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: train.py # File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -47,7 +47,7 @@ class SimpleTrainer(Trainer): ...@@ -47,7 +47,7 @@ class SimpleTrainer(Trainer):
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self): def train(self):
model = self.config.model model = self.model
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
self.input_vars = input_vars self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True) cost_var = model.get_cost(input_vars, is_training=True)
...@@ -91,7 +91,7 @@ class QueueInputTrainer(Trainer): ...@@ -91,7 +91,7 @@ class QueueInputTrainer(Trainer):
return ret return ret
def train(self): def train(self):
model = self.config.model model = self.model
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
input_queue = model.get_input_queue() input_queue = model.get_input_queue()
...@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer): ...@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
input_th = EnqueueThread(self.sess, self.coord, enqueue_op, self.config.dataset, input_queue) input_th = EnqueueThread(self, enqueue_op, self.config.dataset, input_queue)
input_th.start() input_th.start()
self.main_loop() self.main_loop()
......
...@@ -23,11 +23,12 @@ class StoppableThread(threading.Thread): ...@@ -23,11 +23,12 @@ class StoppableThread(threading.Thread):
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, sess, coord, enqueue_op, dataflow, queue): def __init__(self, trainer, enqueue_op, dataflow, queue):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.sess = sess self.sess = trainer.sess
self.coord = coord self.coord = trainer.coord
self.input_vars = sess.graph.get_collection(MODEL_KEY)[0].get_input_vars() self.input_vars = trainer.model.get_input_vars()
self.dataflow = dataflow self.dataflow = dataflow
self.op = enqueue_op self.op = enqueue_op
self.queue = queue self.queue = queue
......
...@@ -8,7 +8,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -8,7 +8,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
MODEL_KEY = 'MODEL' INPUT_VARS_KEY = 'INPUT_VARS'
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment