Commit 3b2f7df1 authored by Yuxin Wu's avatar Yuxin Wu

load model from meta

parent 03b92aba
...@@ -7,10 +7,10 @@ from abc import ABCMeta, abstractmethod ...@@ -7,10 +7,10 @@ from abc import ABCMeta, abstractmethod
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
from ..utils import logger from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import * from ..tfutils import *
__all__ = ['ModelDesc', 'InputVar'] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name']) InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
...@@ -32,6 +32,8 @@ class ModelDesc(object): ...@@ -32,6 +32,8 @@ class ModelDesc(object):
ret = [] ret = []
for v in input_vars: for v in input_vars:
ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name)) ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name))
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret return ret
def reuse_input_vars(self): def reuse_input_vars(self):
...@@ -57,28 +59,12 @@ class ModelDesc(object): ...@@ -57,28 +59,12 @@ class ModelDesc(object):
""" """
self._build_graph(model_inputs, is_training) self._build_graph(model_inputs, is_training)
#@abstractmethod @abstractmethod
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
if self._old_version(): pass
self.model_inputs = inputs
self.is_training = is_training
else:
raise NotImplementedError()
def _old_version(self):
# for backward-compat only.
import inspect
args = inspect.getargspec(self._get_cost)
return len(args.args) == 3
def get_cost(self): def get_cost(self):
if self._old_version(): return self._get_cost()
assert type(self.is_training) == bool
logger.warn("!!!using _get_cost to setup the graph is deprecated in favor of _build_graph")
logger.warn("See examples for details.")
return self._get_cost(self.model_inputs, self.is_training)
else:
return self._get_cost()
def _get_cost(self, *args): def _get_cost(self, *args):
return self.cost return self.cost
...@@ -87,3 +73,27 @@ class ModelDesc(object): ...@@ -87,3 +73,27 @@ class ModelDesc(object):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()] return [CheckGradient()]#, SummaryGradient()]
class ModelFromMetaGraph(ModelDesc):
"""
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
"""
def __init__(self, filename):
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.VARIABLES]:
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
def get_input_vars(self):
return tf.get_collection(INPUT_VARS_KEY)
def _get_input_vars(self):
raise NotImplementedError("Shouldn't call here")
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
pass
...@@ -50,8 +50,9 @@ class PredictConfig(object): ...@@ -50,8 +50,9 @@ class PredictConfig(object):
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config', # XXX does it work? start with minimal memory, but allow growth.
get_default_sess_config(0.3)) # allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
...@@ -61,7 +62,7 @@ class PredictConfig(object): ...@@ -61,7 +62,7 @@ class PredictConfig(object):
def get_predict_func(config): def get_predict_func(config):
""" """
Produce a simple predictor function in a newly-created session without any parallelism. Produce a simple predictor function run inside a new session.
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return :returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``. a list of output values defined in ``config.output_var_names``.
...@@ -77,10 +78,8 @@ def get_predict_func(config): ...@@ -77,10 +78,8 @@ def get_predict_func(config):
input_map = [input_vars[k] for k in config.input_data_mapping] input_map = [input_vars[k] for k in config.input_data_mapping]
# check output_var_names against output_vars # check output_var_names against output_vars
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = get_vars_by_names(output_var_names)
for n in output_var_names]
# XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
......
...@@ -105,9 +105,9 @@ class PredictorWorkerThread(threading.Thread): ...@@ -105,9 +105,9 @@ class PredictorWorkerThread(threading.Thread):
for k in range(self.nr_input_var): for k in range(self.nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
cnt += 1
except queue.Empty: except queue.Empty:
break break
cnt += 1
return batched, futures return batched, futures
#self.xxx = None #self.xxx = None
while True: while True:
...@@ -116,12 +116,9 @@ class PredictorWorkerThread(threading.Thread): ...@@ -116,12 +116,9 @@ class PredictorWorkerThread(threading.Thread):
outputs = self.func(batched) outputs = self.func(batched)
# debug, for speed testing # debug, for speed testing
#if self.xxx is None: #if self.xxx is None:
#outputs = self.func([batched]) #self.xxx = outputs = self.func([batched])
#self.xxx = outputs
#else: #else:
#outputs = [None, None] #outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
#outputs[0] = [self.xxx[0][0]] * len(batched)
#outputs[1] = [self.xxx[1][0]] * len(batched)
for idx, f in enumerate(futures): for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs]) f.set_result([k[idx] for k in outputs])
......
...@@ -13,8 +13,7 @@ __all__ = ['argscope', 'get_arg_scope'] ...@@ -13,8 +13,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = [] _ArgScopeStack = []
@contextmanager @contextmanager
def argscope(layers, **kwargs): def argscope(layers, **param):
param = kwargs
if not isinstance(layers, list): if not isinstance(layers, list):
layers = [layers] layers = [layers]
...@@ -35,6 +34,10 @@ def argscope(layers, **kwargs): ...@@ -35,6 +34,10 @@ def argscope(layers, **kwargs):
del _ArgScopeStack[-1] del _ArgScopeStack[-1]
def get_arg_scope(): def get_arg_scope():
""" return the current argscope
an argscope is a dict of dict:
dict[layername] = {arg: val}
"""
if len(_ArgScopeStack) > 0: if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1] return _ArgScopeStack[-1]
else: else:
......
...@@ -138,7 +138,7 @@ class QueueInputTrainer(Trainer): ...@@ -138,7 +138,7 @@ class QueueInputTrainer(Trainer):
inputs = self.model.get_input_vars() inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in self.predict_tower: for k in self.predict_tower:
logger.info("Building graph for predict towerp{}...".format(k)) logger.info("Building graph for predict tower p{}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('towerp{}'.format(k)): tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False) self.model.build_graph(inputs, False)
......
...@@ -7,6 +7,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -7,6 +7,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment