Commit 3b2f7df1 authored by Yuxin Wu's avatar Yuxin Wu

load model from meta

parent 03b92aba
......@@ -7,10 +7,10 @@ from abc import ABCMeta, abstractmethod
import tensorflow as tf
from collections import namedtuple
from ..utils import logger
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import *
__all__ = ['ModelDesc', 'InputVar']
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
......@@ -32,6 +32,8 @@ class ModelDesc(object):
ret = []
for v in input_vars:
ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name))
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret
def reuse_input_vars(self):
......@@ -57,27 +59,11 @@ class ModelDesc(object):
"""
self._build_graph(model_inputs, is_training)
#@abstractmethod
@abstractmethod
def _build_graph(self, inputs, is_training):
if self._old_version():
self.model_inputs = inputs
self.is_training = is_training
else:
raise NotImplementedError()
def _old_version(self):
# for backward-compat only.
import inspect
args = inspect.getargspec(self._get_cost)
return len(args.args) == 3
pass
def get_cost(self):
if self._old_version():
assert type(self.is_training) == bool
logger.warn("!!!using _get_cost to setup the graph is deprecated in favor of _build_graph")
logger.warn("See examples for details.")
return self._get_cost(self.model_inputs, self.is_training)
else:
return self._get_cost()
def _get_cost(self, *args):
......@@ -87,3 +73,27 @@ class ModelDesc(object):
""" Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()]
class ModelFromMetaGraph(ModelDesc):
"""
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
"""
def __init__(self, filename):
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.VARIABLES]:
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
def get_input_vars(self):
return tf.get_collection(INPUT_VARS_KEY)
def _get_input_vars(self):
raise NotImplementedError("Shouldn't call here")
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
pass
......@@ -50,8 +50,9 @@ class PredictConfig(object):
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config',
get_default_sess_config(0.3))
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
......@@ -61,7 +62,7 @@ class PredictConfig(object):
def get_predict_func(config):
"""
Produce a simple predictor function in a newly-created session without any parallelism.
Produce a simple predictor function run inside a new session.
:param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
......@@ -77,10 +78,8 @@ def get_predict_func(config):
input_map = [input_vars[k] for k in config.input_data_mapping]
# check output_var_names against output_vars
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
output_vars = get_vars_by_names(output_var_names)
# XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
......
......@@ -105,9 +105,9 @@ class PredictorWorkerThread(threading.Thread):
for k in range(self.nr_input_var):
batched[k].append(inp[k])
futures.append(f)
cnt += 1
except queue.Empty:
break
cnt += 1
return batched, futures
#self.xxx = None
while True:
......@@ -116,12 +116,9 @@ class PredictorWorkerThread(threading.Thread):
outputs = self.func(batched)
# debug, for speed testing
#if self.xxx is None:
#outputs = self.func([batched])
#self.xxx = outputs
#self.xxx = outputs = self.func([batched])
#else:
#outputs = [None, None]
#outputs[0] = [self.xxx[0][0]] * len(batched)
#outputs[1] = [self.xxx[1][0]] * len(batched)
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
......
......@@ -13,8 +13,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = []
@contextmanager
def argscope(layers, **kwargs):
param = kwargs
def argscope(layers, **param):
if not isinstance(layers, list):
layers = [layers]
......@@ -35,6 +34,10 @@ def argscope(layers, **kwargs):
del _ArgScopeStack[-1]
def get_arg_scope():
""" return the current argscope
an argscope is a dict of dict:
dict[layername] = {arg: val}
"""
if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1]
else:
......
......@@ -138,7 +138,7 @@ class QueueInputTrainer(Trainer):
inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower:
logger.info("Building graph for predict towerp{}...".format(k))
logger.info("Building graph for predict tower p{}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False)
......
......@@ -7,6 +7,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
# export all upper case variables
all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment