Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
95cb6ba2
Commit
95cb6ba2
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add trainer.create_session, add verboes in RunOp
parent
ee1af311
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
36 additions
and
18 deletions
+36
-18
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+4
-0
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+11
-1
tensorpack/train/base.py
tensorpack/train/base.py
+12
-10
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+5
-3
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+3
-3
No files found.
tensorpack/callbacks/base.py
View file @
95cb6ba2
...
...
@@ -173,6 +173,10 @@ class Callback(object):
"""
return
self
.
_chief_only
@
chief_only
.
setter
def
chief_only
(
self
,
v
):
self
.
_chief_only
=
v
def
__str__
(
self
):
return
type
(
self
)
.
__name__
...
...
tensorpack/callbacks/graph.py
View file @
95cb6ba2
...
...
@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_as_trigger
=
True
,
run_step
=
False
):
run_before
=
True
,
run_as_trigger
=
True
,
run_step
=
False
,
verbose
=
False
):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples:
The `DQN Example
...
...
@@ -34,22 +36,30 @@ class RunOp(Callback):
self
.
run_before
=
run_before
self
.
run_as_trigger
=
run_as_trigger
self
.
run_step
=
run_step
self
.
verbose
=
verbose
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
def
_before_train
(
self
):
if
self
.
run_before
:
self
.
_print
()
self
.
_op
.
run
()
def
_trigger
(
self
):
if
self
.
run_as_trigger
:
self
.
_print
()
self
.
_op
.
run
()
def
_before_run
(
self
,
_
):
if
self
.
run_step
:
self
.
_print
()
return
[
self
.
_op
]
def
_print
(
self
):
if
self
.
verbose
:
logger
.
info
(
"Running Op {} ..."
.
format
(
self
.
_op
.
name
))
class
RunUpdateOps
(
RunOp
):
"""
...
...
tensorpack/train/base.py
View file @
95cb6ba2
...
...
@@ -9,8 +9,6 @@ import six
from
six.moves
import
range
import
tensorflow
as
tf
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
...
...
@@ -118,6 +116,7 @@ class Trainer(object):
self
.
monitors
=
Monitors
(
self
.
monitors
)
self
.
register_callback
(
self
.
monitors
)
# TODO cache per graph, avoid describing all towers
describe_model
()
# some final operations that might modify the graph
...
...
@@ -125,21 +124,24 @@ class Trainer(object):
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
# create session
logger
.
info
(
"Creating the session ..."
)
self
.
sess
=
self
.
config
.
session_creator
.
create_session
()
self
.
_monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
None
)
self
.
_create_session
()
logger
.
info
(
"Initializing the session ..."
)
# init session
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
def
_create_session
(
self
):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
self
.
sess
=
self
.
config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
@
abstractmethod
def
_setup
(
self
):
...
...
@@ -167,7 +169,7 @@ class Trainer(object):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
_monitor
ed_sess
.
should_stop
():
if
self
.
hook
ed_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
...
...
@@ -186,7 +188,7 @@ class Trainer(object):
raise
finally
:
self
.
_callbacks
.
after_train
()
self
.
_monitor
ed_sess
.
close
()
self
.
hook
ed_sess
.
close
()
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
...
...
tensorpack/train/distributed.py
View file @
95cb6ba2
...
...
@@ -165,8 +165,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'post_copy_barrier'
,
[
main_fetch
])
self
.
register_callback
(
RunOp
(
self
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
False
))
cb
=
RunOp
(
self
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
False
,
verbose
=
True
)
cb
.
chief_only
=
False
self
.
register_callback
(
cb
)
self
.
_set_session_creator
()
...
...
@@ -251,4 +253,4 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
logger
.
warn
(
"Global varable {} doesn't match a corresponding local var"
.
format
(
v
.
name
))
return
tf
.
group
(
*
post_init_ops
,
name
=
'
post_init_o
ps'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'
sync_variables_from_
ps'
)
tensorpack/train/input_source.py
View file @
95cb6ba2
...
...
@@ -242,7 +242,7 @@ class QueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
super
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
_
chief_only
=
False
cb
.
chief_only
=
False
trainer
.
register_callback
(
cb
)
def
get_input_tensors
(
self
):
...
...
tensorpack/train/multigpu.py
View file @
95cb6ba2
...
...
@@ -70,7 +70,7 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
if
var_strategy
==
'replicated'
:
# TODO ugly
logger
.
info
(
"
UPDATE_OPS from all GPUs will be kept in the collectio
n."
)
logger
.
info
(
"
In replicated mode, UPDATE_OPS from all GPUs will be ru
n."
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
for
idx
,
t
in
enumerate
(
towers
):
...
...
@@ -261,7 +261,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self
.
train_op
=
tf
.
group
(
*
train_ops
,
name
=
'train_op'
)
self
.
register_callback
(
RunOp
(
SyncMultiGPUTrainerReplicated
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
True
))
run_before
=
True
,
run_as_trigger
=
True
,
verbose
=
True
))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
...
...
@@ -279,7 +279,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name
=
split_name
[
1
:]
copy_from
=
var_by_name
[
'/'
.
join
(
split_name
)]
post_init_ops
.
append
(
v
.
assign
(
copy_from
.
read_value
()))
return
tf
.
group
(
*
post_init_ops
,
name
=
'
init_sync_vars
'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'
sync_variables_from_tower0
'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment