Commit 67fdbf26 authored by Yuxin Wu's avatar Yuxin Wu

reuse_input_vars

parent ae985fc4
......@@ -28,11 +28,8 @@ IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_input_vars(self):
return [
tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label')
]
def _get_cost(self, input_vars, is_training):
......@@ -92,7 +89,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......
......@@ -14,14 +14,12 @@ __all__ = ['Callbacks']
@contextmanager
def create_test_graph(trainer):
model = trainer.model.__class__()
model = trainer.model
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = model.get_input_vars()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
cost = model.get_cost(input_vars, is_training=False)
yield Gtest
......
......@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name
def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.input_vars = self.trainer.model.reuse_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars()
......
......@@ -5,44 +5,53 @@
from abc import ABCMeta, abstractmethod
import tensorflow as tf
from collections import namedtuple
__all__ = ['ModelDesc']
__all__ = ['ModelDesc', 'InputVar']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
class ModelDesc(object):
__metaclass__ = ABCMeta
def __init__(self):
self.input_vars = None
pass
def get_input_vars(self):
"""
return the list of input vars in the graph
results will be cached, to avoid creating the same variable
return the list of raw input vars in the graph
if reuse=True, results will be cached, to avoid creating the same variable
"""
if self.input_vars is None:
self.input_vars = self._get_input_vars()
for i in self.input_vars:
assert isinstance(i, tf.Tensor), tf.Tensor.__class__
return self.input_vars
input_vars = self._get_input_vars()
ret = []
for v in input_vars:
ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name))
return ret
def reuse_input_vars(self):
""" find input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
g = tf.get_default_graph()
return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
@abstractmethod
def _get_input_vars(self):
"""
return the list of input vars in the graph
"""
pass
def get_input_queue(self):
def get_input_queue(self, input_vars):
"""
return the queue for input. the dequeued elements will be fed to self.get_cost
if queue is None, datapoints from dataflow will be fed to the graph directly.
when running with multiGPU, queue cannot be None
"""
assert self.input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue')
assert input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in input_vars], name='input_queue')
def get_cost(self, input_vars, is_training):
assert len(input_vars) == len(self.input_vars)
assert type(is_training) == bool
return self._get_cost(input_vars, is_training)
......@@ -57,10 +66,6 @@ class ModelDesc(object):
is_training: a python bool variable
Returns:
the cost to minimize. scalar variable
input_vars might be different from self.input_vars
(inputs might go through the queue for faster input),
but must have the same length
"""
def get_lr_multiplier(self):
......
......@@ -94,8 +94,8 @@ class QueueInputTrainer(Trainer):
model = self.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): # only one input
......@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
self.init_session_and_coord()
# create a thread that keeps filling the queue
input_th = EnqueueThread(self, enqueue_op, self.config.dataset, input_queue)
input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars)
input_th.start()
self.main_loop()
......
......@@ -23,13 +23,13 @@ class StoppableThread(threading.Thread):
class EnqueueThread(threading.Thread):
def __init__(self, trainer, enqueue_op, dataflow, queue):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.input_vars = trainer.model.get_input_vars()
self.dataflow = trainer.config.dataset
self.dataflow = dataflow
self.input_vars = raw_input_var
self.op = enqueue_op
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
......
......@@ -8,7 +8,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY = 'INPUT_VARS'
# export all upper case variables
all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment