Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
76fa8e38
Commit
76fa8e38
authored
May 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput
parent
48f6c267
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
167 additions
and
100 deletions
+167
-100
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+32
-84
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+75
-12
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-4
tensorpack/train/utils.py
tensorpack/train/utils.py
+59
-0
No files found.
tensorpack/callbacks/inference_runner.py
View file @
76fa8e38
This diff is collapsed.
Click to expand it.
tensorpack/train/input_source.py
View file @
76fa8e38
...
@@ -12,8 +12,9 @@ except ImportError:
...
@@ -12,8 +12,9 @@ except ImportError:
from
itertools
import
chain
from
itertools
import
chain
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
six
import
six
from
six.moves
import
range
from
six.moves
import
range
,
zip
from
.utils
import
get_placeholders_by_names
from
..dataflow
import
DataFlow
,
RepeatedData
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
...
@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
...
@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
'DataParallelFeedInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
...
@@ -38,8 +39,9 @@ class InputSource(object):
...
@@ -38,8 +39,9 @@ class InputSource(object):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
"""
"""
Returns:
Returns:
list: A list of tensors corresponding to the inputs of the model.
list: A list of tensors corresponding to the inputs of the model,
Always create and return a list of new input tensors when called.
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
"""
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
...
@@ -53,27 +55,37 @@ class InputSource(object):
...
@@ -53,27 +55,37 @@ class InputSource(object):
pass
pass
def
next_feed
(
self
):
def
next_feed
(
self
):
return
[]
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return
{}
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
,
input_names
=
None
):
"""
"""
Args:
Args:
ds (DataFlow): the input DataFlow.
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_input_names
=
input_names
def
size
(
self
):
def
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
if
self
.
_input_names
is
None
:
rds
.
reset_state
()
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
self
.
data_producer
=
rds
.
get_data
()
else
:
self
.
_placehdrs_to_feed
=
get_placeholders_by_names
(
self
.
_all_placehdrs
,
self
.
_input_names
)
self
.
reset_state
()
def
reset_state
(
self
):
def
reset_state
(
self
):
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
...
@@ -81,10 +93,61 @@ class FeedInput(InputSource):
...
@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self
.
data_producer
=
rds
.
get_data
()
self
.
data_producer
=
rds
.
get_data
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
return
self
.
input
_placehdrs
return
self
.
_all
_placehdrs
def
next_feed
(
self
):
def
next_feed
(
self
):
return
next
(
self
.
data_producer
)
dp
=
next
(
self
.
data_producer
)
return
dict
(
zip
(
self
.
_placehdrs_to_feed
,
dp
))
class
DataParallelFeedInput
(
FeedInput
):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def
__init__
(
self
,
ds
,
tower_names
,
input_names
=
None
):
super
(
DataParallelFeedInput
,
self
)
.
__init__
(
ds
,
input_names
)
self
.
_tower_names
=
tower_names
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
model
):
self
.
_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
model
.
build_placeholders
(
prefix
=
tname
+
'/'
))
# apply input mapping and store results in feed_placehdrs_per_tower
if
self
.
_input_names
is
None
:
self
.
_feed_placehdrs_per_tower
=
self
.
_placehdrs_per_tower
else
:
for
phdrs
,
tname
in
zip
(
self
.
_placehdrs_per_tower
,
self
.
_tower_names
):
input_names
=
[
tname
+
'/'
+
n
for
n
in
self
.
_input_names
]
# input_names to be used for this specific tower
self
.
_feed_placehdrs_per_tower
.
append
(
get_placeholders_by_names
(
phdrs
,
input_names
))
self
.
reset_state
()
def
get_input_tensors
(
self
):
# return placeholders for each tower
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
next_feed
(
self
,
cnt
=
None
):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if
cnt
is
None
:
cnt
=
self
.
_nr_tower
feed
=
{}
for
t
in
range
(
cnt
):
dp
=
next
(
self
.
data_producer
)
f
=
dict
(
zip
(
self
.
_feed_placehdrs_per_tower
[
t
],
dp
))
feed
.
update
(
f
)
return
feed
class
FeedfreeInput
(
InputSource
):
class
FeedfreeInput
(
InputSource
):
...
...
tensorpack/train/trainer.py
View file @
76fa8e38
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
from
..utils
import
logger
...
@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
...
@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
dp
=
self
.
_input_source
.
next_feed
()
feed
=
self
.
_input_source
.
next_feed
()
feed
=
dict
(
zip
(
self
.
inputs
,
dp
))
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
...
...
tensorpack/train/utils.py
0 → 100644
View file @
76fa8e38
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
copy
from
six.moves
import
zip
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'get_tensors_inputs'
,
'get_placeholders_by_names'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_placeholders_by_names
(
placeholders
,
names
):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
ret
=
[]
for
name
in
names
:
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
.
append
(
placeholders
[
idx
])
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment