Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
76fa8e38
Commit
76fa8e38
authored
May 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput
parent
48f6c267
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
167 additions
and
100 deletions
+167
-100
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+32
-84
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+75
-12
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-4
tensorpack/train/utils.py
tensorpack/train/utils.py
+59
-0
No files found.
tensorpack/callbacks/inference_runner.py
View file @
76fa8e38
...
@@ -10,14 +10,14 @@ from tensorflow.python.training.monitored_session \
...
@@ -10,14 +10,14 @@ from tensorflow.python.training.monitored_session \
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
tqdm
import
tqdm
import
six
import
six
import
copy
from
six.moves
import
range
from
six.moves
import
zip
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_
op_tensor_name
,
get_
tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..train.input_source
import
TensorInput
,
FeedInput
from
..train.input_source
import
TensorInput
,
FeedInput
,
DataParallelFeedInput
from
..train.utils
import
get_tensors_inputs
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
...
@@ -60,11 +60,12 @@ class InferenceRunnerBase(Callback):
...
@@ -60,11 +60,12 @@ class InferenceRunnerBase(Callback):
"""
"""
Args:
Args:
input (InputSource): the input to use. Must have ``size()``.
input (InputSource): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list
): extra ``SessionRunHook`
` to run with the evaluation.
extra_hooks (list
[SessionRunHook]): extra :class:`SessionRunHook
` to run with the evaluation.
"""
"""
self
.
_input_source
=
input
self
.
_input_source
=
input
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
...
@@ -87,33 +88,17 @@ class InferenceRunnerBase(Callback):
...
@@ -87,33 +88,17 @@ class InferenceRunnerBase(Callback):
extra_hooks
=
[]
extra_hooks
=
[]
self
.
_extra_hooks
=
extra_hooks
self
.
_extra_hooks
=
extra_hooks
def
_setup_input_names
(
self
):
# just use all the placeholders, if input_name is None
if
self
.
input_names
is
None
:
inputs
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
self
.
input_names
=
[
x
.
name
for
x
in
inputs
]
# TODO sparse. even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# def get_name(x):
# if isinstance(x, tf.SparseTensor):
# return x.op.name.split('/')[0]
# return x.name
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
)
self
.
_input_source
.
setup
(
self
.
trainer
.
model
)
self
.
_setup_input_names
()
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
in_tensors
=
self
.
_find_input_tensors
()
assert
isinstance
(
in_tensors
,
list
),
in_tensors
def
fn
(
_
):
def
fn
(
_
):
in_tensors
=
self
.
_find_input_tensors
()
assert
isinstance
(
in_tensors
,
list
),
in_tensors
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_feed_tensors
=
self
.
_find_feed_tensors
()
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
def
_before_train
(
self
):
def
_before_train
(
self
):
...
@@ -128,10 +113,6 @@ class InferenceRunnerBase(Callback):
...
@@ -128,10 +113,6 @@ class InferenceRunnerBase(Callback):
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
pass
pass
@
abstractmethod
def
_find_feed_tensors
(
self
):
pass
@
abstractmethod
@
abstractmethod
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
pass
pass
...
@@ -143,8 +124,7 @@ class InferenceRunnerBase(Callback):
...
@@ -143,8 +124,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
# iterate over the data, and run the hooked session
self
.
_input_source
.
reset_state
()
self
.
_input_source
.
reset_state
()
for
_
in
tqdm
.
trange
(
self
.
_input_source
.
size
(),
**
get_tqdm_kwargs
()):
for
_
in
tqdm
.
trange
(
self
.
_input_source
.
size
(),
**
get_tqdm_kwargs
()):
dp
=
self
.
_input_source
.
next_feed
()
feed
=
self
.
_input_source
.
next_feed
()
feed
=
dict
(
zip
(
self
.
_feed_tensors
,
dp
))
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
...
@@ -160,19 +140,15 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -160,19 +140,15 @@ class InferenceRunner(InferenceRunnerBase):
Args:
Args:
ds (DataFlow): the DataFlow to run inferencer on.
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
infs (list): a list of `Inferencer` instances.
input_names(list): list of tensors to feed the dataflow to.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
Defaults to all the input placeholders.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
input
=
FeedInput
(
ds
)
input
=
FeedInput
(
ds
,
input_names
)
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
input_names
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
input_names
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
return
self
.
trainer
.
model
.
get_reused_placehdrs
()
return
self
.
_input_source
.
get_input_tensors
()
def
_find_feed_tensors
(
self
):
return
self
.
_get_tensors_maybe_in_tower
(
self
.
input_names
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
out_names
=
inf
.
get_output_tensors
()
...
@@ -191,7 +167,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -191,7 +167,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args:
Args:
input (TensorInput): the input to use. Must have ``size()``.
input (TensorInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
infs (list): list of :class:`Inferencer` to run.
input_names (list
): must be a subset of the names in InputDesc
.
input_names (list
[str]): same as in :class:`InferenceRunnerBase`
.
prefix(str): an prefix used to build the tower. Must be set
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
"""
...
@@ -199,36 +175,14 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -199,36 +175,14 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
input_names
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
input_names
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
def
_setup_input_names
(
self
):
super
(
FeedfreeInferenceRunner
,
self
)
.
_setup_input_names
()
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()])
for
n
in
self
.
input_names
:
opname
=
get_op_tensor_name
(
n
)[
0
]
assert
opname
in
placeholder_names
,
\
"[FeedfreeInferenceRunner] name {} is not a model input!"
.
format
(
n
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
# TODO move mapping to InputSource
tensors
=
self
.
_input_source
.
get_input_tensors
()
tensors
=
self
.
_input_source
.
get_input_tensors
()
placeholders
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_names
)
==
len
(
tensors
),
\
if
self
.
input_names
is
None
:
"[FeedfreeInferenceRunner] Input names must match the "
\
return
tensors
"length of the input data, but {} != {}"
.
format
(
len
(
self
.
input_names
),
len
(
tensors
))
# use placeholders for the unused inputs, use TensorInput for the used inpupts
ret
=
copy
.
copy
(
self
.
trainer
.
model
.
get_reused_placehdrs
())
for
name
,
tensor
in
zip
(
self
.
input_names
,
tensors
):
tname
=
get_op_tensor_name
(
name
)[
1
]
for
idx
,
hdr
in
enumerate
(
ret
):
if
hdr
.
name
==
tname
:
ret
[
idx
]
=
tensor
break
else
:
else
:
assert
tname
in
set
([
k
.
name
for
k
in
ret
]),
\
return
get_tensors_inputs
(
placeholders
,
tensors
,
self
.
input_names
)
"Input name {} is not among model inputs: {}!"
.
format
(
tname
,
ret
)
self
.
_input_tensors
=
ret
return
ret
def
_find_feed_tensors
(
self
):
return
[]
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
# all is tensorname
out_names
=
inf
.
get_output_tensors
()
# all is tensorname
...
@@ -243,22 +197,24 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -243,22 +197,24 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
ret
)
return
InferencerToHook
(
inf
,
ret
)
class
DataParallelInferenceRunner
(
InferenceRunner
):
class
DataParallelInferenceRunner
(
InferenceRunner
Base
):
def
__init__
(
self
,
ds
,
infs
,
gpus
,
input_names
=
None
):
def
__init__
(
self
,
ds
,
infs
,
gpus
,
input_names
=
None
):
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
ds
,
infs
,
input_names
)
self
.
_tower_names
=
[
TowerContext
.
get_predict_tower_name
(
k
)
for
k
in
range
(
len
(
gpus
))]
input
=
DataParallelFeedInput
(
ds
,
self
.
_tower_names
,
input_names
=
input_names
)
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
input_names
)
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
model
=
self
.
trainer
.
model
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
model
)
self
.
_input_source
.
setup
(
model
)
self
.
_setup_input_names
()
# build graph
# build graph
def
build_tower
(
k
):
def
build_tower
(
k
):
towername
=
TowerContext
.
get_predict_tower_name
(
k
)
# inputs (placeholders) for this tower only
# inputs (placeholders) for this tower only
input_tensors
=
model
.
build_placeholders
(
input_tensors
=
self
.
_input_source
.
get_input_tensors
()
prefix
=
towername
+
'/'
)
model
.
build_graph
(
input_tensors
)
model
.
build_graph
(
input_tensors
)
builder
=
PredictorTowerBuilder
(
build_tower
,
prefix
=
self
.
_prefix
)
builder
=
PredictorTowerBuilder
(
build_tower
,
prefix
=
self
.
_prefix
)
...
@@ -267,7 +223,6 @@ class DataParallelInferenceRunner(InferenceRunner):
...
@@ -267,7 +223,6 @@ class DataParallelInferenceRunner(InferenceRunner):
builder
.
build
(
t
)
builder
.
build
(
t
)
# setup feeds and hooks
# setup feeds and hooks
self
.
_feed_tensors
=
self
.
_find_feed_tensors
()
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
...
@@ -278,10 +233,6 @@ class DataParallelInferenceRunner(InferenceRunner):
...
@@ -278,10 +233,6 @@ class DataParallelInferenceRunner(InferenceRunner):
'/'
+
n
for
n
in
names
])
'/'
+
n
for
n
in
names
])
return
ret
return
ret
def
_find_feed_tensors
(
self
):
names
=
self
.
_duplicate_names_across_towers
(
self
.
input_names
)
return
get_tensors_by_names
(
names
)
class
InferencerToHookDataParallel
(
InferencerToHook
):
class
InferencerToHookDataParallel
(
InferencerToHook
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
super
(
DataParallelInferenceRunner
.
InferencerToHookDataParallel
,
self
)
.
__init__
(
inf
,
fetches
)
super
(
DataParallelInferenceRunner
.
InferencerToHookDataParallel
,
self
)
.
__init__
(
inf
,
fetches
)
...
@@ -322,16 +273,13 @@ class DataParallelInferenceRunner(InferenceRunner):
...
@@ -322,16 +273,13 @@ class DataParallelInferenceRunner(InferenceRunner):
nr_tower
=
len
(
self
.
_gpus
)
nr_tower
=
len
(
self
.
_gpus
)
with
tqdm
.
tqdm
(
total
=
total
,
**
get_tqdm_kwargs
())
as
pbar
:
with
tqdm
.
tqdm
(
total
=
total
,
**
get_tqdm_kwargs
())
as
pbar
:
while
total
>=
nr_tower
:
while
total
>=
nr_tower
:
dps
=
[]
feed
=
self
.
_input_source
.
next_feed
()
for
k
in
self
.
_gpus
:
dps
.
extend
(
self
.
_input_source
.
next_feed
())
feed
=
dict
(
zip
(
self
.
_feed_tensors
,
dps
))
self
.
_parallel_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
self
.
_parallel_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
nr_tower
)
pbar
.
update
(
nr_tower
)
total
-=
nr_tower
total
-=
nr_tower
# take care of the rest
# take care of the rest
while
total
>
0
:
while
total
>
0
:
dp
=
self
.
_input_source
.
next_feed
()
feed
=
self
.
_input_source
.
next_feed
(
cnt
=
1
)
feed
=
dict
(
zip
(
self
.
_feed_tensors
[:
len
(
dp
)],
dp
))
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
1
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
tensorpack/train/input_source.py
View file @
76fa8e38
...
@@ -12,8 +12,9 @@ except ImportError:
...
@@ -12,8 +12,9 @@ except ImportError:
from
itertools
import
chain
from
itertools
import
chain
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
six
import
six
from
six.moves
import
range
from
six.moves
import
range
,
zip
from
.utils
import
get_placeholders_by_names
from
..dataflow
import
DataFlow
,
RepeatedData
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
...
@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
...
@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
'DataParallelFeedInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
...
@@ -38,8 +39,9 @@ class InputSource(object):
...
@@ -38,8 +39,9 @@ class InputSource(object):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
"""
"""
Returns:
Returns:
list: A list of tensors corresponding to the inputs of the model.
list: A list of tensors corresponding to the inputs of the model,
Always create and return a list of new input tensors when called.
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
"""
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
...
@@ -53,27 +55,37 @@ class InputSource(object):
...
@@ -53,27 +55,37 @@ class InputSource(object):
pass
pass
def
next_feed
(
self
):
def
next_feed
(
self
):
return
[]
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return
{}
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
,
input_names
=
None
):
"""
"""
Args:
Args:
ds (DataFlow): the input DataFlow.
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_input_names
=
input_names
def
size
(
self
):
def
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
if
self
.
_input_names
is
None
:
rds
.
reset_state
()
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
self
.
data_producer
=
rds
.
get_data
()
else
:
self
.
_placehdrs_to_feed
=
get_placeholders_by_names
(
self
.
_all_placehdrs
,
self
.
_input_names
)
self
.
reset_state
()
def
reset_state
(
self
):
def
reset_state
(
self
):
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
...
@@ -81,10 +93,61 @@ class FeedInput(InputSource):
...
@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self
.
data_producer
=
rds
.
get_data
()
self
.
data_producer
=
rds
.
get_data
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
return
self
.
input
_placehdrs
return
self
.
_all
_placehdrs
def
next_feed
(
self
):
def
next_feed
(
self
):
return
next
(
self
.
data_producer
)
dp
=
next
(
self
.
data_producer
)
return
dict
(
zip
(
self
.
_placehdrs_to_feed
,
dp
))
class
DataParallelFeedInput
(
FeedInput
):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def
__init__
(
self
,
ds
,
tower_names
,
input_names
=
None
):
super
(
DataParallelFeedInput
,
self
)
.
__init__
(
ds
,
input_names
)
self
.
_tower_names
=
tower_names
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
model
):
self
.
_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
model
.
build_placeholders
(
prefix
=
tname
+
'/'
))
# apply input mapping and store results in feed_placehdrs_per_tower
if
self
.
_input_names
is
None
:
self
.
_feed_placehdrs_per_tower
=
self
.
_placehdrs_per_tower
else
:
for
phdrs
,
tname
in
zip
(
self
.
_placehdrs_per_tower
,
self
.
_tower_names
):
input_names
=
[
tname
+
'/'
+
n
for
n
in
self
.
_input_names
]
# input_names to be used for this specific tower
self
.
_feed_placehdrs_per_tower
.
append
(
get_placeholders_by_names
(
phdrs
,
input_names
))
self
.
reset_state
()
def
get_input_tensors
(
self
):
# return placeholders for each tower
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
next_feed
(
self
,
cnt
=
None
):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if
cnt
is
None
:
cnt
=
self
.
_nr_tower
feed
=
{}
for
t
in
range
(
cnt
):
dp
=
next
(
self
.
data_producer
)
f
=
dict
(
zip
(
self
.
_feed_placehdrs_per_tower
[
t
],
dp
))
feed
.
update
(
f
)
return
feed
class
FeedfreeInput
(
InputSource
):
class
FeedfreeInput
(
InputSource
):
...
...
tensorpack/train/trainer.py
View file @
76fa8e38
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
from
..utils
import
logger
...
@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
...
@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
dp
=
self
.
_input_source
.
next_feed
()
feed
=
self
.
_input_source
.
next_feed
()
feed
=
dict
(
zip
(
self
.
inputs
,
dp
))
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
...
...
tensorpack/train/utils.py
0 → 100644
View file @
76fa8e38
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
copy
from
six.moves
import
zip
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'get_tensors_inputs'
,
'get_placeholders_by_names'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_placeholders_by_names
(
placeholders
,
names
):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
ret
=
[]
for
name
in
names
:
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
.
append
(
placeholders
[
idx
])
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment