Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a6f88814
Commit
a6f88814
authored
Apr 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
docs in callbacks
parent
817bb882
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
139 additions
and
43 deletions
+139
-43
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+33
-17
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+8
-0
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+15
-6
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+11
-5
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+19
-3
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+25
-4
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+24
-5
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+4
-3
No files found.
tensorpack/callbacks/base.py
View file @
a6f88814
...
@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta
...
@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback
'
,
'TestCallback
'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback
Type'
,
'TestCallbackType
'
]
class
TrainCallback
(
object
):
class
TrainCallback
Type
(
object
):
pass
pass
class
TestCallback
(
object
):
class
TestCallback
Type
(
object
):
pass
pass
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
type
=
TrainCallback
()
type
=
TrainCallbackType
()
""" The graph that this callback should run on.
""" Determine the graph that this callback should run on.
Either TrainCallback or TestCallback
Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
"""
"""
def
before_train
(
self
,
trainer
):
def
before_train
(
self
,
trainer
):
"""
Called before starting iterative training.
:param trainer: a :class:`train.Trainer` instance
"""
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
_before_train
()
self
.
_before_train
()
def
_before_train
(
self
):
def
_before_train
(
self
):
"""
pass
Called before starting iterative training
"""
def
after_train
(
self
):
def
after_train
(
self
):
"""
Called after training.
"""
self
.
_after_train
()
self
.
_after_train
()
def
_after_train
(
self
):
def
_after_train
(
self
):
"""
pass
Called after training
"""
def
trigger_step
(
self
):
def
trigger_step
(
self
):
"""
"""
Callback to be triggered after every step (every backpropagation)
Callback to be triggered after every step (every backpropagation)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
"""
@
property
@
property
def
global_step
(
self
):
def
global_step
(
self
):
"""
Access the global step value of this training.
"""
return
self
.
trainer
.
global_step
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
epoch_num is the number of epoch finished.
Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished.
"""
"""
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
"""
pass
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class
PeriodicCallback
(
Callback
):
class
PeriodicCallback
(
Callback
):
"""
A callback to be triggered after every `period` epochs.
"""
def
__init__
(
self
,
period
):
def
__init__
(
self
,
period
):
self
.
period
=
period
"""
:param period: int
"""
self
.
period
=
int
(
period
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
if
self
.
epoch_num
%
self
.
period
==
0
:
...
...
tensorpack/callbacks/common.py
View file @
a6f88814
...
@@ -12,7 +12,15 @@ from ..utils import *
...
@@ -12,7 +12,15 @@ from ..utils import *
__all__
=
[
'PeriodicSaver'
]
__all__
=
[
'PeriodicSaver'
]
class
PeriodicSaver
(
PeriodicCallback
):
class
PeriodicSaver
(
PeriodicCallback
):
"""
Save the model to logger directory.
"""
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
self
.
keep_freq
=
keep_freq
...
...
tensorpack/callbacks/dump.py
View file @
a6f88814
...
@@ -9,21 +9,30 @@ import numpy as np
...
@@ -9,21 +9,30 @@ import numpy as np
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
__all__
=
[
'DumpParamAsImage'
]
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
class
DumpParamAsImage
(
Callback
):
"""
Dump a variable to image(s) after every epoch.
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
"""
map_func: map the value of the variable to an image or list of images, default to identity
:param var_name: the name of the variable.
images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255
:param prefix: the filename prefix for saved images. Default is the op name.
clip: clip the result to [0, 255]
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
"""
"""
self
.
var_name
=
var_name
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
func
=
map_func
self
.
func
=
map_func
if
prefix
is
None
:
if
prefix
is
None
:
self
.
prefix
=
self
.
var
_name
self
.
prefix
=
op
_name
else
:
else
:
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
log_dir
=
logger
.
LOG_DIR
self
.
log_dir
=
logger
.
LOG_DIR
...
...
tensorpack/callbacks/group.py
View file @
a6f88814
...
@@ -6,7 +6,7 @@ import tensorflow as tf
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
time
import
time
from
.base
import
Callback
,
TrainCallback
,
TestCallback
from
.base
import
Callback
,
TrainCallback
Type
,
TestCallbackType
from
.summary
import
*
from
.summary
import
*
from
..utils
import
*
from
..utils
import
*
...
@@ -91,11 +91,17 @@ class TestCallbackContext(object):
...
@@ -91,11 +91,17 @@ class TestCallbackContext(object):
yield
yield
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
"""
:param cbs: a list of `Callbacks`
"""
# check type
# check type
for
cb
in
cbs
:
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
,
TestCallback
)):
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
Type
,
TestCallbackType
)):
raise
ValueError
(
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
...
@@ -104,7 +110,7 @@ class Callbacks(Callback):
...
@@ -104,7 +110,7 @@ class Callbacks(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
cb
.
before_train
(
self
.
trainer
)
cb
.
before_train
(
self
.
trainer
)
else
:
else
:
with
self
.
test_callback_context
.
before_train_context
(
self
.
trainer
):
with
self
.
test_callback_context
.
before_train_context
(
self
.
trainer
):
...
@@ -116,7 +122,7 @@ class Callbacks(Callback):
...
@@ -116,7 +122,7 @@ class Callbacks(Callback):
def
trigger_step
(
self
):
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
cb
.
trigger_step
()
cb
.
trigger_step
()
# test callback don't have trigger_step
# test callback don't have trigger_step
...
@@ -125,7 +131,7 @@ class Callbacks(Callback):
...
@@ -125,7 +131,7 @@ class Callbacks(Callback):
test_sess_restored
=
False
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
Type
):
with
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
with
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
else
:
else
:
...
...
tensorpack/callbacks/param.py
View file @
a6f88814
...
@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
...
@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'
]
'ScheduledHyperParamSetter'
]
class
HyperParamSetter
(
Callback
):
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
# TODO maybe support InputVar?
# TODO maybe support InputVar?
def
__init__
(
self
,
var_name
,
shape
=
[]):
def
__init__
(
self
,
var_name
,
shape
=
[]):
"""
:param var_name: name of the variable
:param shape: shape of the variable
"""
self
.
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
shape
=
shape
self
.
shape
=
shape
self
.
last_value
=
None
self
.
last_value
=
None
...
@@ -37,6 +44,9 @@ class HyperParamSetter(Callback):
...
@@ -37,6 +44,9 @@ class HyperParamSetter(Callback):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
get_current_value
(
self
):
def
get_current_value
(
self
):
"""
:returns: the value to assign to the variable now.
"""
ret
=
self
.
_get_current_value
()
ret
=
self
.
_get_current_value
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} is changed to {}"
.
format
(
logger
.
info
(
"{} at epoch {} is changed to {}"
.
format
(
...
@@ -54,10 +64,13 @@ class HyperParamSetter(Callback):
...
@@ -54,10 +64,13 @@ class HyperParamSetter(Callback):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
class
HumanHyperParamSetter
(
HyperParamSetter
):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters manually by modifying a file.
"""
def
__init__
(
self
,
var_name
,
file_name
):
def
__init__
(
self
,
var_name
,
file_name
):
"""
"""
read value from file_nam
e.
:param var_name: name of the variabl
e.
file_name: e
ach line in the file is a k:v pair
:param file_name: a file containing the value of the variable. E
ach line in the file is a k:v pair
"""
"""
self
.
file_name
=
file_name
self
.
file_name
=
file_name
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
var_name
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
var_name
)
...
@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter):
return
None
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by a predefined schedule.
"""
def
__init__
(
self
,
var_name
,
schedule
):
def
__init__
(
self
,
var_name
,
schedule
):
"""
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
:param
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
"""
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
...
...
tensorpack/callbacks/summary.py
View file @
a6f88814
...
@@ -14,8 +14,14 @@ from ..utils import *
...
@@ -14,8 +14,14 @@ from ..utils import *
__all__
=
[
'StatHolder'
,
'StatPrinter'
]
__all__
=
[
'StatHolder'
,
'StatPrinter'
]
class
StatHolder
(
object
):
class
StatHolder
(
object
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
"""
self
.
set_print_tag
(
print_tag
)
A holder to keep all statistics aside from tensorflow events.
"""
def
__init__
(
self
,
log_dir
):
"""
:param log_dir: directory to save the stats.
"""
self
.
set_print_tag
([])
self
.
stat_now
=
{}
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
self
.
log_dir
=
log_dir
...
@@ -28,12 +34,23 @@ class StatHolder(object):
...
@@ -28,12 +34,23 @@ class StatHolder(object):
self
.
stat_history
=
[]
self
.
stat_history
=
[]
def
add_stat
(
self
,
k
,
v
):
def
add_stat
(
self
,
k
,
v
):
"""
Add a stat.
:param k: name
:param v: value
"""
self
.
stat_now
[
k
]
=
v
self
.
stat_now
[
k
]
=
v
def
set_print_tag
(
self
,
print_tag
):
def
set_print_tag
(
self
,
print_tag
):
"""
Set name of stats to print.
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
finalize
(
self
):
def
finalize
(
self
):
"""
Called after finishing adding stats. Will print and write stats to disk.
"""
self
.
_print_stat
()
self
.
_print_stat
()
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_now
=
{}
self
.
stat_now
=
{}
...
@@ -51,9 +68,13 @@ class StatHolder(object):
...
@@ -51,9 +68,13 @@ class StatHolder(object):
os
.
rename
(
tmp_filename
,
self
.
filename
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
StatPrinter
(
Callback
):
class
StatPrinter
(
Callback
):
"""
Control what stats to print.
"""
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
""" print_tag : a list of regex to match scalar summary to print
"""
if None, will print all scalar tags
:param print_tag : a list of regex to match scalar summary to print.
If None, will print all scalar tags
"""
"""
self
.
print_tag
=
print_tag
self
.
print_tag
=
print_tag
...
...
tensorpack/callbacks/validation_callback.py
View file @
a6f88814
...
@@ -10,16 +10,22 @@ from six.moves import zip
...
@@ -10,16 +10,22 @@ from six.moves import zip
from
..utils
import
*
from
..utils
import
*
from
..utils.stat
import
*
from
..utils.stat
import
*
from
..tfutils.summary
import
*
from
..tfutils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
Type
__all__
=
[
'ValidationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
__all__
=
[
'ValidationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
class
ValidationCallback
(
PeriodicCallback
):
class
ValidationCallback
(
PeriodicCallback
):
type
=
TestCallback
()
"""
"""
Base class for validation callbacks.
Base class for validation callbacks.
"""
"""
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
):
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super
(
ValidationCallback
,
self
)
.
__init__
(
period
)
super
(
ValidationCallback
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
ds
=
ds
self
.
prefix
=
prefix
self
.
prefix
=
prefix
...
@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback):
...
@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback):
self
.
_find_output_vars
()
self
.
_find_output_vars
()
def
get_tensor
(
self
,
name
):
def
get_tensor
(
self
,
name
):
"""
Get tensor from graph.
"""
return
self
.
graph
.
get_tensor_by_name
(
name
)
return
self
.
graph
.
get_tensor_by_name
(
name
)
@
abstractmethod
@
abstractmethod
...
@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback):
...
@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback):
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
"""
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
,
period
=
1
):
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
,
period
=
1
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
self
.
names
=
names_to_print
self
.
names
=
names_to_print
...
@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback):
...
@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback):
class
ValidationError
(
ValidationCallback
):
class
ValidationError
(
ValidationCallback
):
"""
"""
Validate the accuracy from a
'wrong'
variable
Validate the accuracy from a
`wrong`
variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
taking account of the fact that batches might not have the same size in
...
@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback):
...
@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback):
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
period
=
1
,
period
=
1
,
wrong_var_name
=
'wrong:0'
):
wrong_var_name
=
'wrong:0'
):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super
(
ValidationError
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
super
(
ValidationError
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
self
.
wrong_var_name
=
wrong_var_name
self
.
wrong_var_name
=
wrong_var_name
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
a6f88814
...
@@ -15,6 +15,8 @@ from ..base import DataFlow
...
@@ -15,6 +15,8 @@ from ..base import DataFlow
__all__
=
[
'SVHNDigit'
]
__all__
=
[
'SVHNDigit'
]
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
DataFlow
):
class
SVHNDigit
(
DataFlow
):
"""
"""
SVHN Cropped Digit Dataset
SVHN Cropped Digit Dataset
...
@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow):
...
@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
"""
name: 'train', 'test', or 'extra'
name: 'train', 'test', or 'extra'
data_dir: a directory containing {train,test,extra}_32x32.mat
data_dir: a directory containing
the original
{train,test,extra}_32x32.mat
"""
"""
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow):
...
@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow):
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
assert
os
.
path
.
isfile
(
filename
),
\
"File {} not found! Please download it from
\
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
http://ufldl.stanford.edu/housenumbers/"
.
format
(
filename
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
data
=
scipy
.
io
.
loadmat
(
filename
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment