Commit 63c0f891 authored by Yuxin Wu's avatar Yuxin Wu

move InputSource to a separate folder

parent 9e995a8d
......@@ -8,6 +8,7 @@ API Documentation
dataflow
dataflow.dataset
dataflow.imgaug
input_source
models
callbacks
graph_builder
......
tensorpack.input_source package
================================
.. automodule:: tensorpack.input_source
:members:
:undoc-members:
:show-inheritance:
......@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import (
FeedInput, QueueInput)
from ..input_source import (
InputSource, FeedInput, QueueInput)
from .base import Callback
from .group import Callbacks
......
......@@ -9,7 +9,7 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
from .input_source_base import InputSource
from ..input_source import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......
......@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS
from .input_source import PlaceholderInput
from ..input_source import PlaceholderInput
__all__ = []
......
......@@ -3,68 +3,15 @@
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@contextmanager
def override_to_local_variable(enable=True):
if enable:
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
if not k.startswith('__'):
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
......@@ -3,17 +3,68 @@
# File: input_source_base.py
from abc import ABCMeta, abstractmethod
import copy
import six
from six.moves import zip
from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import memoized
from .utils import get_sublist_by_names, get_tensors_inputs
from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['InputSource', 'remap_input_source']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@six.add_metaclass(ABCMeta)
class InputSource(object):
""" Base class for the abstract InputSource. """
......
......@@ -9,7 +9,7 @@ import six
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import PlaceholderInput
from ..input_source import PlaceholderInput
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from ..utils import logger
from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.input_source import PlaceholderInput
from ..input_source import PlaceholderInput
from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor',
......
......@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
import tensorflow as tf
from ..utils import logger
from ..graph_builder.model_desc import ModelDescBase
from ..graph_builder.input_source import PlaceholderInput
from ..input_source import PlaceholderInput
from ..tfutils import TowerContext, sessinit
......
......@@ -12,7 +12,7 @@ from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from ..graph_builder.input_source_base import InputSource
from ..input_source import InputSource
__all__ = ['TrainConfig']
......
......@@ -8,7 +8,7 @@ import tensorflow as tf
from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
......
......@@ -6,7 +6,7 @@
from .base import Trainer
from ..utils import logger
from ..graph_builder.input_source import FeedInput, QueueInput
from ..input_source import FeedInput, QueueInput
from ..graph_builder.training import SimpleBuilder
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment