Commit 39fa4656 authored by Yuxin Wu's avatar Yuxin Wu

Use inputs() and tf.placeholder in ModelDesc (#318)

parent 215a4d6d
...@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun ...@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun
```python ```python
class MyModel(ModelDesc): class MyModel(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(...), InputDesc(...)] return [tf.placeholder(dtype, shape, name), tf.placeholder(dtype, shape, name), ... ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
tensorA, tensorB = inputs tensorA, tensorB = inputs
......
...@@ -26,10 +26,9 @@ class Model(ModelDesc): ...@@ -26,10 +26,9 @@ class Model(ModelDesc):
super(Model, self).__init__() super(Model, self).__init__()
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'), return [tf.placeholder(tf.float32, (None, 30, 30, 3), 'input'),
InputDesc(tf.int32, (None,), 'label') tf.placeholder(tf.int32, (None,), 'label')]
]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -20,13 +20,12 @@ IMAGE_SIZE = 28 ...@@ -20,13 +20,12 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
""" """
Define all the inputs (with type, shape, name) that Define all the inputs (with type, shape, name) that the graph will need.
the graph will need.
""" """
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
"""This function should build the model which takes the input variables """This function should build the model which takes the input variables
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
# File: model_desc.py # File: model_desc.py
from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
import tensorflow as tf import tensorflow as tf
import six
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
...@@ -38,9 +37,10 @@ class InputDesc( ...@@ -38,9 +37,10 @@ class InputDesc(
if any(k in name for k in [':', '/', ' ']): if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid InputDesc name: '{}'".format(name)) raise ValueError("Invalid InputDesc name: '{}'".format(name))
self = super(InputDesc, cls).__new__(cls, type, shape, name) self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None self._cached_placeholder = {}
return self return self
# TODO this method seems unused outside this class
def build_placeholder(self): def build_placeholder(self):
""" """
Build a tf.placeholder from the metadata. Build a tf.placeholder from the metadata.
...@@ -51,8 +51,7 @@ class InputDesc( ...@@ -51,8 +51,7 @@ class InputDesc(
with tf.name_scope(None): # clear any name scope it might get called in with tf.name_scope(None): # clear any name scope it might get called in
ret = tf.placeholder( ret = tf.placeholder(
self.type, shape=self.shape, name=self.name) self.type, shape=self.shape, name=self.name)
if self._cached_placeholder is None: self._register_cached_placeholder(ret)
self._cached_placeholder = ret # cached_placeholder only caches the prefix='' case
return ret return ret
# cannot memoize here, because InputDesc is hashed by its fields. # cannot memoize here, because InputDesc is hashed by its fields.
...@@ -63,28 +62,67 @@ class InputDesc( ...@@ -63,28 +62,67 @@ class InputDesc(
Returns: Returns:
tf.Tensor: tf.Tensor:
""" """
if self._cached_placeholder is not None: g = tf.get_default_graph()
return self._cached_placeholder if g in self._cached_placeholder:
return self.build_placeholder() return self._cached_placeholder[g]
else:
return self.build_placeholder()
def _register_cached_placeholder(self, placeholder):
graph = placeholder.graph
assert graph not in self._cached_placeholder, \
"Placeholder for this InputDesc had been created before! This is a bug."
self._cached_placeholder[graph] = placeholder
@staticmethod
def from_placeholder(placeholder):
name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name))
logger.error("You might have mistakenly created this placeholder multiple times!")
ret = InputDesc(
placeholder.dtype,
tuple(placeholder.shape.as_list()),
name)
ret._register_cached_placeholder(placeholder)
return ret
@six.add_metaclass(ABCMeta)
class ModelDescBase(object): class ModelDescBase(object):
""" Base class for a model description.
""" """
Base class for a model description.
"""
@memoized @memoized
def get_inputs_desc(self): def get_inputs_desc(self):
""" """
Returns: Returns:
list[:class:`InputDesc`]: list of the underlying :class:`InputDesc`. a list of :class:`InputDesc`.
""" """
return self._get_inputs() try:
return self._get_inputs()
except NotImplementedError:
with tf.Graph().as_default(): # create these placeholder in a temporary graph
inputs = self.inputs()
return [InputDesc.from_placeholder(p) for p in inputs]
@abstractmethod
def _get_inputs(self): def _get_inputs(self):
""" """
:returns: a list of InputDesc Returns:
a list of :class:`InputDesc`.
"""
raise NotImplementedError()
def inputs(self):
"""
__Create__ and returns a list of placeholders.
To be implemented by subclass.
The placeholders __have to__ be created inside this function.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
""" """
raise NotImplementedError()
def build_graph(self, *args): def build_graph(self, *args):
""" """
...@@ -93,13 +131,12 @@ class ModelDescBase(object): ...@@ -93,13 +131,12 @@ class ModelDescBase(object):
By default it will call :meth:`_build_graph` with a list of input tensors. By default it will call :meth:`_build_graph` with a list of input tensors.
Args: Args:
args ([tf.Tensor]): tensors that matches the list of args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
:class:`InputDesc` defined by ``_get_inputs``.
Returns: Returns:
In general it returns nothing, but a subclass (e.g. In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information :class:`ModelDesc`) may require it to return necessary information
to build the trainer. (e.g. cost) to build the trainer.
""" """
if len(args) == 1: if len(args) == 1:
arg = args[0] arg = args[0]
...@@ -122,7 +159,8 @@ class ModelDescBase(object): ...@@ -122,7 +159,8 @@ class ModelDescBase(object):
def _build_graph(self, inputs): def _build_graph(self, inputs):
""" """
This is an old interface which takes a list of tensors, instead of positional arguments. This is an alternative interface which takes a list of tensors, instead of positional arguments.
By default :meth:`build_graph` will call this method.
""" """
pass pass
......
...@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource): ...@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource):
Just produce placeholders as input tensors. Just produce placeholders as input tensors.
""" """
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
def _get_input_tensors(self): def _get_input_tensors(self):
return self._all_placehdrs return self._all_placehdrs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment