Commit a266459e authored by Yuxin Wu's avatar Yuxin Wu

use InputDesc inside input_source, instead of ModelDesc

parent 8dcf454d
...@@ -29,8 +29,8 @@ class Model(ModelDesc): ...@@ -29,8 +29,8 @@ class Model(ModelDesc):
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, [None, 30, 30, 3], 'input'), return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'),
InputDesc(tf.int32, [None], 'label') InputDesc(tf.int32, (None,), 'label')
] ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -6,14 +6,12 @@ ...@@ -6,14 +6,12 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
import tensorflow as tf import tensorflow as tf
import pickle
import six import six
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from .regularize import regularize_cost_from_collection from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'InputVar', 'ModelDesc'] __all__ = ['InputDesc', 'ModelDesc']
class InputDesc( class InputDesc(
...@@ -24,23 +22,36 @@ class InputDesc( ...@@ -24,23 +22,36 @@ class InputDesc(
input source. input source.
""" """
def dumps(self): _cached_placeholder = None
"""
Returns:
str: serialized string
"""
return pickle.dumps(self)
@staticmethod def __init__(self, type, shape, name):
def loads(buf):
""" """
Args: Args:
buf (str): serialized string type (tf.DType):
shape (tuple):
Returns: name (str):
InputDesc: """
""" shape = tuple(shape) # has to be tuple for self to be hashable
return pickle.loads(buf) super(InputDesc, self).__init__(type, shape, name)
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def build_placeholder(self, prefix=''): def build_placeholder(self, prefix=''):
""" """
...@@ -53,11 +64,13 @@ class InputDesc( ...@@ -53,11 +64,13 @@ class InputDesc(
tf.Tensor: tf.Tensor:
""" """
with tf.name_scope(None): # clear any name scope it might get called in with tf.name_scope(None): # clear any name scope it might get called in
return tf.placeholder( ret = tf.placeholder(
self.type, shape=self.shape, self.type, shape=self.shape,
name=prefix + self.name) name=prefix + self.name)
if prefix == '' and self._cached_placeholder is None:
self._cached_placeholder = ret
return ret
# TODO cache results from build_placeholder, and skip it in serialization
@memoized @memoized
def build_placeholder_reuse(self): def build_placeholder_reuse(self):
""" """
...@@ -66,21 +79,18 @@ class InputDesc( ...@@ -66,21 +79,18 @@ class InputDesc(
Returns: Returns:
tf.Tensor: tf.Tensor:
""" """
if self._cached_placeholder is not None:
return self._cached_placeholder
return self.build_placeholder() return self.build_placeholder()
class InputVar(InputDesc):
def __init__(self, *args, **kwargs):
logger.warn("[Deprecated] InputVar was renamed to InputDesc!")
super(InputVar, self).__init__(*args, **kwargs)
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description. """ Base class for a model description.
""" """
# inputs: # inputs:
# TODO remove this method?
@memoized @memoized
def get_reused_placehdrs(self): def get_reused_placehdrs(self):
""" """
...@@ -89,7 +99,7 @@ class ModelDesc(object): ...@@ -89,7 +99,7 @@ class ModelDesc(object):
Returns: Returns:
list[tf.Tensor]: the list of input placeholders in the graph. list[tf.Tensor]: the list of input placeholders in the graph.
""" """
return self.build_placeholders() return [v.build_placeholder_reuse() for v in self.get_inputs_desc()]
def build_placeholders(self, prefix=''): def build_placeholders(self, prefix=''):
""" """
...@@ -99,7 +109,7 @@ class ModelDesc(object): ...@@ -99,7 +109,7 @@ class ModelDesc(object):
Returns: Returns:
list[tf.Tensor]: the list of built placeholders. list[tf.Tensor]: the list of built placeholders.
""" """
inputs = self._get_inputs() inputs = self.get_inputs_desc()
ret = [] ret = []
for v in inputs: for v in inputs:
ret.append(v.build_placeholder(prefix)) ret.append(v.build_placeholder(prefix))
......
...@@ -83,7 +83,8 @@ class FeedInput(InputSource): ...@@ -83,7 +83,8 @@ class FeedInput(InputSource):
return self.ds.size() return self.ds.size()
def setup(self, model): def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs() inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None: if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs self._placehdrs_to_feed = self._all_placehdrs
else: else:
...@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput): ...@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput):
self._nr_tower = len(tower_names) self._nr_tower = len(tower_names)
def setup(self, model): def setup(self, model):
inputs = model.get_inputs_desc()
self._placehdrs_per_tower = [] self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = [] self._feed_placehdrs_per_tower = []
for tname in self._tower_names: for tname in self._tower_names:
# build a list of placeholders for each tower # build a list of placeholders for each tower
self._placehdrs_per_tower.append( self._placehdrs_per_tower.append(
model.build_placeholders( [v.build_placeholder(prefix=tname + '/') for v in inputs])
prefix=tname + '/'))
# apply input mapping and store results in feed_placehdrs_per_tower # apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None: if self._input_names is None:
...@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput): ...@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed # TODO use input data mapping. not all placeholders are needed
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None: if self._names is None:
self._queue_feedpoint = self.input_placehdrs self._queue_feedpoint = self.input_placehdrs
else: else:
...@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput): ...@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput):
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!" "BatchQueueInput has to be used with some InputDesc!"
...@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput): ...@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput):
tlist = [] tlist = []
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx is not None assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs) assert len(self.shapes) == len(self.inputs_desc)
for idx, p in enumerate(self.input_placehdrs): for idx, p in enumerate(self.inputs_desc):
tlist.append(tf.constant( tlist.append(tf.constant(
0, dtype=p.dtype, 0, dtype=p.type,
name='dummy-{}-{}'.format(p.op.name, ctx.index), name='dummy-{}-{}'.format(p.name, ctx.index),
shape=self.shapes[idx])) shape=self.shapes[idx]))
return tlist return tlist
super(DummyConstantInput, self).__init__(fn) super(DummyConstantInput, self).__init__(fn)
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self.inputs_desc = model.get_inputs_desc()
# TODO doesn't support remapping # TODO doesn't support remapping
class ZMQInput(TensorInput): class ZMQInput(TensorInput):
"""
Not well implemented yet. Don't use.
"""
def __init__(self, endpoint): def __init__(self, endpoint):
self._endpoint = endpoint self._endpoint = endpoint
from tensorpack.user_ops import zmq_recv from tensorpack.user_ops import zmq_recv
def fn(): def fn():
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs]) ret = zmq_recv(self._endpoint, [x.dtype for x in self.inputs_desc])
if isinstance(ret, tf.Tensor): if isinstance(ret, tf.Tensor):
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.inputs_desc)
for qv, v in zip(ret, self.input_placehdrs): for qv, v in zip(ret, self.inputs_desc):
qv.set_shape(v.get_shape()) qv.set_shape(v.shape)
return ret return ret
super(ZMQInput, self).__init__(fn) super(ZMQInput, self).__init__(fn)
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self.inputs_desc = model.get_inputs_desc()
assert len(self.input_placehdrs) > 0, \ assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
...@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput): ...@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput):
return self._input.size() return self._input.size()
def setup(self, model): def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs() inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(model) self._input.setup(model)
def setup_training(self, trainer): def setup_training(self, trainer):
self._all_placehdrs = trainer.model.get_reused_placehdrs() inputs = trainer.model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup_training(trainer) self._input.setup_training(trainer)
def reset_state(self): def reset_state(self):
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: trainer.py # File: simple.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer): ...@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer):
def _setup(self): def _setup(self):
self._input_source.setup_training(self) self._input_source.setup_training(self)
model = self.model model = self.model
self.inputs = model.get_reused_placehdrs() self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.inputs) model.build_graph(self.inputs)
cost_var = model.get_cost() cost_var = model.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment