Commit a266459e authored by Yuxin Wu's avatar Yuxin Wu

use InputDesc inside input_source, instead of ModelDesc

parent 8dcf454d
......@@ -29,8 +29,8 @@ class Model(ModelDesc):
self.cifar_classnum = cifar_classnum
def _get_inputs(self):
return [InputDesc(tf.float32, [None, 30, 30, 3], 'input'),
InputDesc(tf.int32, [None], 'label')
return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'),
InputDesc(tf.int32, (None,), 'label')
]
def _build_graph(self, inputs):
......
......@@ -6,14 +6,12 @@
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import tensorflow as tf
import pickle
import six
from ..utils import logger
from ..utils.argtools import memoized
from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'InputVar', 'ModelDesc']
__all__ = ['InputDesc', 'ModelDesc']
class InputDesc(
......@@ -24,23 +22,36 @@ class InputDesc(
input source.
"""
def dumps(self):
"""
Returns:
str: serialized string
"""
return pickle.dumps(self)
_cached_placeholder = None
@staticmethod
def loads(buf):
def __init__(self, type, shape, name):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return pickle.loads(buf)
type (tf.DType):
shape (tuple):
name (str):
"""
shape = tuple(shape) # has to be tuple for self to be hashable
super(InputDesc, self).__init__(type, shape, name)
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def build_placeholder(self, prefix=''):
"""
......@@ -53,11 +64,13 @@ class InputDesc(
tf.Tensor:
"""
with tf.name_scope(None): # clear any name scope it might get called in
return tf.placeholder(
ret = tf.placeholder(
self.type, shape=self.shape,
name=prefix + self.name)
if prefix == '' and self._cached_placeholder is None:
self._cached_placeholder = ret
return ret
# TODO cache results from build_placeholder, and skip it in serialization
@memoized
def build_placeholder_reuse(self):
"""
......@@ -66,21 +79,18 @@ class InputDesc(
Returns:
tf.Tensor:
"""
if self._cached_placeholder is not None:
return self._cached_placeholder
return self.build_placeholder()
class InputVar(InputDesc):
def __init__(self, *args, **kwargs):
logger.warn("[Deprecated] InputVar was renamed to InputDesc!")
super(InputVar, self).__init__(*args, **kwargs)
@six.add_metaclass(ABCMeta)
class ModelDesc(object):
""" Base class for a model description.
"""
# inputs:
# TODO remove this method?
@memoized
def get_reused_placehdrs(self):
"""
......@@ -89,7 +99,7 @@ class ModelDesc(object):
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
"""
return self.build_placeholders()
return [v.build_placeholder_reuse() for v in self.get_inputs_desc()]
def build_placeholders(self, prefix=''):
"""
......@@ -99,7 +109,7 @@ class ModelDesc(object):
Returns:
list[tf.Tensor]: the list of built placeholders.
"""
inputs = self._get_inputs()
inputs = self.get_inputs_desc()
ret = []
for v in inputs:
ret.append(v.build_placeholder(prefix))
......
......@@ -83,7 +83,8 @@ class FeedInput(InputSource):
return self.ds.size()
def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs()
inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
else:
......@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput):
self._nr_tower = len(tower_names)
def setup(self, model):
inputs = model.get_inputs_desc()
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
model.build_placeholders(
prefix=tname + '/'))
[v.build_placeholder(prefix=tname + '/') for v in inputs])
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
......@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None:
self._queue_feedpoint = self.input_placehdrs
else:
......@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput):
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!"
......@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput):
tlist = []
ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs):
assert len(self.shapes) == len(self.inputs_desc)
for idx, p in enumerate(self.inputs_desc):
tlist.append(tf.constant(
0, dtype=p.dtype,
name='dummy-{}-{}'.format(p.op.name, ctx.index),
0, dtype=p.type,
name='dummy-{}-{}'.format(p.name, ctx.index),
shape=self.shapes[idx]))
return tlist
super(DummyConstantInput, self).__init__(fn)
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
self.inputs_desc = model.get_inputs_desc()
# TODO doesn't support remapping
class ZMQInput(TensorInput):
"""
Not well implemented yet. Don't use.
"""
def __init__(self, endpoint):
self._endpoint = endpoint
from tensorpack.user_ops import zmq_recv
def fn():
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
ret = zmq_recv(self._endpoint, [x.dtype for x in self.inputs_desc])
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
assert len(ret) == len(self.inputs_desc)
for qv, v in zip(ret, self.inputs_desc):
qv.set_shape(v.shape)
return ret
super(ZMQInput, self).__init__(fn)
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
self.inputs_desc = model.get_inputs_desc()
assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!"
......@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput):
return self._input.size()
def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs()
inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(model)
def setup_training(self, trainer):
self._all_placehdrs = trainer.model.get_reused_placehdrs()
inputs = trainer.model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup_training(trainer)
def reset_state(self):
......
# -*- coding: UTF-8 -*-
# File: trainer.py
# File: simple.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer):
def _setup(self):
self._input_source.setup_training(self)
model = self.model
self.inputs = model.get_reused_placehdrs()
self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True):
model.build_graph(self.inputs)
cost_var = model.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment