Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
589a8a35
Commit
589a8a35
authored
Jan 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make progressbar a callback
parent
a59e46cd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
28 deletions
+59
-28
scripts/dump-model-params.py
scripts/dump-model-params.py
+1
-2
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+18
-8
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+2
-1
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+18
-3
tensorpack/train/base.py
tensorpack/train/base.py
+14
-10
tensorpack/train/config.py
tensorpack/train/config.py
+6
-4
No files found.
scripts/dump-model-params.py
View file @
589a8a35
...
...
@@ -10,7 +10,6 @@ import imp
from
tensorpack
import
TowerContext
,
logger
,
ModelFromMetaGraph
from
tensorpack.tfutils
import
sessinit
,
varmanip
from
tensorpack.utils.naming
import
EXTRA_SAVE_VARS_KEY
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
help
=
'config file'
)
...
...
@@ -44,7 +43,7 @@ with tf.Graph().as_default() as G:
varmanip
.
dump_session_params
(
args
.
output
)
else
:
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
EXTRA_SAVE_VARS_KEY
))
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
var_dict
=
{}
for
v
in
var
:
name
=
varmanip
.
get_savename_from_varname
(
v
.
name
)
...
...
tensorpack/callbacks/base.py
View file @
589a8a35
...
...
@@ -15,9 +15,10 @@ class Callback(object):
""" Base class for all callbacks
Attributes:
epoch_num(int): the number of epochs that have completed the update
trainer(Trainer): the trainer
graph(tf.Graph): the graph
epoch_num(int): the epoch that have completed the update.
step_num(int): the step number in the current epoch.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
Note:
These attributes are available only after (and including)
...
...
@@ -34,7 +35,6 @@ class Callback(object):
"""
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
-
1
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
self
.
_setup_graph
()
...
...
@@ -91,7 +91,6 @@ class Callback(object):
"""
Triggered after every epoch.
"""
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
def
_trigger_epoch
(
self
):
...
...
@@ -106,6 +105,14 @@ class Callback(object):
def
_after_train
(
self
):
pass
@
property
def
epoch_num
(
self
):
return
self
.
trainer
.
epoch_num
@
property
def
step_num
(
self
):
return
self
.
trainer
.
step_num
def
__str__
(
self
):
return
type
(
self
)
.
__name__
...
...
@@ -128,12 +135,15 @@ class ProxyCallback(Callback):
def
_setup_graph
(
self
):
self
.
cb
.
setup_graph
(
self
.
trainer
)
def
_after_train
(
self
):
self
.
cb
.
after_train
()
def
_trigger_epoch
(
self
):
self
.
cb
.
trigger_epoch
()
def
_trigger_step
(
self
,
*
args
):
self
.
cb
.
trigger_step
(
*
args
)
def
_after_train
(
self
):
self
.
cb
.
after_train
()
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
...
...
tensorpack/callbacks/stats.py
View file @
589a8a35
...
...
@@ -112,7 +112,8 @@ class StatHolder(object):
class
StatPrinter
(
Callback
):
"""
A callback to control what stats to print. Print everything by default.
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
"""
def
__init__
(
self
,
print_tag
=
None
):
...
...
tensorpack/callbacks/steps.py
View file @
589a8a35
...
...
@@ -8,13 +8,14 @@
import
tensorflow
as
tf
import
re
from
six.moves
import
zip
import
tqdm
from
..utils
import
logger
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
MOVING_SUMMARY_VARS_KEY
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
from
.base
import
Callback
__all__
=
[
'StepStatPrinter'
,
'SummaryMovingAverage'
]
__all__
=
[
'StepStatPrinter'
,
'SummaryMovingAverage'
,
'ProgressBar'
]
class
StepStatPrinter
(
Callback
):
...
...
@@ -38,7 +39,7 @@ class StepStatPrinter(Callback):
class
SummaryMovingAverage
(
Callback
):
""" Maintain the moving average of the tensors
in every step, and summarize them.
in every step, and summarize them.
Enabled by default.
"""
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_VARS_KEY
,
decay
=
0.95
):
"""
...
...
@@ -65,3 +66,17 @@ class SummaryMovingAverage(Callback):
def
_extra_fetches
(
self
):
return
[
self
.
ema_op
]
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
def
_before_train
(
self
):
self
.
_total
=
self
.
trainer
.
config
.
step_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
def
_trigger_step
(
self
,
*
args
):
if
self
.
step_num
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
.
update
()
if
self
.
step_num
==
self
.
_total
-
1
:
self
.
_bar
.
__exit__
()
tensorpack/train/base.py
View file @
589a8a35
...
...
@@ -7,11 +7,10 @@ import re
import
weakref
import
six
from
six.moves
import
range
import
tqdm
import
tensorflow
as
tf
from
.config
import
TrainConfig
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
from
..utils.timer
import
timed_operation
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step
,
get_global_step_var
...
...
@@ -33,14 +32,18 @@ class Trainer(object):
""" Base class for a trainer.
Attributes:
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch).
"""
def
__init__
(
self
,
config
):
...
...
@@ -54,6 +57,9 @@ class Trainer(object):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
epoch_num
=
self
.
config
.
starting_epoch
self
.
step_num
=
0
def
train
(
self
):
""" Start training """
self
.
setup
()
...
...
@@ -165,15 +171,13 @@ class Trainer(object):
try
:
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
for
epoch_num
in
range
(
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
),
self
.
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
),
log_start
=
True
):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
**
get_tqdm_kwargs
(
leave
=
True
)):
for
self
.
step_num
in
range
(
self
.
config
.
step_per_epoch
):
if
self
.
coord
.
should_stop
():
return
fetch_data
=
self
.
run_step
()
# implemented by subclass
...
...
tensorpack/train/config.py
View file @
589a8a35
...
...
@@ -4,7 +4,9 @@
import
tensorflow
as
tf
from
..callbacks
import
Callbacks
,
SummaryMovingAverage
,
StatPrinter
from
..callbacks
import
(
Callbacks
,
SummaryMovingAverage
,
StatPrinter
,
ProgressBar
)
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
...
...
@@ -38,8 +40,8 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of
callbacks that will be used in the end
is
``callbacks + extra_callbacks``.
``[SummaryMovingAverage(),
ProgressBar(),
StatPrinter()]``. The list of
callbacks that will be used in the end
are
``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
...
...
@@ -80,7 +82,7 @@ class TrainConfig(object):
callbacks
=
callbacks
.
cbs
[:
-
1
]
# the last one is StatPrinter()
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
SummaryMovingAverage
(),
StatPrinter
()]
extra_callbacks
=
[
SummaryMovingAverage
(),
ProgressBar
(),
StatPrinter
()]
self
.
callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment