Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dc31efa4
Commit
dc31efa4
authored
Jul 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change trigger_epoch to trigger in CallbackFactory
parent
c223f223
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
12 deletions
+16
-12
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+13
-6
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-4
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+3
-2
No files found.
tensorpack/callbacks/base.py
View file @
dc31efa4
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
ABCMeta
from
abc
import
ABCMeta
import
six
import
six
from
..utils.develop
import
log_deprecated
from
..tfutils.common
import
get_op_or_tensor_by_name
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
...
@@ -255,17 +256,23 @@ class CallbackFactory(Callback):
...
@@ -255,17 +256,23 @@ class CallbackFactory(Callback):
"""
"""
Create a callback with some lambdas.
Create a callback with some lambdas.
"""
"""
def
__init__
(
self
,
setup_graph
=
None
,
before_train
=
None
,
def
__init__
(
self
,
setup_graph
=
None
,
before_train
=
None
,
trigger
=
None
,
trigger_epoch
=
None
,
after_train
=
None
):
after_train
=
None
,
trigger_epoch
=
None
):
"""
"""
Each lambda takes ``self`` as the only argument.
Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated.
"""
"""
self
.
_cb_setup_graph
=
setup_graph
self
.
_cb_setup_graph
=
setup_graph
self
.
_cb_before_train
=
before_train
self
.
_cb_before_train
=
before_train
self
.
_cb_trigger
_epoch
=
trigger_epoch
self
.
_cb_trigger
=
trigger
self
.
_cb_after_train
=
after_train
self
.
_cb_after_train
=
after_train
if
trigger_epoch
:
self
.
_cb_trigger
=
trigger_epoch
log_deprecated
(
"CallbackFactory(trigger_epoch=)"
,
"Use trigger instead."
,
"2017-11-15"
)
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
if
self
.
_cb_setup_graph
:
if
self
.
_cb_setup_graph
:
self
.
_cb_setup_graph
(
self
)
self
.
_cb_setup_graph
(
self
)
...
@@ -274,9 +281,9 @@ class CallbackFactory(Callback):
...
@@ -274,9 +281,9 @@ class CallbackFactory(Callback):
if
self
.
_cb_before_train
:
if
self
.
_cb_before_train
:
self
.
_cb_before_train
(
self
)
self
.
_cb_before_train
(
self
)
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
if
self
.
_cb_trigger
_epoch
:
if
self
.
_cb_trigger
:
self
.
_cb_trigger
_epoch
(
self
)
self
.
_cb_trigger
(
self
)
def
_after_train
(
self
):
def
_after_train
(
self
):
if
self
.
_cb_after_train
:
if
self
.
_cb_after_train
:
...
...
tensorpack/callbacks/group.py
View file @
dc31efa4
...
@@ -98,7 +98,3 @@ class Callbacks(Callback):
...
@@ -98,7 +98,3 @@ class Callbacks(Callback):
def
_after_epoch
(
self
):
def
_after_epoch
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
cb
.
after_epoch
()
cb
.
after_epoch
()
def
append
(
self
,
cb
):
assert
isinstance
(
cb
,
Callback
)
self
.
cbs
.
append
(
cb
)
tensorpack/tfutils/summary.py
View file @
dc31efa4
...
@@ -163,13 +163,14 @@ def add_moving_summary(v, *args, **kwargs):
...
@@ -163,13 +163,14 @@ def add_moving_summary(v, *args, **kwargs):
for
x
in
v
:
for
x
in
v
:
assert
isinstance
(
x
,
tf
.
Tensor
),
x
assert
isinstance
(
x
,
tf
.
Tensor
),
x
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
gs
=
get_global_step_var
()
# TODO will produce variable tower0/xxx?
# TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO not saved under distributed
# TODO use zero_debias
# TODO use zero_debias
# TODO create EMA for each variable separately, so that the maintain ops
# TODO create EMA for each variable separately, so that the maintain ops
# have a decent name (rather than EMA)
# have a decent name (rather than EMA)
gs
=
get_global_step_var
()
# clear namescope, otherwise the variable names will have duplicated name scope
with
tf
.
device
(
gs
.
device
):
with
tf
.
name_scope
(
None
),
tf
.
device
(
gs
.
device
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
gs
,
name
=
'EMA'
)
decay
,
num_updates
=
gs
,
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
avg_maintain_op
=
averager
.
apply
(
v
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment