Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a6f88814
Commit
a6f88814
authored
Apr 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
docs in callbacks
parent
817bb882
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
139 additions
and
43 deletions
+139
-43
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+33
-17
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+8
-0
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+15
-6
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+11
-5
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+19
-3
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+25
-4
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+24
-5
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+4
-3
No files found.
tensorpack/callbacks/base.py
View file @
a6f88814
...
...
@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback
'
,
'TestCallback
'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback
Type'
,
'TestCallbackType
'
]
class
TrainCallback
(
object
):
class
TrainCallback
Type
(
object
):
pass
class
TestCallback
(
object
):
class
TestCallback
Type
(
object
):
pass
class
Callback
(
object
):
""" Base class for all callbacks """
__metaclass__
=
ABCMeta
type
=
TrainCallback
()
""" The graph that this callback should run on.
Either TrainCallback or TestCallback
type
=
TrainCallbackType
()
""" Determine the graph that this callback should run on.
Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
"""
def
before_train
(
self
,
trainer
):
"""
Called before starting iterative training.
:param trainer: a :class:`train.Trainer` instance
"""
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
_before_train
()
def
_before_train
(
self
):
"""
Called before starting iterative training
"""
pass
def
after_train
(
self
):
"""
Called after training.
"""
self
.
_after_train
()
def
_after_train
(
self
):
"""
Called after training
"""
pass
def
trigger_step
(
self
):
"""
Callback to be triggered after every step (every backpropagation)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@
property
def
global_step
(
self
):
"""
Access the global step value of this training.
"""
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
"""
epoch_num is the number of epoch finished.
Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished.
"""
self
.
_trigger_epoch
()
self
.
epoch_num
+=
1
def
_trigger_epoch
(
self
):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
pass
class
PeriodicCallback
(
Callback
):
"""
A callback to be triggered after every `period` epochs.
"""
def
__init__
(
self
,
period
):
self
.
period
=
period
"""
:param period: int
"""
self
.
period
=
int
(
period
)
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
...
...
tensorpack/callbacks/common.py
View file @
a6f88814
...
...
@@ -12,7 +12,15 @@ from ..utils import *
__all__
=
[
'PeriodicSaver'
]
class
PeriodicSaver
(
PeriodicCallback
):
"""
Save the model to logger directory.
"""
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
...
...
tensorpack/callbacks/dump.py
View file @
a6f88814
...
...
@@ -9,21 +9,30 @@ import numpy as np
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
"""
Dump a variable to image(s) after every epoch.
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255
clip: clip the result to [0, 255]
:param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
"""
self
.
var_name
=
var_name
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
func
=
map_func
if
prefix
is
None
:
self
.
prefix
=
self
.
var
_name
self
.
prefix
=
op
_name
else
:
self
.
prefix
=
prefix
self
.
log_dir
=
logger
.
LOG_DIR
...
...
tensorpack/callbacks/group.py
View file @
a6f88814
...
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
import
time
from
.base
import
Callback
,
TrainCallback
,
TestCallback
from
.base
import
Callback
,
TrainCallback
Type
,
TestCallbackType
from
.summary
import
*
from
..utils
import
*
...
...
@@ -91,11 +91,17 @@ class TestCallbackContext(object):
yield
class
Callbacks
(
Callback
):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def
__init__
(
self
,
cbs
):
"""
:param cbs: a list of `Callbacks`
"""
# check type
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
,
TestCallback
)):
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
Type
,
TestCallbackType
)):
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
...
...
@@ -104,7 +110,7 @@ class Callbacks(Callback):
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
cb
.
before_train
(
self
.
trainer
)
else
:
with
self
.
test_callback_context
.
before_train_context
(
self
.
trainer
):
...
...
@@ -116,7 +122,7 @@ class Callbacks(Callback):
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
cb
.
trigger_step
()
# test callback don't have trigger_step
...
...
@@ -125,7 +131,7 @@ class Callbacks(Callback):
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
with
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
else
:
...
...
tensorpack/callbacks/param.py
View file @
a6f88814
...
...
@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'
]
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__
=
ABCMeta
# TODO maybe support InputVar?
def
__init__
(
self
,
var_name
,
shape
=
[]):
"""
:param var_name: name of the variable
:param shape: shape of the variable
"""
self
.
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
shape
=
shape
self
.
last_value
=
None
...
...
@@ -37,6 +44,9 @@ class HyperParamSetter(Callback):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
get_current_value
(
self
):
"""
:returns: the value to assign to the variable now.
"""
ret
=
self
.
_get_current_value
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} is changed to {}"
.
format
(
...
...
@@ -54,10 +64,13 @@ class HyperParamSetter(Callback):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters manually by modifying a file.
"""
def
__init__
(
self
,
var_name
,
file_name
):
"""
read value from file_nam
e.
file_name: e
ach line in the file is a k:v pair
:param var_name: name of the variabl
e.
:param file_name: a file containing the value of the variable. E
ach line in the file is a k:v pair
"""
self
.
file_name
=
file_name
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
var_name
)
...
...
@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter):
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by a predefined schedule.
"""
def
__init__
(
self
,
var_name
,
schedule
):
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
:param
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
...
...
tensorpack/callbacks/summary.py
View file @
a6f88814
...
...
@@ -14,8 +14,14 @@ from ..utils import *
__all__
=
[
'StatHolder'
,
'StatPrinter'
]
class
StatHolder
(
object
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
self
.
set_print_tag
(
print_tag
)
"""
A holder to keep all statistics aside from tensorflow events.
"""
def
__init__
(
self
,
log_dir
):
"""
:param log_dir: directory to save the stats.
"""
self
.
set_print_tag
([])
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
...
...
@@ -28,12 +34,23 @@ class StatHolder(object):
self
.
stat_history
=
[]
def
add_stat
(
self
,
k
,
v
):
"""
Add a stat.
:param k: name
:param v: value
"""
self
.
stat_now
[
k
]
=
v
def
set_print_tag
(
self
,
print_tag
):
"""
Set name of stats to print.
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
finalize
(
self
):
"""
Called after finishing adding stats. Will print and write stats to disk.
"""
self
.
_print_stat
()
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_now
=
{}
...
...
@@ -51,9 +68,13 @@ class StatHolder(object):
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
StatPrinter
(
Callback
):
"""
Control what stats to print.
"""
def
__init__
(
self
,
print_tag
=
None
):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
:param print_tag : a list of regex to match scalar summary to print.
If None, will print all scalar tags
"""
self
.
print_tag
=
print_tag
...
...
tensorpack/callbacks/validation_callback.py
View file @
a6f88814
...
...
@@ -10,16 +10,22 @@ from six.moves import zip
from
..utils
import
*
from
..utils.stat
import
*
from
..tfutils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
Type
__all__
=
[
'ValidationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
class
ValidationCallback
(
PeriodicCallback
):
type
=
TestCallback
()
"""
Base class for validation callbacks.
"""
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super
(
ValidationCallback
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
prefix
=
prefix
...
...
@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback):
self
.
_find_output_vars
()
def
get_tensor
(
self
,
name
):
"""
Get tensor from graph.
"""
return
self
.
graph
.
get_tensor_by_name
(
name
)
@
abstractmethod
...
...
@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback):
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
,
period
=
1
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
self
.
names
=
names_to_print
...
...
@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback):
class
ValidationError
(
ValidationCallback
):
"""
Validate the accuracy from a
'wrong'
variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
Validate the accuracy from a
`wrong`
variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
...
...
@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback):
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
period
=
1
,
wrong_var_name
=
'wrong:0'
):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super
(
ValidationError
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
self
.
wrong_var_name
=
wrong_var_name
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
a6f88814
...
...
@@ -15,6 +15,8 @@ from ..base import DataFlow
__all__
=
[
'SVHNDigit'
]
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
DataFlow
):
"""
SVHN Cropped Digit Dataset
...
...
@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
name: 'train', 'test', or 'extra'
data_dir: a directory containing {train,test,extra}_32x32.mat
data_dir: a directory containing
the original
{train,test,extra}_32x32.mat
"""
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
...
...
@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow):
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
"File {} not found! Please download it from
\
http://ufldl.stanford.edu/housenumbers/"
.
format
(
filename
)
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment