Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c8d40e69
Commit
c8d40e69
authored
Jan 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
periodic trigger
parent
3152d495
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
12 deletions
+48
-12
examples/mnist-convnet.py
examples/mnist-convnet.py
+1
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+4
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+3
-3
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+37
-2
tensorpack/train/base.py
tensorpack/train/base.py
+3
-3
No files found.
examples/mnist-convnet.py
View file @
c8d40e69
...
...
@@ -141,7 +141,7 @@ def get_config():
dataflow
=
dataset_train
,
# the DataFlow instance for training
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
ModelSaver
(
),
# save the model after every epoch
PeriodicTrigger
(
ModelSaver
(),
every_k_steps
=
100
),
# save the model after every epoch
InferenceRunner
(
# run inference(for validation) after every epoch
dataset_test
,
# the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow
...
...
tensorpack/callbacks/base.py
View file @
c8d40e69
...
...
@@ -16,7 +16,7 @@ class Callback(object):
Attributes:
epoch_num(int): the epoch that have completed the update.
step_num(int): the
step number in the current epoch.
local_step(int): the local
step number in the current epoch.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
...
...
@@ -110,8 +110,8 @@ class Callback(object):
return
self
.
trainer
.
epoch_num
@
property
def
step_num
(
self
):
return
self
.
trainer
.
step_num
def
local_step
(
self
):
return
self
.
trainer
.
local_step
def
__str__
(
self
):
return
type
(
self
)
.
__name__
...
...
@@ -127,6 +127,7 @@ class ProxyCallback(Callback):
Args:
cb(Callback): the underlying callback
"""
assert
isinstance
(
cb
,
Callback
),
type
(
cb
)
self
.
cb
=
cb
def
_before_train
(
self
):
...
...
tensorpack/callbacks/steps.py
View file @
c8d40e69
...
...
@@ -55,7 +55,7 @@ class MaintainStepCounter(Callback):
self
.
gs_incr_var
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
self
.
local_step
=
tf
.
mod
(
tf
.
mod
(
self
.
gs_incr_var
,
self
.
trainer
.
config
.
step_per_epoch
,
name
=
LOCAL_STEP_OP_NAME
)
...
...
@@ -75,8 +75,8 @@ class ProgressBar(Callback):
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
def
_trigger_step
(
self
,
*
args
):
if
self
.
step_num
==
0
:
if
self
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
.
update
()
if
self
.
step_num
==
self
.
_total
-
1
:
if
self
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
tensorpack/callbacks/trigger.py
View file @
c8d40e69
...
...
@@ -6,10 +6,10 @@
from
abc
import
abstractmethod
,
ABCMeta
import
six
from
.base
import
Callback
from
.base
import
Callback
,
ProxyCallback
__all__
=
[
'Triggerable'
]
__all__
=
[
'Triggerable'
,
'PeriodicTrigger'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -43,3 +43,38 @@ class Triggerable(Callback):
def
_trigger_epoch
(
self
):
""" If used as a callback directly, run the trigger every epoch."""
self
.
trigger
()
class
PeriodicTrigger
(
ProxyCallback
):
"""
Trigger a :class:`Triggerable` callback every k steps or every k epochs.
"""
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
"""
Args:
triggerable (Triggerable): a Triggerable instance.
every_k_steps (int): trigger when ``local_step
%
k == 0``. Set to
None to disable.
every_k_epochs (int): trigger when ``epoch_num
%
k == 0``. Set to
None to disable.
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
"""
assert
isinstance
(
triggerable
,
Triggerable
),
type
(
triggerable
)
super
(
PeriodicTrigger
,
self
)
.
__init__
(
triggerable
)
assert
(
every_k_epochs
is
not
None
)
or
(
every_k_steps
is
not
None
),
\
"every_k_steps and every_k_epochs cannot be both None!"
self
.
_step_k
=
every_k_steps
self
.
_epoch_k
=
every_k_epochs
def
_trigger_step
(
self
,
*
args
):
if
self
.
_step_k
is
None
:
return
if
self
.
local_step
%
self
.
_step_k
==
0
:
self
.
cb
.
trigger
()
def
_trigger_epoch
(
self
,
*
args
):
if
self
.
_epoch_k
is
None
:
return
if
self
.
local_step
%
self
.
_epoch_k
==
0
:
self
.
cb
.
trigger
()
tensorpack/train/base.py
View file @
c8d40e69
...
...
@@ -42,7 +42,7 @@ class Trainer(object):
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the current epoch number.
step_num
(int): the current step number (in an epoch).
local_step
(int): the current step number (in an epoch).
"""
def
__init__
(
self
,
config
):
...
...
@@ -57,7 +57,7 @@ class Trainer(object):
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
epoch_num
=
self
.
config
.
starting_epoch
self
.
step_num
=
0
self
.
local_step
=
0
def
train
(
self
):
""" Start training """
...
...
@@ -163,7 +163,7 @@ class Trainer(object):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
for
self
.
step_num
in
range
(
self
.
config
.
step_per_epoch
):
for
self
.
local_step
in
range
(
self
.
config
.
step_per_epoch
):
if
self
.
coord
.
should_stop
():
return
fetch_data
=
self
.
run_step
()
# implemented by subclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment