Commit 8dcf454d authored by Yuxin Wu's avatar Yuxin Wu

add build_placeholder method in InputDesc

parent c7de2013
...@@ -4,40 +4,70 @@ ...@@ -4,40 +4,70 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple
import tensorflow as tf import tensorflow as tf
import pickle import pickle
import six import six
from ..utils import logger from ..utils import logger
from ..utils.naming import INPUTS_KEY
from ..utils.argtools import memoized from ..utils.argtools import memoized
from .regularize import regularize_cost_from_collection from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'InputVar', 'ModelDesc'] __all__ = ['InputDesc', 'InputVar', 'ModelDesc']
class InputDesc(object): class InputDesc(
""" Store metadata about input placeholders. """ namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
def __init__(self, type, shape, name, sparse=False):
""" """
Args: Metadata about an input entry point to the graph.
type: tf type of the tensor. This metadata can be later used to build placeholders or other types of
shape (list): input source.
name (str):
sparse (bool): whether to use ``tf.sparse_placeholder``.
""" """
self.type = type
self.shape = shape
self.name = name
self.sparse = sparse
def dumps(self): def dumps(self):
"""
Returns:
str: serialized string
"""
return pickle.dumps(self) return pickle.dumps(self)
@staticmethod @staticmethod
def loads(buf): def loads(buf):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return pickle.loads(buf) return pickle.loads(buf)
def build_placeholder(self, prefix=''):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Args:
prefix(str): the name of the placeholder will be ``prefix + self.name``
Returns:
tf.Tensor:
"""
with tf.name_scope(None): # clear any name scope it might get called in
return tf.placeholder(
self.type, shape=self.shape,
name=prefix + self.name)
# TODO cache results from build_placeholder, and skip it in serialization
@memoized
def build_placeholder_reuse(self):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
return self.build_placeholder()
class InputVar(InputDesc): class InputVar(InputDesc):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -70,17 +100,12 @@ class ModelDesc(object): ...@@ -70,17 +100,12 @@ class ModelDesc(object):
list[tf.Tensor]: the list of built placeholders. list[tf.Tensor]: the list of built placeholders.
""" """
inputs = self._get_inputs() inputs = self._get_inputs()
for v in inputs:
tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = [] ret = []
with tf.name_scope(None): # clear any name scope it might get called in
for v in inputs: for v in inputs:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder ret.append(v.build_placeholder(prefix))
ret.append(placehdr_f(
v.type, shape=v.shape,
name=prefix + v.name))
return ret return ret
@memoized
def get_inputs_desc(self): def get_inputs_desc(self):
""" """
Returns: Returns:
...@@ -150,33 +175,3 @@ class ModelDesc(object): ...@@ -150,33 +175,3 @@ class ModelDesc(object):
def _get_gradient_processor(self): def _get_gradient_processor(self):
return [] return []
class ModelFromMetaGraph(ModelDesc):
"""
Load the exact TF graph from a saved meta_graph.
Only useful for inference.
"""
# TODO this class may not be functional anymore. don't use
def __init__(self, filename):
"""
Args:
filename (str): file name of the saved meta graph.
"""
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUTS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]:
if k not in all_coll:
logger.warn("Collection {} not found in metagraph!".format(k))
def _get_inputs(self):
col = tf.get_collection(INPUTS_KEY)
col = [InputDesc.loads(v) for v in col]
return col
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
pass
...@@ -38,6 +38,7 @@ class Trainer(object): ...@@ -38,6 +38,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer. config (TrainConfig): the config used in this trainer.
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging. monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished. epoch_num (int): the number of epochs that have finished.
...@@ -107,9 +108,6 @@ class Trainer(object): ...@@ -107,9 +108,6 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration". """ Abstract method: run one iteration. Subclass should define what is "iteration".
""" """
def _trigger_epoch(self):
pass
def setup(self): def setup(self):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
...@@ -192,7 +190,6 @@ class Trainer(object): ...@@ -192,7 +190,6 @@ class Trainer(object):
self._epoch_num, self.global_step, time.time() - start_time)) self._epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
logger.info("Training has finished!") logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError): except (StopTraining, tf.errors.OutOfRangeError):
......
...@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp' ...@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS' MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
# metainfo for input tensors
INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS] TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment