Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e5f5da3c
Commit
e5f5da3c
authored
Dec 13, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Let MovingAverageSummary bind the train_op instead of the session.
parent
93819550
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
12 deletions
+35
-12
.travis.yml
.travis.yml
+1
-0
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+26
-8
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+6
-2
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-2
No files found.
.travis.yml
View file @
e5f5da3c
...
...
@@ -55,6 +55,7 @@ before_script:
# Check that these private names can be imported because tensorpack is using them
-
python -c "from tensorflow.python.client.session import _FetchHandler"
-
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
-
python -c "import tensorflow as tf; tf.Operation._add_control_input"
script
:
-
flake8 .
...
...
tensorpack/callbacks/summary.py
View file @
e5f5da3c
...
...
@@ -19,28 +19,46 @@ class MovingAverageSummary(Callback):
This callback is enabled by default.
Maintain the moving average of summarized tensors in every step,
by ops added to the collection.
Note that it only __maintains__ the moving averages in the graph,
Note that it only __maintains__ the moving averages by updating
the relevant variables in the graph,
the actual summary should be done in other callbacks.
"""
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_OPS_KEY
):
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_OPS_KEY
,
train_op
=
None
):
"""
Args:
collection(str): the collection of EMA-maintaining ops.
The default value would work with
the tensors you added by :func:`tfutils.summary.add_moving_summary()`,
but you can use other collections as well.
train_op (tf.Operation or str): the (name of) training op to associate the maintaing ops with.
If not provided, the EMA-maintaining ops will be hooked to
`trainer.hooked_session` and be executed in every iteration.
Otherwise, the EMA-maintaining ops will be executed whenever
the training op is executed.
"""
self
.
_collection
=
collection
self
.
_train_op
=
train_op
def
_setup_graph
(
self
):
ops
=
tf
.
get_collection
(
self
.
_collection
)
logger
.
info
(
"Maintain moving average summary of {} tensors in collection {}."
.
format
(
len
(
ops
),
self
.
_collection
))
ops
=
[
k
.
op
for
k
in
tf
.
get_collection
(
self
.
_collection
)]
if
self
.
_train_op
is
None
:
logger
.
info
(
"[MovingAverageSummary] {} operations in collection '{}' "
"will be run with session hooks."
.
format
(
len
(
ops
),
self
.
_collection
))
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'maintain_moving_average_summary'
)
self
.
_fetch
=
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
ema_op
)
else
:
if
isinstance
(
self
.
_train_op
,
tf
.
Tensor
):
self
.
_train_op
=
self
.
_train_op
.
op
if
not
isinstance
(
self
.
_train_op
,
tf
.
Operation
):
self
.
_train_op
=
self
.
graph
.
get_operation_by_name
(
self
.
_train_op
)
self
.
_train_op
.
_add_control_inputs
(
ops
)
logger
.
info
(
"[MovingAverageSummary] {} operations in collection '{}'"
" will be run together with operation '{}'."
.
format
(
len
(
ops
),
self
.
_collection
,
self
.
_train_op
.
name
))
def
_before_run
(
self
,
_
):
if
self
.
_train_op
is
None
:
return
self
.
_fetch
...
...
tensorpack/graph_builder/training.py
View file @
e5f5da3c
...
...
@@ -77,7 +77,7 @@ class DataParallelBuilder(GraphBuilder):
raise
ValueError
(
"Number of gradients from each tower is different! "
+
str
(
nvars
))
@
staticmethod
def
build_on_towers
(
def
call_for_each_tower
(
towers
,
func
,
devices
=
None
,
use_vs
=
None
):
"""
Run `func` on all GPUs (towers) and return the results.
...
...
@@ -119,6 +119,10 @@ class DataParallelBuilder(GraphBuilder):
ret
.
append
(
func
())
return
ret
@
staticmethod
def
build_on_towers
(
*
args
,
**
kwargs
):
return
DataParallelBuilder
.
call_for_each_tower
(
*
args
,
**
kwargs
)
class
SyncMultiGPUParameterServerBuilder
(
DataParallelBuilder
):
"""
...
...
tensorpack/train/trainers.py
View file @
e5f5da3c
...
...
@@ -53,7 +53,7 @@ class SimpleTrainer(SingleCostTrainer):
with
TrainTowerContext
(
''
):
grads
=
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
)()
opt
=
get_opt_fn
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'
m
in_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'
tra
in_op'
)
return
[]
...
...
@@ -404,7 +404,7 @@ class HorovodTrainer(SingleCostTrainer):
grads
=
self
.
allreduce
(
grads
)
opt
=
get_opt_fn
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'
m
in_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'
tra
in_op'
)
def
broadcast
(
self
):
logger
.
info
(
"Running horovod broadcast ..."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment