Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
63c0f891
Commit
63c0f891
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move InputSource to a separate folder
parent
9e995a8d
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
102 additions
and
65 deletions
+102
-65
docs/modules/index.rst
docs/modules/index.rst
+1
-0
docs/modules/input_source.rst
docs/modules/input_source.rst
+7
-0
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-3
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+1
-1
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+1
-1
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+0
-53
tensorpack/input_source/__init__.py
tensorpack/input_source/__init__.py
+32
-0
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+0
-0
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+52
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-1
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+1
-1
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+1
-1
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
tensorpack/train/simple.py
tensorpack/train/simple.py
+1
-1
No files found.
docs/modules/index.rst
View file @
63c0f891
...
@@ -8,6 +8,7 @@ API Documentation
...
@@ -8,6 +8,7 @@ API Documentation
dataflow
dataflow
dataflow.dataset
dataflow.dataset
dataflow.imgaug
dataflow.imgaug
input_source
models
models
callbacks
callbacks
graph_builder
graph_builder
...
...
docs/modules/input_source.rst
0 → 100644
View file @
63c0f891
tensorpack.input_source package
================================
.. automodule:: tensorpack.input_source
:members:
:undoc-members:
:show-inheritance:
tensorpack/callbacks/inference_runner.py
View file @
63c0f891
...
@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs
...
@@ -17,9 +17,8 @@ from ..utils.utils import get_tqdm_kwargs
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..graph_builder.input_source_base
import
InputSource
from
..input_source
import
(
from
..graph_builder.input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
FeedInput
,
QueueInput
)
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
...
...
tensorpack/graph_builder/model_desc.py
View file @
63c0f891
...
@@ -9,7 +9,7 @@ import tensorflow as tf
...
@@ -9,7 +9,7 @@ import tensorflow as tf
import
six
import
six
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
.
input_source_bas
e
import
InputSource
from
.
.input_sourc
e
import
InputSource
from
..models.regularize
import
regularize_cost_from_collection
from
..models.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
63c0f891
...
@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
...
@@ -8,7 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
from
..tfutils.collection
import
freeze_collection
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
.input_source
import
PlaceholderInput
from
.
.
input_source
import
PlaceholderInput
__all__
=
[]
__all__
=
[]
...
...
tensorpack/graph_builder/utils.py
View file @
63c0f891
...
@@ -3,68 +3,15 @@
...
@@ -3,68 +3,15 @@
# File: utils.py
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
copy
from
six.moves
import
zip
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
operator
import
operator
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideToLocalVariable'
,
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideToLocalVariable'
,
'override_to_local_variable'
]
'override_to_local_variable'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
try
:
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
lst
[
idx
])
return
ret
@
contextmanager
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
if
enable
:
...
...
tensorpack/input_source/__init__.py
0 → 100644
View file @
63c0f891
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
iter_modules
import
os
import
os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
del
globals
()[
name
]
for
k
in
lst
:
if
not
k
.
startswith
(
'__'
):
globals
()[
k
]
=
p
.
__dict__
[
k
]
__all__
.
append
(
k
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[]
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
continue
if
module_name
.
startswith
(
'_'
):
continue
if
module_name
not
in
_SKIP
:
global_import
(
module_name
)
tensorpack/
graph_builder
/input_source.py
→
tensorpack/
input_source
/input_source.py
View file @
63c0f891
File moved
tensorpack/
graph_builder
/input_source_base.py
→
tensorpack/
input_source
/input_source_base.py
View file @
63c0f891
...
@@ -3,17 +3,68 @@
...
@@ -3,17 +3,68 @@
# File: input_source_base.py
# File: input_source_base.py
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
copy
import
six
import
six
from
six.moves
import
zip
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
.utils
import
get_sublist_by_names
,
get_tensors_inputs
from
..callbacks.base
import
CallbackFactory
from
..callbacks.base
import
CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'InputSource'
,
'remap_input_source'
]
__all__
=
[
'InputSource'
,
'remap_input_source'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
try
:
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
lst
[
idx
])
return
ret
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InputSource
(
object
):
class
InputSource
(
object
):
""" Base class for the abstract InputSource. """
""" Base class for the abstract InputSource. """
...
...
tensorpack/predict/base.py
View file @
63c0f891
...
@@ -9,7 +9,7 @@ import six
...
@@ -9,7 +9,7 @@ import six
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
...
...
tensorpack/predict/multigpu.py
View file @
63c0f891
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
.base
import
OnlinePredictor
from
.base
import
OnlinePredictor
__all__
=
[
'MultiTowerOfflinePredictor'
,
__all__
=
[
'MultiTowerOfflinePredictor'
,
...
...
tensorpack/tfutils/export.py
View file @
63c0f891
...
@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
...
@@ -10,7 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..graph_builder.model_desc
import
ModelDescBase
from
..graph_builder.model_desc
import
ModelDescBase
from
..
graph_builder.
input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..tfutils
import
TowerContext
,
sessinit
from
..tfutils
import
TowerContext
,
sessinit
...
...
tensorpack/train/config.py
View file @
63c0f891
...
@@ -12,7 +12,7 @@ from ..utils import logger
...
@@ -12,7 +12,7 @@ from ..utils import logger
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
from
..
graph_builder.input_source_bas
e
import
InputSource
from
..
input_sourc
e
import
InputSource
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
...
tensorpack/train/multigpu.py
View file @
63c0f891
...
@@ -8,7 +8,7 @@ import tensorflow as tf
...
@@ -8,7 +8,7 @@ import tensorflow as tf
from
..callbacks.graph
import
RunOp
from
..callbacks.graph
import
RunOp
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..
graph_builder.
input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..graph_builder.training
import
(
from
..graph_builder.training
import
(
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
SyncMultiGPUReplicatedBuilder
,
...
...
tensorpack/train/simple.py
View file @
63c0f891
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
from
..utils
import
logger
from
..
graph_builder.
input_source
import
FeedInput
,
QueueInput
from
..input_source
import
FeedInput
,
QueueInput
from
..graph_builder.training
import
SimpleBuilder
from
..graph_builder.training
import
SimpleBuilder
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment