Commit 39fa4656 authored by Yuxin Wu's avatar Yuxin Wu

Use inputs() and tf.placeholder in ModelDesc (#318)

parent 215a4d6d
......@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun
```python
class MyModel(ModelDesc):
def _get_inputs(self):
return [InputDesc(...), InputDesc(...)]
def inputs(self):
return [tf.placeholder(dtype, shape, name), tf.placeholder(dtype, shape, name), ... ]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
......
......@@ -26,10 +26,9 @@ class Model(ModelDesc):
super(Model, self).__init__()
self.cifar_classnum = cifar_classnum
def _get_inputs(self):
return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'),
InputDesc(tf.int32, (None,), 'label')
]
def inputs(self):
return [tf.placeholder(tf.float32, (None, 30, 30, 3), 'input'),
tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
image, label = inputs
......
......@@ -20,13 +20,12 @@ IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
def inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
Define all the inputs (with type, shape, name) that the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
"""This function should build the model which takes the input variables
......
......@@ -3,11 +3,10 @@
# File: model_desc.py
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import tensorflow as tf
import six
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.develop import log_deprecated
from ..tfutils.gradproc import FilterNoneGrad
......@@ -38,9 +37,10 @@ class InputDesc(
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid InputDesc name: '{}'".format(name))
self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None
self._cached_placeholder = {}
return self
# TODO this method seems unused outside this class
def build_placeholder(self):
"""
Build a tf.placeholder from the metadata.
......@@ -51,8 +51,7 @@ class InputDesc(
with tf.name_scope(None): # clear any name scope it might get called in
ret = tf.placeholder(
self.type, shape=self.shape, name=self.name)
if self._cached_placeholder is None:
self._cached_placeholder = ret # cached_placeholder only caches the prefix='' case
self._register_cached_placeholder(ret)
return ret
# cannot memoize here, because InputDesc is hashed by its fields.
......@@ -63,28 +62,67 @@ class InputDesc(
Returns:
tf.Tensor:
"""
if self._cached_placeholder is not None:
return self._cached_placeholder
g = tf.get_default_graph()
if g in self._cached_placeholder:
return self._cached_placeholder[g]
else:
return self.build_placeholder()
def _register_cached_placeholder(self, placeholder):
graph = placeholder.graph
assert graph not in self._cached_placeholder, \
"Placeholder for this InputDesc had been created before! This is a bug."
self._cached_placeholder[graph] = placeholder
@staticmethod
def from_placeholder(placeholder):
name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name))
logger.error("You might have mistakenly created this placeholder multiple times!")
ret = InputDesc(
placeholder.dtype,
tuple(placeholder.shape.as_list()),
name)
ret._register_cached_placeholder(placeholder)
return ret
@six.add_metaclass(ABCMeta)
class ModelDescBase(object):
""" Base class for a model description.
"""
Base class for a model description.
"""
@memoized
def get_inputs_desc(self):
"""
Returns:
list[:class:`InputDesc`]: list of the underlying :class:`InputDesc`.
a list of :class:`InputDesc`.
"""
try:
return self._get_inputs()
except NotImplementedError:
with tf.Graph().as_default(): # create these placeholder in a temporary graph
inputs = self.inputs()
return [InputDesc.from_placeholder(p) for p in inputs]
@abstractmethod
def _get_inputs(self):
"""
:returns: a list of InputDesc
Returns:
a list of :class:`InputDesc`.
"""
raise NotImplementedError()
def inputs(self):
"""
__Create__ and returns a list of placeholders.
To be implemented by subclass.
The placeholders __have to__ be created inside this function.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
"""
raise NotImplementedError()
def build_graph(self, *args):
"""
......@@ -93,13 +131,12 @@ class ModelDescBase(object):
By default it will call :meth:`_build_graph` with a list of input tensors.
Args:
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
:class:`ModelDesc`) may require it to return necessary information
(e.g. cost) to build the trainer.
"""
if len(args) == 1:
arg = args[0]
......@@ -122,7 +159,8 @@ class ModelDescBase(object):
def _build_graph(self, inputs):
"""
This is an old interface which takes a list of tensors, instead of positional arguments.
This is an alternative interface which takes a list of tensors, instead of positional arguments.
By default :meth:`build_graph` will call this method.
"""
pass
......
......@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource):
Just produce placeholders as input tensors.
"""
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder() for v in inputs]
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
def _get_input_tensors(self):
return self._all_placehdrs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment