Commit ef9e27a6 authored by Yuxin Wu's avatar Yuxin Wu

dynamically generate the remapped class

parent adf51f22
......@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
......@@ -62,6 +61,9 @@ class InputSource(object):
@abstractmethod
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
pass
@abstractmethod
......@@ -80,6 +82,33 @@ class InputSource(object):
return NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def get_input_tensors(self):
return self._input.get_input_tensors()
def setup(self, inputs_desc):
self._input.setup(inputs_desc)
def get_callbacks(self):
return self._input.get_callbacks()
def size(self):
return self._input.size()
def next_feed(self):
return self._input.next_feed()
def reset_state(self):
self._input.reset_state()
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds):
......@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput):
def get_staging_name(idx):
return 'StagingArea{}'.format(idx)
@memoized
def get_stage_op(self):
return tf.group(*self._stage_ops)
@memoized
def get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
# TODO dynamically generate inheritance
# TODO make it a function, not a class
class remap_input_source(FeedInput, FeedfreeInput):
def remap_input_source(input, names):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
def __init__(self, input, names):
"""
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert isinstance(input, InputSource), input
self._input = input
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def size(self):
return self._input.size()
def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def get_callbacks(self):
return self._input.get_callbacks()
def reset_state(self):
self._input.reset_state()
def next_feed(self):
return self._input.next_feed()
def get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'setup': setup,
'get_input_tensors': get_input_tensors})
return cls(input, names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment