Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
63c0f891
Commit
63c0f891
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move InputSource to a separate folder
parent
9e995a8d
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
102 additions
and
65 deletions
+102
-65
docs/modules/index.rst
docs/modules/index.rst
+1
-0
docs/modules/input_source.rst
docs/modules/input_source.rst
+7
-0
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-3
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+1
-1
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+1
-1
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+0
-53
tensorpack/input_source/__init__.py
tensorpack/input_source/__init__.py
+32
-0
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+0
-0
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+52
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-1
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+1
-1
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+1
-1
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
tensorpack/train/simple.py
tensorpack/train/simple.py
+1
-1
No files found.
docs/modules/index.rst
View file @
63c0f891
...
...
@@ -8,6 +8,7 @@ API Documentation
dataflow
dataflow.dataset
dataflow.imgaug
input_source
models
callbacks
graph_builder
...
...
docs/modules/input_source.rst
0 → 100644
View file @
63c0f891
tensorpack.input_source package
================================
.. automodule:: tensorpack.input_source
:members:
:undoc-members:
:show-inheritance:
tensorpack/callbacks/inference_runner.py
View file @
63c0f891
...
...
@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs
from
..utils.develop
import
deprecated
from
..dataflow.base
import
DataFlow
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
FeedInput
,
QueueInput
)
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
from
.base
import
Callback
from
.group
import
Callbacks
...
...
tensorpack/graph_builder/model_desc.py
View file @
63c0f891
...
...
@@ -9,7 +9,7 @@ import tensorflow as tf
import
six
from
..utils.argtools
import
memoized
from
.
input_source_bas
e
import
InputSource
from
.
.input_sourc
e
import
InputSource
from
..models.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
63c0f891
...
...
@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
.input_source
import
PlaceholderInput
from
.
.
input_source
import
PlaceholderInput
__all__
=
[]
...
...
tensorpack/graph_builder/utils.py
View file @
63c0f891
...
...
@@ -3,68 +3,15 @@
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
copy
from
six.moves
import
zip
from
contextlib
import
contextmanager
import
operator
import
tensorflow
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideToLocalVariable'
,
'override_to_local_variable'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
try
:
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
lst
[
idx
])
return
ret
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
...
...
tensorpack/input_source/__init__.py
0 → 100644
View file @
63c0f891
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
iter_modules
import
os
import
os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
del
globals
()[
name
]
for
k
in
lst
:
if
not
k
.
startswith
(
'__'
):
globals
()[
k
]
=
p
.
__dict__
[
k
]
__all__
.
append
(
k
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[]
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
continue
if
module_name
.
startswith
(
'_'
):
continue
if
module_name
not
in
_SKIP
:
global_import
(
module_name
)
tensorpack/
graph_builder
/input_source.py
→
tensorpack/
input_source
/input_source.py
View file @
63c0f891
File moved
tensorpack/
graph_builder
/input_source_base.py
→
tensorpack/
input_source
/input_source_base.py
View file @
63c0f891
...
...
@@ -3,17 +3,68 @@
# File: input_source_base.py
from
abc
import
ABCMeta
,
abstractmethod
import
copy
import
six
from
six.moves
import
zip
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils.argtools
import
memoized
from
.utils
import
get_sublist_by_names
,
get_tensors_inputs
from
..callbacks.base
import
CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'InputSource'
,
'remap_input_source'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
try
:
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
lst
[
idx
])
return
ret
@
six
.
add_metaclass
(
ABCMeta
)
class
InputSource
(
object
):
""" Base class for the abstract InputSource. """
...
...
tensorpack/predict/base.py
View file @
63c0f891
...
...
@@ -9,7 +9,7 @@ import six
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
...
...
tensorpack/predict/multigpu.py
View file @
63c0f891
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
..utils
import
logger
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
.base
import
OnlinePredictor
__all__
=
[
'MultiTowerOfflinePredictor'
,
...
...
tensorpack/tfutils/export.py
View file @
63c0f891
...
...
@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
import
tensorflow
as
tf
from
..utils
import
logger
from
..graph_builder.model_desc
import
ModelDescBase
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..tfutils
import
TowerContext
,
sessinit
...
...
tensorpack/train/config.py
View file @
63c0f891
...
...
@@ -12,7 +12,7 @@ from ..utils import logger
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..
graph_builder.input_source_bas
e
import
InputSource
from
..
input_sourc
e
import
InputSource
__all__
=
[
'TrainConfig'
]
...
...
tensorpack/train/multigpu.py
View file @
63c0f891
...
...
@@ -8,7 +8,7 @@ import tensorflow as tf
from
..callbacks.graph
import
RunOp
from
..utils.develop
import
log_deprecated
from
..
graph_builder.
input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..graph_builder.training
import
(
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
...
...
tensorpack/train/simple.py
View file @
63c0f891
...
...
@@ -6,7 +6,7 @@
from
.base
import
Trainer
from
..utils
import
logger
from
..
graph_builder.
input_source
import
FeedInput
,
QueueInput
from
..input_source
import
FeedInput
,
QueueInput
from
..graph_builder.training
import
SimpleBuilder
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment