Commit ef9e27a6 authored by Yuxin Wu's avatar Yuxin Wu

dynamically generate the remapped class

parent adf51f22
...@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary ...@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
...@@ -62,6 +61,9 @@ class InputSource(object): ...@@ -62,6 +61,9 @@ class InputSource(object):
@abstractmethod @abstractmethod
def reset_state(self): def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
pass pass
@abstractmethod @abstractmethod
...@@ -80,6 +82,33 @@ class InputSource(object): ...@@ -80,6 +82,33 @@ class InputSource(object):
return NotImplementedError() return NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def get_input_tensors(self):
return self._input.get_input_tensors()
def setup(self, inputs_desc):
self._input.setup(inputs_desc)
def get_callbacks(self):
return self._input.get_callbacks()
def size(self):
return self._input.size()
def next_feed(self):
return self._input.next_feed()
def reset_state(self):
self._input.reset_state()
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds): def __init__(self, ds):
...@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput):
def get_staging_name(idx): def get_staging_name(idx):
return 'StagingArea{}'.format(idx) return 'StagingArea{}'.format(idx)
@memoized
def get_stage_op(self): def get_stage_op(self):
return tf.group(*self._stage_ops) return tf.group(*self._stage_ops)
@memoized
def get_unstage_op(self): def get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return tf.group(*all_outputs)
# TODO dynamically generate inheritance def remap_input_source(input, names):
# TODO make it a function, not a class
class remap_input_source(FeedInput, FeedfreeInput):
""" """
When you have some :class:`InputSource` which doesn't match the inputs in When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`. your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model, It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`. by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
""" """
def __init__(self, input, names): def __init__(self, input, names):
""" ProxyInputSource.__init__(self, input)
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert isinstance(input, InputSource), input
self._input = input
assert isinstance(names, (list, tuple)), names assert isinstance(names, (list, tuple)), names
self._names = tuple(names) self._names = tuple(names)
def size(self):
return self._input.size()
def setup(self, inputs): def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names) inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset) self._input.setup(inputs_subset)
def get_callbacks(self):
return self._input.get_callbacks()
def reset_state(self):
self._input.reset_state()
def next_feed(self):
return self._input.next_feed()
def get_input_tensors(self): def get_input_tensors(self):
ret = self._input.get_input_tensors() ret = self._input.get_input_tensors()
assert len(ret) == len(self._names) assert len(ret) == len(self._names)
return get_tensors_inputs( return get_tensors_inputs(
self._all_placehdrs, ret, self._names) self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'setup': setup,
'get_input_tensors': get_input_tensors})
return cls(input, names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment