Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d9b96535
You need to sign in or sign up before continuing.
Commit
d9b96535
authored
Jul 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use hooks for feed_dict in FeedInput
parent
1e7fa5f9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
41 deletions
+39
-41
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+3
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+10
-12
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+23
-25
tensorpack/train/simple.py
tensorpack/train/simple.py
+3
-3
No files found.
tensorpack/callbacks/base.py
View file @
d9b96535
...
@@ -44,7 +44,9 @@ class Callback(object):
...
@@ -44,7 +44,9 @@ class Callback(object):
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
scope_name
=
type
(
self
)
.
__name__
scope_name
=
scope_name
.
replace
(
'_'
,
''
)
with
tf
.
name_scope
(
scope_name
):
self
.
_setup_graph
()
self
.
_setup_graph
()
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
...
...
tensorpack/callbacks/inference_runner.py
View file @
d9b96535
...
@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
...
@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.hooks
import
CallbackToHook
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
'DataParallelInferenceRunner'
]
'DataParallelInferenceRunner'
]
...
@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
...
@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
...
@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
...
@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
cbs
=
self
.
_input_source
.
get_callbacks
()
self
.
_hooks
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
...
@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
...
@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
# iterate over the data, and run the hooked session
self
.
_input_source
.
reset_state
()
self
.
_input_source
.
reset_state
()
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
feed
=
self
.
_input_source
.
next_feed
()
self
.
_hooked_sess
.
run
(
fetches
=
[])
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
...
@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
placeholder_names
=
[
k
.
name
+
':0'
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()]
placeholder_names
=
[
k
.
name
+
':0'
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()]
ret
=
[]
ret
=
[]
for
name
in
out_names
:
for
name
in
out_names
:
if
name
not
in
placeholder_names
:
assert
name
not
in
placeholder_names
,
"Currently inferencer don't support fetching placeholders!"
ret
.
append
(
self
.
_get_tensors_maybe_in_tower
([
name
])[
0
])
ret
.
append
(
self
.
_get_tensors_maybe_in_tower
([
name
])[
0
])
else
:
# requesting an input
idx
=
placeholder_names
.
index
(
name
)
ret
.
append
(
self
.
_input_tensors
[
idx
])
return
InferencerToHook
(
inf
,
ret
)
return
InferencerToHook
(
inf
,
ret
)
# TODO completely broken now!
# TODO some scripts to test
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
"""
"""
Not tested
. Don't use.
Broken
. Don't use.
"""
"""
# TODO some scripts to test
def
__init__
(
self
,
input
,
infs
,
gpus
):
def
__init__
(
self
,
input
,
infs
,
gpus
):
"""
"""
Args:
Args:
...
@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
model
=
self
.
trainer
.
model
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
"
DataParallel
InferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
# build graph
def
build_tower
(
k
):
def
build_tower
(
k
):
...
...
tensorpack/train/input_source.py
View file @
d9b96535
...
@@ -80,17 +80,6 @@ class InputSource(object):
...
@@ -80,17 +80,6 @@ class InputSource(object):
def
_reset_state
(
self
):
def
_reset_state
(
self
):
pass
pass
def
next_feed
(
self
):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return
self
.
_next_feed
()
@
abstractmethod
def
_next_feed
(
self
):
pass
def
size
(
self
):
def
size
(
self
):
"""
"""
Returns:
Returns:
...
@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
...
@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
def
_size
(
self
):
def
_size
(
self
):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
_next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
_reset_state
(
self
):
def
_reset_state
(
self
):
self
.
_input
.
reset_state
()
self
.
_input
.
reset_state
()
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
class
_FeedCallback
(
Callback
):
def
__init__
(
self
,
ds
,
placeholders
):
self
.
_ds
=
ds
self
.
_itr
=
self
.
_ds
.
get_data
()
self
.
_placeholders
=
placeholders
def
_before_run
(
self
,
_
):
dp
=
next
(
self
.
_itr
)
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
feed
=
dict
(
zip
(
self
.
_placeholders
,
dp
))
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
self
.
_ds
.
reset_state
()
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
Args:
Args:
...
@@ -138,28 +140,27 @@ class FeedInput(InputSource):
...
@@ -138,28 +140,27 @@ class FeedInput(InputSource):
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_repeat_ds
=
RepeatedData
(
self
.
ds
,
-
1
)
def
_size
(
self
):
def
_size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_repeat_ds
,
self
.
_all_placehdrs
)
self
.
reset_state
()
self
.
reset_state
()
def
_reset_state
(
self
):
def
_reset_state
(
self
):
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
self
.
_cb
.
_reset
()
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
def
_next_feed
(
self
):
def
_get_callbacks
(
self
):
dp
=
next
(
self
.
data_producer
)
return
[
self
.
_cb
]
assert
len
(
dp
)
==
len
(
self
.
_all_placehdrs
),
"[FeedInput] datapoints and inputs are of different length!"
return
dict
(
zip
(
self
.
_all_placehdrs
,
dp
))
# TODO completely broken now!
class
DataParallelFeedInput
(
FeedInput
):
class
DataParallelFeedInput
(
FeedInput
):
"""
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
Input by feeding k datapoints to k copies of placeholders located on k towers.
...
@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
...
@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
_
next_feed
(
self
,
cnt
=
None
):
def
next_feed
(
self
,
cnt
=
None
):
"""
"""
Args:
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
cnt: how many towers to feed to. Defaults to the total number of towers
...
@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
...
@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
def
_reset_state
(
self
):
def
_reset_state
(
self
):
pass
pass
def
_next_feed
(
self
):
return
{}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
...
...
tensorpack/train/simple.py
View file @
d9b96535
...
@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
...
@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_source
.
next_feed
()
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
model
=
self
.
model
model
=
self
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
get_callbacks
()
cbs
=
self
.
_input_source
.
get_callbacks
()
assert
len
(
cbs
)
==
0
,
"Feedinput has no callbacks!"
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
model
.
build_graph
(
self
.
inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment