Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d9b96535
Commit
d9b96535
authored
Jul 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use hooks for feed_dict in FeedInput
parent
1e7fa5f9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
41 deletions
+39
-41
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+3
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+10
-12
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+23
-25
tensorpack/train/simple.py
tensorpack/train/simple.py
+3
-3
No files found.
tensorpack/callbacks/base.py
View file @
d9b96535
...
@@ -44,7 +44,9 @@ class Callback(object):
...
@@ -44,7 +44,9 @@ class Callback(object):
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
scope_name
=
type
(
self
)
.
__name__
scope_name
=
scope_name
.
replace
(
'_'
,
''
)
with
tf
.
name_scope
(
scope_name
):
self
.
_setup_graph
()
self
.
_setup_graph
()
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
...
...
tensorpack/callbacks/inference_runner.py
View file @
d9b96535
...
@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
...
@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.hooks
import
CallbackToHook
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
'DataParallelInferenceRunner'
]
'DataParallelInferenceRunner'
]
...
@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
...
@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
...
@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
...
@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
cbs
=
self
.
_input_source
.
get_callbacks
()
self
.
_hooks
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
...
@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
...
@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
# iterate over the data, and run the hooked session
self
.
_input_source
.
reset_state
()
self
.
_input_source
.
reset_state
()
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
feed
=
self
.
_input_source
.
next_feed
()
self
.
_hooked_sess
.
run
(
fetches
=
[])
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
...
@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
placeholder_names
=
[
k
.
name
+
':0'
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()]
placeholder_names
=
[
k
.
name
+
':0'
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()]
ret
=
[]
ret
=
[]
for
name
in
out_names
:
for
name
in
out_names
:
if
name
not
in
placeholder_names
:
assert
name
not
in
placeholder_names
,
"Currently inferencer don't support fetching placeholders!"
ret
.
append
(
self
.
_get_tensors_maybe_in_tower
([
name
])[
0
])
ret
.
append
(
self
.
_get_tensors_maybe_in_tower
([
name
])[
0
])
else
:
# requesting an input
idx
=
placeholder_names
.
index
(
name
)
ret
.
append
(
self
.
_input_tensors
[
idx
])
return
InferencerToHook
(
inf
,
ret
)
return
InferencerToHook
(
inf
,
ret
)
# TODO completely broken now!
# TODO some scripts to test
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
"""
"""
Not tested
. Don't use.
Broken
. Don't use.
"""
"""
# TODO some scripts to test
def
__init__
(
self
,
input
,
infs
,
gpus
):
def
__init__
(
self
,
input
,
infs
,
gpus
):
"""
"""
Args:
Args:
...
@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
model
=
self
.
trainer
.
model
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
"
DataParallel
InferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
# build graph
def
build_tower
(
k
):
def
build_tower
(
k
):
...
...
tensorpack/train/input_source.py
View file @
d9b96535
...
@@ -80,17 +80,6 @@ class InputSource(object):
...
@@ -80,17 +80,6 @@ class InputSource(object):
def
_reset_state
(
self
):
def
_reset_state
(
self
):
pass
pass
def
next_feed
(
self
):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return
self
.
_next_feed
()
@
abstractmethod
def
_next_feed
(
self
):
pass
def
size
(
self
):
def
size
(
self
):
"""
"""
Returns:
Returns:
...
@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
...
@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
def
_size
(
self
):
def
_size
(
self
):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
_next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
_reset_state
(
self
):
def
_reset_state
(
self
):
self
.
_input
.
reset_state
()
self
.
_input
.
reset_state
()
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
class
_FeedCallback
(
Callback
):
def
__init__
(
self
,
ds
,
placeholders
):
self
.
_ds
=
ds
self
.
_itr
=
self
.
_ds
.
get_data
()
self
.
_placeholders
=
placeholders
def
_before_run
(
self
,
_
):
dp
=
next
(
self
.
_itr
)
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
feed
=
dict
(
zip
(
self
.
_placeholders
,
dp
))
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
self
.
_ds
.
reset_state
()
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
Args:
Args:
...
@@ -138,28 +140,27 @@ class FeedInput(InputSource):
...
@@ -138,28 +140,27 @@ class FeedInput(InputSource):
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_repeat_ds
=
RepeatedData
(
self
.
ds
,
-
1
)
def
_size
(
self
):
def
_size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_repeat_ds
,
self
.
_all_placehdrs
)
self
.
reset_state
()
self
.
reset_state
()
def
_reset_state
(
self
):
def
_reset_state
(
self
):
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
self
.
_cb
.
_reset
()
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
def
_next_feed
(
self
):
def
_get_callbacks
(
self
):
dp
=
next
(
self
.
data_producer
)
return
[
self
.
_cb
]
assert
len
(
dp
)
==
len
(
self
.
_all_placehdrs
),
"[FeedInput] datapoints and inputs are of different length!"
return
dict
(
zip
(
self
.
_all_placehdrs
,
dp
))
# TODO completely broken now!
class
DataParallelFeedInput
(
FeedInput
):
class
DataParallelFeedInput
(
FeedInput
):
"""
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
Input by feeding k datapoints to k copies of placeholders located on k towers.
...
@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
...
@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
_
next_feed
(
self
,
cnt
=
None
):
def
next_feed
(
self
,
cnt
=
None
):
"""
"""
Args:
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
cnt: how many towers to feed to. Defaults to the total number of towers
...
@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
...
@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
def
_reset_state
(
self
):
def
_reset_state
(
self
):
pass
pass
def
_next_feed
(
self
):
return
{}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
...
...
tensorpack/train/simple.py
View file @
d9b96535
...
@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
...
@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_source
.
next_feed
()
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
model
=
self
.
model
model
=
self
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
get_callbacks
()
cbs
=
self
.
_input_source
.
get_callbacks
()
assert
len
(
cbs
)
==
0
,
"Feedinput has no callbacks!"
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
model
.
build_graph
(
self
.
inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment