Commit a8cb7e33 authored by Yuxin Wu's avatar Yuxin Wu

Move related stuff to graph_builder.

parent 3f5b9d51
......@@ -4,10 +4,14 @@
from tensorpack.libinfo import __version__
from tensorpack.train import *
from tensorpack.models import *
from tensorpack.dataflow import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.tfutils import *
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.predict import *
......@@ -16,8 +16,8 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, InputSource)
from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput)
from ..predict import PredictorTowerBuilder
from .base import Callback
......@@ -190,7 +190,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
input = DataParallelFeedInput(input, tower_names)
assert isinstance(input, InputSource), input
assert isinstance(input, DataParallelFeedInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
......@@ -10,109 +10,22 @@ except ImportError:
pass
from itertools import chain
from abc import ABCMeta, abstractmethod
import six
from six.moves import range, zip
from .utils import get_sublist_by_names, get_tensors_inputs
from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputSource',
'FeedInput', 'DataParallelFeedInput',
__all__ = ['FeedInput', 'DataParallelFeedInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper', 'remap_input_source']
@six.add_metaclass(ABCMeta)
class InputSource(object):
""" Base class for the abstract InputSource. """
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass
def get_callbacks(self):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return self._get_callbacks()
def _get_callbacks(self):
return []
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def _get_input_tensors(self):
return self._input.get_input_tensors()
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def _get_callbacks(self):
return self._input.get_callbacks()
def _size(self):
return self._input.size()
def _reset_state(self):
self._input.reset_state()
'StagingInputWrapper']
class FeedInput(InputSource):
......@@ -299,6 +212,7 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb]
......@@ -542,59 +456,3 @@ class StagingInputWrapper(FeedfreeInput):
def _get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
def remap_input_source(input, names):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
def __init__(self, input, names):
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def _get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'_setup': _setup,
'_get_input_tensors': _get_input_tensors})
return cls(input, names)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: input_source_base.py
from abc import ABCMeta, abstractmethod
import six
from ._utils import get_sublist_by_names, get_tensors_inputs
__all__ = ['InputSource', 'remap_input_source']
@six.add_metaclass(ABCMeta)
class InputSource(object):
""" Base class for the abstract InputSource. """
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass
def get_callbacks(self):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return self._get_callbacks()
def _get_callbacks(self):
return []
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def _get_input_tensors(self):
return self._input.get_input_tensors()
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def _get_callbacks(self):
return self._input.get_callbacks()
def _size(self):
return self._input.size()
def _reset_state(self):
self._input.reset_state()
def remap_input_source(input, names):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
def __init__(self, input, names):
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def _get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'_setup': _setup,
'_get_input_tensors': _get_input_tensors})
return cls(input, names)
......@@ -9,9 +9,8 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
# TODO sort out import issues
# from ..train.input_source import InputSource
from .regularize import regularize_cost_from_collection
from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc']
......@@ -141,7 +140,7 @@ class ModelDesc(object):
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if not isinstance(inputs, (list, tuple)):
if isinstance(inputs, InputSource):
inputs = inputs.get_input_tensors()
assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
import six
from ..models import ModelDesc
from ..graph_builder import ModelDesc
from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
......
# -*- coding: UTF-8 -*-
# File: exporter.py
# File: export.py
# Author: Patrick Wieschollek <mail@patwie.com>
"""
......@@ -8,9 +8,9 @@ This simplifies the process of exporting a model for TensorFlow serving.
"""
import tensorflow as tf
from tensorpack.utils import logger
from tensorpack.models import ModelDesc
from tensorpack.tfutils import TowerContext, sessinit
from ..utils import logger
from ..graph_builder.model_desc import ModelDesc
from ..tfutils import TowerContext, sessinit
__all__ = ['ModelExport']
......@@ -49,7 +49,7 @@ class ModelExport(object):
prediction = sess.run(prediction, {lowres: ...})[0]
Args:
model (ModelDescr): the model description which should be exported
model (ModelDesc): the model description which should be exported
input_names (list(str)): names of input tensors
output_names (list(str)): names of output tensors
"""
......
......@@ -31,8 +31,10 @@ class TowerContext(object):
self._is_training = bool(is_training)
self._index = int(index)
self._vs_name = str(vs_name)
self._vs_name = vs_name
if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name)
@property
def is_main_training_tower(self):
......@@ -48,14 +50,15 @@ class TowerContext(object):
@property
def has_own_variables(self):
"""
Whether this tower is supposed to have its own variables.
"""
return self.is_main_training_tower or len(self._vs_name) > 0
@property
def name(self):
return self._name
# TODO remove this and add something like `tower.variables`
# variable_scope name
@property
def vs_name(self):
return self._vs_name
......
......@@ -7,13 +7,13 @@ from ..callbacks import (
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..graph_builder.model_desc import ModelDesc
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from .input_source import InputSource
from ..graph_builder.input_source_base import InputSource
__all__ = ['TrainConfig']
......
......@@ -9,7 +9,7 @@ from six.moves import zip
from ..utils import logger
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput
from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .base import Trainer
......
......@@ -17,7 +17,7 @@ from ..callbacks.graph import RunOp
from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
......
......@@ -7,7 +7,7 @@ from .base import Trainer
from ..utils import logger
from ..tfutils import TowerContext
from .input_source import FeedInput
from ..graph_builder.input_source import FeedInput
__all__ = ['SimpleTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment