Commit 8dcf454d authored by Yuxin Wu's avatar Yuxin Wu

add build_placeholder method in InputDesc

parent c7de2013
......@@ -4,40 +4,70 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import tensorflow as tf
import pickle
import six
from ..utils import logger
from ..utils.naming import INPUTS_KEY
from ..utils.argtools import memoized
from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'InputVar', 'ModelDesc']
class InputDesc(object):
""" Store metadata about input placeholders. """
def __init__(self, type, shape, name, sparse=False):
"""
Args:
type: tf type of the tensor.
shape (list):
name (str):
sparse (bool): whether to use ``tf.sparse_placeholder``.
"""
self.type = type
self.shape = shape
self.name = name
self.sparse = sparse
class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
"""
Metadata about an input entry point to the graph.
This metadata can be later used to build placeholders or other types of
input source.
"""
def dumps(self):
"""
Returns:
str: serialized string
"""
return pickle.dumps(self)
@staticmethod
def loads(buf):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return pickle.loads(buf)
def build_placeholder(self, prefix=''):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Args:
prefix(str): the name of the placeholder will be ``prefix + self.name``
Returns:
tf.Tensor:
"""
with tf.name_scope(None): # clear any name scope it might get called in
return tf.placeholder(
self.type, shape=self.shape,
name=prefix + self.name)
# TODO cache results from build_placeholder, and skip it in serialization
@memoized
def build_placeholder_reuse(self):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
return self.build_placeholder()
class InputVar(InputDesc):
def __init__(self, *args, **kwargs):
......@@ -70,17 +100,12 @@ class ModelDesc(object):
list[tf.Tensor]: the list of built placeholders.
"""
inputs = self._get_inputs()
for v in inputs:
tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = []
with tf.name_scope(None): # clear any name scope it might get called in
for v in inputs:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f(
v.type, shape=v.shape,
name=prefix + v.name))
for v in inputs:
ret.append(v.build_placeholder(prefix))
return ret
@memoized
def get_inputs_desc(self):
"""
Returns:
......@@ -150,33 +175,3 @@ class ModelDesc(object):
def _get_gradient_processor(self):
return []
class ModelFromMetaGraph(ModelDesc):
"""
Load the exact TF graph from a saved meta_graph.
Only useful for inference.
"""
# TODO this class may not be functional anymore. don't use
def __init__(self, filename):
"""
Args:
filename (str): file name of the saved meta graph.
"""
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUTS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]:
if k not in all_coll:
logger.warn("Collection {} not found in metagraph!".format(k))
def _get_inputs(self):
col = tf.get_collection(INPUTS_KEY)
col = [InputDesc.loads(v) for v in col]
return col
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
pass
......@@ -38,6 +38,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
......@@ -107,9 +108,6 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
def _trigger_epoch(self):
pass
def setup(self):
"""
Setup the trainer and be ready for the main loop.
......@@ -192,7 +190,6 @@ class Trainer(object):
self._epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError):
......
......@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
# metainfo for input tensors
INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment