Commit 67fdbf26 authored by Yuxin Wu's avatar Yuxin Wu

reuse_input_vars

parent ae985fc4
...@@ -28,11 +28,8 @@ IMAGE_SIZE = 28 ...@@ -28,11 +28,8 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [ return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
tf.placeholder( InputVar(tf.int32, (None,), 'label')
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
] ]
def _get_cost(self, input_vars, is_training): def _get_cost(self, input_vars, is_training):
...@@ -92,7 +89,6 @@ def get_config(): ...@@ -92,7 +89,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -14,14 +14,12 @@ __all__ = ['Callbacks'] ...@@ -14,14 +14,12 @@ __all__ = ['Callbacks']
@contextmanager @contextmanager
def create_test_graph(trainer): def create_test_graph(trainer):
model = trainer.model.__class__() model = trainer.model
with tf.Graph().as_default() as Gtest: with tf.Graph().as_default() as Gtest:
# create a global step var in test graph # create a global step var in test graph
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
cost = model.get_cost(input_vars, is_training=False) cost = model.get_cost(input_vars, is_training=False)
yield Gtest yield Gtest
......
...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name self.cost_var_name = cost_var_name
def _before_train(self): def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY) self.input_vars = self.trainer.model.reuse_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars() self._find_output_vars()
......
...@@ -5,44 +5,53 @@ ...@@ -5,44 +5,53 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import tensorflow as tf import tensorflow as tf
from collections import namedtuple
__all__ = ['ModelDesc'] __all__ = ['ModelDesc', 'InputVar']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
class ModelDesc(object): class ModelDesc(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self): def __init__(self):
self.input_vars = None pass
def get_input_vars(self): def get_input_vars(self):
""" """
return the list of input vars in the graph return the list of raw input vars in the graph
results will be cached, to avoid creating the same variable if reuse=True, results will be cached, to avoid creating the same variable
""" """
if self.input_vars is None: input_vars = self._get_input_vars()
self.input_vars = self._get_input_vars() ret = []
for i in self.input_vars: for v in input_vars:
assert isinstance(i, tf.Tensor), tf.Tensor.__class__ ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name))
return self.input_vars return ret
def reuse_input_vars(self):
""" find input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
g = tf.get_default_graph()
return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
@abstractmethod @abstractmethod
def _get_input_vars(self): def _get_input_vars(self):
""" """
return the list of input vars in the graph return the list of input vars in the graph
""" """
pass
def get_input_queue(self): def get_input_queue(self, input_vars):
""" """
return the queue for input. the dequeued elements will be fed to self.get_cost return the queue for input. the dequeued elements will be fed to self.get_cost
if queue is None, datapoints from dataflow will be fed to the graph directly. if queue is None, datapoints from dataflow will be fed to the graph directly.
when running with multiGPU, queue cannot be None when running with multiGPU, queue cannot be None
""" """
assert self.input_vars is not None assert input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue') return tf.FIFOQueue(50, [x.dtype for x in input_vars], name='input_queue')
def get_cost(self, input_vars, is_training): def get_cost(self, input_vars, is_training):
assert len(input_vars) == len(self.input_vars)
assert type(is_training) == bool assert type(is_training) == bool
return self._get_cost(input_vars, is_training) return self._get_cost(input_vars, is_training)
...@@ -57,10 +66,6 @@ class ModelDesc(object): ...@@ -57,10 +66,6 @@ class ModelDesc(object):
is_training: a python bool variable is_training: a python bool variable
Returns: Returns:
the cost to minimize. scalar variable the cost to minimize. scalar variable
input_vars might be different from self.input_vars
(inputs might go through the queue for faster input),
but must have the same length
""" """
def get_lr_multiplier(self): def get_lr_multiplier(self):
......
...@@ -94,8 +94,8 @@ class QueueInputTrainer(Trainer): ...@@ -94,8 +94,8 @@ class QueueInputTrainer(Trainer):
model = self.model model = self.model
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
input_queue = model.get_input_queue() input_queue = model.get_input_queue()
enqueue_op = input_queue.enqueue(input_vars) enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs(): def get_model_inputs():
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): # only one input if isinstance(model_inputs, tf.Tensor): # only one input
...@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer): ...@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
input_th = EnqueueThread(self, enqueue_op, self.config.dataset, input_queue) input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars)
input_th.start() input_th.start()
self.main_loop() self.main_loop()
......
...@@ -23,13 +23,13 @@ class StoppableThread(threading.Thread): ...@@ -23,13 +23,13 @@ class StoppableThread(threading.Thread):
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer, enqueue_op, dataflow, queue): def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.sess = trainer.sess self.sess = trainer.sess
self.coord = trainer.coord self.coord = trainer.coord
self.input_vars = trainer.model.get_input_vars() self.dataflow = trainer.config.dataset
self.dataflow = dataflow self.input_vars = raw_input_var
self.op = enqueue_op self.op = enqueue_op
self.queue = queue self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
......
...@@ -8,7 +8,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -8,7 +8,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY = 'INPUT_VARS'
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment