Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f73717ab
Commit
f73717ab
authored
May 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix UPDATE_OPS collection (#81)
parent
89b0c256
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
29 deletions
+41
-29
CHANGES.md
CHANGES.md
+1
-2
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+29
-3
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+3
-3
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-2
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+1
-16
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+3
-3
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+2
-0
No files found.
CHANGES.md
View file @
f73717ab
...
...
@@ -6,12 +6,11 @@ The backward compatibilty will be __preserved for several months__, with a depre
so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also change
s API
and those are not listed here.
TensorFlow itself also change
d APIs before 1.0
and those are not listed here.
+
[
2017/05/06
](
https://github.com/ppwwyyxx/tensorpack/commit/0774ec66e66075486f6a36aba63cc2a151b9fec8
)
.
`replace_get_variable`
was deprecated in favor of the official
`custom_getter`
interface.
`{freeze,remap}_get_variable`
was renamed to
`{freeze,remap}_variables`
.
+
[
2017/04/09
](
https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8
)
.
`ParamRestore`
was renamed to
`DictRestore`
.
+
[
2017/03/16
](
https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730
)
.
...
...
tensorpack/callbacks/graph.py
View file @
f73717ab
...
...
@@ -5,21 +5,25 @@
""" Graph related callbacks"""
import
tensorflow
as
tf
from
..utils
import
logger
from
.base
import
Callback
__all__
=
[
'RunOp'
]
__all__
=
[
'RunOp'
,
'RunUpdateOps'
]
class
RunOp
(
Callback
):
""" Run an Op. """
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_as_trigger
=
True
):
run_before
=
True
,
run_as_trigger
=
True
,
run_step
=
False
):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
Examples:
The `DQN Example
...
...
@@ -29,10 +33,15 @@ class RunOp(Callback):
self
.
setup_func
=
setup_func
self
.
run_before
=
run_before
self
.
run_as_trigger
=
run_as_trigger
self
.
run_step
=
run_step
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
def
_before_run
(
self
,
_
):
if
self
.
run_step
:
return
[
self
.
_op
]
def
_before_train
(
self
):
if
self
.
run_before
:
self
.
_op
.
run
()
...
...
@@ -40,3 +49,20 @@ class RunOp(Callback):
def
_trigger
(
self
):
if
self
.
run_as_trigger
:
self
.
_op
.
run
()
class
RunUpdateOps
(
RunOp
):
"""
Run ops from the collection UPDATE_OPS every step
"""
def
__init__
(
self
,
collection
=
tf
.
GraphKeys
.
UPDATE_OPS
):
def
f
():
ops
=
tf
.
get_collection
(
collection
)
if
ops
:
logger
.
info
(
"Applying UPDATE_OPS collection of {} ops."
.
format
(
len
(
ops
)))
return
tf
.
group
(
*
ops
,
name
=
'update_ops'
)
else
:
return
tf
.
no_op
(
name
=
'empty_update_ops'
)
super
(
RunUpdateOps
,
self
)
.
__init__
(
f
,
run_before
=
False
,
run_as_trigger
=
False
,
run_step
=
True
)
tensorpack/models/model_desc.py
View file @
f73717ab
...
...
@@ -128,9 +128,9 @@ class ModelDesc(object):
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
This function also applies t
fslim collections to the cost automatically,
including ``tf.GraphKeys.REGULARIZATION_LOSSES`` and ``tf.GraphKeys.UPDATE_OPS``
.
This is b
ecause slim users would expect the regularizer being automatically applied once used in slim layers.
This function also applies t
he collection
``tf.GraphKeys.REGULARIZATION_LOSSES``to the cost automatically
.
B
ecause slim users would expect the regularizer being automatically applied once used in slim layers.
"""
cost
=
self
.
_get_cost
()
return
apply_slim_collections
(
cost
)
...
...
tensorpack/predict/base.py
View file @
f73717ab
...
...
@@ -10,7 +10,7 @@ import six
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
..utils.argtools
import
memoized
from
..utils.naming
import
SUMMARY_BACKUP
_KEYS
from
..utils.naming
import
TOWER_FREEZE
_KEYS
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils.collection
import
freeze_collection
...
...
@@ -188,7 +188,7 @@ class PredictorTowerBuilder(object):
# No matter where this get called, clear any existing name scope.
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP
_KEYS
),
\
freeze_collection
(
TOWER_FREEZE
_KEYS
),
\
TowerContext
(
towername
,
device
=
device
,
is_training
=
False
):
self
.
_fn
(
tower
)
...
...
tensorpack/tfutils/model_utils.py
View file @
f73717ab
...
...
@@ -8,7 +8,6 @@ from tabulate import tabulate
from
..utils
import
logger
from
.summary
import
add_moving_summary
from
.tower
import
get_current_tower_context
__all__
=
[
'describe_model'
,
'get_shape_str'
,
'apply_slim_collections'
]
...
...
@@ -53,10 +52,7 @@ def get_shape_str(tensors):
def
apply_slim_collections
(
cost
):
"""
Apply slim collections to the cost, including:
1. adding the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
2. make the cost depend on ``tf.GraphKeys.UPDATE_OPS``.
Add the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
Args:
cost: a scalar tensor
...
...
@@ -70,15 +66,4 @@ def apply_slim_collections(cost):
reg_loss
=
tf
.
add_n
(
list
(
regulization_losses
),
name
=
"regularize_loss"
)
cost
=
tf
.
add
(
reg_loss
,
cost
,
name
=
'total_cost'
)
add_moving_summary
(
reg_loss
,
cost
)
# As these batch-norm statistics quickly accumulate, there is no significant loss of accuracy
# if only the main tower handles all batch-normalization updates, which are then shared across
# the towers
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
is_main_training_tower
:
non_grad_updates
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
))
if
non_grad_updates
:
logger
.
info
(
"Applying UPDATE_OPS collection from the first tower on cost."
)
with
tf
.
control_dependencies
(
non_grad_updates
):
cost
=
tf
.
identity
(
cost
,
name
=
'cost_with_update'
)
return
cost
tensorpack/train/multigpu.py
View file @
f73717ab
...
...
@@ -9,7 +9,7 @@ import re
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils.naming
import
SUMMARY_BACKUP
_KEYS
from
..utils.naming
import
TOWER_FREEZE
_KEYS
from
..utils.concurrency
import
LoopThread
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
...
@@ -50,8 +50,8 @@ class MultiGPUTrainer(Trainer):
ret
.
append
(
func
())
if
idx
==
0
:
# avoid repeated summary from each device
backup
=
backup_collection
(
SUMMARY_BACKUP
_KEYS
)
# avoid repeated summary
& update_ops
from each device
backup
=
backup_collection
(
TOWER_FREEZE
_KEYS
)
restore_collection
(
backup
)
return
ret
...
...
tensorpack/utils/naming.py
View file @
f73717ab
...
...
@@ -24,6 +24,8 @@ INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
__all__
=
[
x
for
x
in
all_local_names
if
x
.
isupper
()]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment