Commit 63c0f891 authored by Yuxin Wu's avatar Yuxin Wu

move InputSource to a separate folder

parent 9e995a8d
...@@ -8,6 +8,7 @@ API Documentation ...@@ -8,6 +8,7 @@ API Documentation
dataflow dataflow
dataflow.dataset dataflow.dataset
dataflow.imgaug dataflow.imgaug
input_source
models models
callbacks callbacks
graph_builder graph_builder
......
tensorpack.input_source package
================================
.. automodule:: tensorpack.input_source
:members:
:undoc-members:
:show-inheritance:
...@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs ...@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource from ..input_source import (
from ..graph_builder.input_source import ( InputSource, FeedInput, QueueInput)
FeedInput, QueueInput)
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
......
...@@ -9,7 +9,7 @@ import tensorflow as tf ...@@ -9,7 +9,7 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
from .input_source_base import InputSource from ..input_source import InputSource
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......
...@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names ...@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from .input_source import PlaceholderInput from ..input_source import PlaceholderInput
__all__ = [] __all__ = []
......
...@@ -3,68 +3,15 @@ ...@@ -3,68 +3,15 @@
# File: utils.py # File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from contextlib import contextmanager from contextlib import contextmanager
import operator import operator
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable', __all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable'] 'override_to_local_variable']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@contextmanager @contextmanager
def override_to_local_variable(enable=True): def override_to_local_variable(enable=True):
if enable: if enable:
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
if not k.startswith('__'):
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
...@@ -3,17 +3,68 @@ ...@@ -3,17 +3,68 @@
# File: input_source_base.py # File: input_source_base.py
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import copy
import six import six
from six.moves import zip
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..utils.argtools import memoized from ..utils.argtools import memoized
from .utils import get_sublist_by_names, get_tensors_inputs
from ..callbacks.base import CallbackFactory from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputSource(object): class InputSource(object):
""" Base class for the abstract InputSource. """ """ Base class for the abstract InputSource. """
......
...@@ -9,7 +9,7 @@ import six ...@@ -9,7 +9,7 @@ import six
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import PlaceholderInput from ..input_source import PlaceholderInput
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.input_source import PlaceholderInput from ..input_source import PlaceholderInput
from .base import OnlinePredictor from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor', __all__ = ['MultiTowerOfflinePredictor',
......
...@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving. ...@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..graph_builder.input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils import TowerContext, sessinit from ..tfutils import TowerContext, sessinit
......
...@@ -12,7 +12,7 @@ from ..utils import logger ...@@ -12,7 +12,7 @@ from ..utils import logger
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..graph_builder.input_source_base import InputSource from ..input_source import InputSource
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
......
...@@ -8,7 +8,7 @@ import tensorflow as tf ...@@ -8,7 +8,7 @@ import tensorflow as tf
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..graph_builder.training import ( from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder, SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder, SyncMultiGPUReplicatedBuilder,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
from ..graph_builder.input_source import FeedInput, QueueInput from ..input_source import FeedInput, QueueInput
from ..graph_builder.training import SimpleBuilder from ..graph_builder.training import SimpleBuilder
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment