Commit a8cb7e33 authored by Yuxin Wu's avatar Yuxin Wu

Move related stuff to graph_builder.

parent 3f5b9d51
...@@ -4,10 +4,14 @@ ...@@ -4,10 +4,14 @@
from tensorpack.libinfo import __version__ from tensorpack.libinfo import __version__
from tensorpack.train import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.dataflow import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.tfutils import *
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.predict import * from tensorpack.predict import *
...@@ -16,8 +16,8 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -16,8 +16,8 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..train.input_source import ( from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, InputSource) FeedInput, DataParallelFeedInput, FeedfreeInput)
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -190,7 +190,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -190,7 +190,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))] tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
input = DataParallelFeedInput(input, tower_names) input = DataParallelFeedInput(input, tower_names)
assert isinstance(input, InputSource), input assert isinstance(input, DataParallelFeedInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus self._gpus = gpus
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
...@@ -10,109 +10,22 @@ except ImportError: ...@@ -10,109 +10,22 @@ except ImportError:
pass pass
from itertools import chain from itertools import chain
from abc import ABCMeta, abstractmethod
import six
from six.moves import range, zip from six.moves import range, zip
from .utils import get_sublist_by_names, get_tensors_inputs from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputSource', __all__ = ['FeedInput', 'DataParallelFeedInput',
'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput', 'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper', 'remap_input_source'] 'StagingInputWrapper']
@six.add_metaclass(ABCMeta)
class InputSource(object):
""" Base class for the abstract InputSource. """
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass
def get_callbacks(self):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return self._get_callbacks()
def _get_callbacks(self):
return []
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def _get_input_tensors(self):
return self._input.get_input_tensors()
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def _get_callbacks(self):
return self._input.get_callbacks()
def _size(self):
return self._input.size()
def _reset_state(self):
self._input.reset_state()
class FeedInput(InputSource): class FeedInput(InputSource):
...@@ -299,6 +212,7 @@ class QueueInput(FeedfreeInput): ...@@ -299,6 +212,7 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _get_callbacks(self): def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
cb.chief_only = False cb.chief_only = False
return [cb] return [cb]
...@@ -542,59 +456,3 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -542,59 +456,3 @@ class StagingInputWrapper(FeedfreeInput):
def _get_unstage_op(self): def _get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return tf.group(*all_outputs)
def remap_input_source(input, names):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
def __init__(self, input, names):
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def _get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'_setup': _setup,
'_get_input_tensors': _get_input_tensors})
return cls(input, names)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: input_source_base.py
from abc import ABCMeta, abstractmethod
import six
from ._utils import get_sublist_by_names, get_tensors_inputs
__all__ = ['InputSource', 'remap_input_source']
@six.add_metaclass(ABCMeta)
class InputSource(object):
""" Base class for the abstract InputSource. """
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass
def get_callbacks(self):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return self._get_callbacks()
def _get_callbacks(self):
return []
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource):
"""
An InputSource which proxy every method to ``self._input``.
"""
def __init__(self, input):
assert isinstance(input, InputSource), input
self._input = input
def _get_input_tensors(self):
return self._input.get_input_tensors()
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def _get_callbacks(self):
return self._input.get_callbacks()
def _size(self):
return self._input.size()
def _reset_state(self):
self._input.reset_state()
def remap_input_source(input, names):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
def __init__(self, input, names):
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def _get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
oldcls = type(input)
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'_setup': _setup,
'_get_input_tensors': _get_input_tensors})
return cls(input, names)
...@@ -9,9 +9,8 @@ import tensorflow as tf ...@@ -9,9 +9,8 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
# TODO sort out import issues from .input_source_base import InputSource
# from ..train.input_source import InputSource from ..models.regularize import regularize_cost_from_collection
from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc'] __all__ = ['InputDesc', 'ModelDesc']
...@@ -141,7 +140,7 @@ class ModelDesc(object): ...@@ -141,7 +140,7 @@ class ModelDesc(object):
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`, inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
that match the list of :class:`InputDesc` defined by ``_get_inputs``. that match the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
if not isinstance(inputs, (list, tuple)): if isinstance(inputs, InputSource):
inputs = inputs.get_input_tensors() inputs = inputs.get_input_tensors()
assert len(inputs) == len(self.get_inputs_desc()), \ assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \ "Number of inputs passed to the graph != number of inputs defined " \
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
import six import six
from ..models import ModelDesc from ..graph_builder import ModelDesc
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: exporter.py # File: export.py
# Author: Patrick Wieschollek <mail@patwie.com> # Author: Patrick Wieschollek <mail@patwie.com>
""" """
...@@ -8,9 +8,9 @@ This simplifies the process of exporting a model for TensorFlow serving. ...@@ -8,9 +8,9 @@ This simplifies the process of exporting a model for TensorFlow serving.
""" """
import tensorflow as tf import tensorflow as tf
from tensorpack.utils import logger from ..utils import logger
from tensorpack.models import ModelDesc from ..graph_builder.model_desc import ModelDesc
from tensorpack.tfutils import TowerContext, sessinit from ..tfutils import TowerContext, sessinit
__all__ = ['ModelExport'] __all__ = ['ModelExport']
...@@ -49,7 +49,7 @@ class ModelExport(object): ...@@ -49,7 +49,7 @@ class ModelExport(object):
prediction = sess.run(prediction, {lowres: ...})[0] prediction = sess.run(prediction, {lowres: ...})[0]
Args: Args:
model (ModelDescr): the model description which should be exported model (ModelDesc): the model description which should be exported
input_names (list(str)): names of input tensors input_names (list(str)): names of input tensors
output_names (list(str)): names of output tensors output_names (list(str)): names of output tensors
""" """
......
...@@ -31,8 +31,10 @@ class TowerContext(object): ...@@ -31,8 +31,10 @@ class TowerContext(object):
self._is_training = bool(is_training) self._is_training = bool(is_training)
self._index = int(index) self._index = int(index)
self._vs_name = str(vs_name)
self._vs_name = vs_name if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name)
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
...@@ -48,14 +50,15 @@ class TowerContext(object): ...@@ -48,14 +50,15 @@ class TowerContext(object):
@property @property
def has_own_variables(self): def has_own_variables(self):
"""
Whether this tower is supposed to have its own variables.
"""
return self.is_main_training_tower or len(self._vs_name) > 0 return self.is_main_training_tower or len(self._vs_name) > 0
@property @property
def name(self): def name(self):
return self._name return self._name
# TODO remove this and add something like `tower.variables`
# variable_scope name
@property @property
def vs_name(self): def vs_name(self):
return self._vs_name return self._vs_name
......
...@@ -7,13 +7,13 @@ from ..callbacks import ( ...@@ -7,13 +7,13 @@ from ..callbacks import (
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps) TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..graph_builder.model_desc import ModelDesc
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from .input_source import InputSource from ..graph_builder.input_source_base import InputSource
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
......
...@@ -9,7 +9,7 @@ from six.moves import zip ...@@ -9,7 +9,7 @@ from six.moves import zip
from ..utils import logger from ..utils import logger
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
......
...@@ -17,7 +17,7 @@ from ..callbacks.graph import RunOp ...@@ -17,7 +17,7 @@ from ..callbacks.graph import RunOp
from .base import Trainer from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter', 'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
......
...@@ -7,7 +7,7 @@ from .base import Trainer ...@@ -7,7 +7,7 @@ from .base import Trainer
from ..utils import logger from ..utils import logger
from ..tfutils import TowerContext from ..tfutils import TowerContext
from .input_source import FeedInput from ..graph_builder.input_source import FeedInput
__all__ = ['SimpleTrainer'] __all__ = ['SimpleTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment