Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e2f9798a
Commit
e2f9798a
authored
Jul 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
replace setup_training by get_callbacks in InputSource
parent
14c564cc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
18 deletions
+29
-18
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+4
-0
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+4
-1
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+18
-16
tensorpack/train/simple.py
tensorpack/train/simple.py
+3
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e2f9798a
...
...
@@ -84,6 +84,8 @@ class InferenceRunnerBase(Callback):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
...
...
@@ -189,6 +191,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
def
build_tower
(
k
):
...
...
tensorpack/train/feedfree.py
View file @
e2f9798a
...
...
@@ -37,7 +37,10 @@ class FeedfreeTrainerBase(Trainer):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_input_source
.
setup_training
(
self
)
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
input_callbacks
=
self
.
_input_source
.
get_callbacks
()
for
cb
in
input_callbacks
:
self
.
register_callback
(
cb
)
def
run_step
(
self
):
""" Simply run ``self.train_op``."""
...
...
tensorpack/train/input_source.py
View file @
e2f9798a
...
...
@@ -53,8 +53,12 @@ class InputSource(object):
"""
pass
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
.
get_inputs_desc
())
def
get_callbacks
(
self
):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return
[]
@
abstractmethod
def
reset_state
(
self
):
...
...
@@ -248,11 +252,10 @@ class QueueInput(FeedfreeInput):
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_queue_feedpoint
)
def
setup_training
(
self
,
trainer
):
super
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
def
get_callbacks
(
self
):
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
chief_only
=
False
trainer
.
register_callback
(
cb
)
return
[
cb
]
def
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
...
...
@@ -321,9 +324,10 @@ class BatchQueueInput(FeedfreeInput):
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
def
setup_training
(
self
,
trainer
):
super
(
BatchQueueInput
,
self
)
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_callbacks
(
self
):
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
chief_only
=
False
return
[
cb
]
def
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
...
...
@@ -461,13 +465,13 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_input
.
setup
(
inputs
)
self
.
setup_staging_areas
()
def
setup_training
(
self
,
trainer
):
self
.
_input
.
setup_training
(
trainer
)
self
.
setup_staging_areas
()
def
get_callbacks
(
self
):
cbs
=
self
.
_input
.
get_callbacks
()
trainer
.
register_callback
(
cbs
.
append
(
StagingInputWrapper
.
StagingCallback
(
self
.
get_stage_op
(),
self
.
get_unstage_op
(),
self
.
_nr_stage
))
return
cbs
def
setup_staging_areas
(
self
):
logger
.
info
(
"Setting up StagingArea for GPU prefetching ..."
)
...
...
@@ -531,10 +535,8 @@ class ReorderInputSource(FeedfreeInput):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup
(
inputs
)
def
setup_training
(
self
,
trainer
):
inputs
=
trainer
.
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup_training
(
trainer
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
...
...
tensorpack/train/simple.py
View file @
e2f9798a
...
...
@@ -35,8 +35,10 @@ class SimpleTrainer(Trainer):
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
self
.
_input_source
.
setup_training
(
self
)
model
=
self
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
get_callbacks
()
assert
len
(
cbs
)
==
0
,
"Feedinput has no callbacks!"
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment