Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
069c0b9c
Commit
069c0b9c
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use explicit kwargs in TrainConfig
parent
b5f8c73a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
72 deletions
+100
-72
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+43
-26
tensorpack/train/config.py
tensorpack/train/config.py
+48
-40
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+7
-4
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
No files found.
examples/DoReFa-Net/README.md
View file @
069c0b9c
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation
error
.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation
accuracy
.
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
...
...
tensorpack/train/base.py
View file @
069c0b9c
...
@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining']
...
@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining']
class
StopTraining
(
BaseException
):
class
StopTraining
(
BaseException
):
"""
An exception thrown to stop training.
"""
pass
pass
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer."""
""" Base class for a trainer.
"""a `StatHolder` instance"""
Attributes:
stat_holder
=
None
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
"""`tf.SummaryWriter`"""
summary_op (tf.Operation): an Op which outputs all summaries.
summary_writer
=
None
config (TrainConfig): the config used in this trainer.
"""a tf.Tensor which returns summary string"""
model (ModelDesc)
summary_op
=
None
sess (tf.Session): the current session in use.
""" TrainConfig """
coord (tf.train.Coordinator)
config
=
None
"""
""" a ModelDesc"""
model
=
None
""" the current session"""
sess
=
None
""" the `tf.train.Coordinator` """
coord
=
None
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
:param config: a `TrainConfig` instance
Args:
config (TrainConfig): the train config.
"""
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
config
=
config
...
@@ -56,27 +54,35 @@ class Trainer(object):
...
@@ -56,27 +54,35 @@ class Trainer(object):
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
coord
=
tf
.
train
.
Coordinator
()
def
train
(
self
):
def
train
(
self
):
""" Start training"""
""" Start training
"""
self
.
setup
()
self
.
setup
()
self
.
main_loop
()
self
.
main_loop
()
@
abstractmethod
@
abstractmethod
def
run_step
(
self
):
def
run_step
(
self
):
""" run an iteration"""
""" Abstract method. Run one iteration. """
pass
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
):
""" return a online predictor"""
"""
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
"""
return n predictor function
s.
"""
Return n predictor
s.
Can be overwritten by subclasses to exploit more
Can be overwritten by subclasses to exploit more
parallelism among
func
s.
parallelism among
predictor
s.
"""
"""
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
Called after each epoch.
"""
# trigger subclass
# trigger subclass
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
# trigger callbacks
# trigger callbacks
...
@@ -85,7 +91,6 @@ class Trainer(object):
...
@@ -85,7 +91,6 @@ class Trainer(object):
@
abstractmethod
@
abstractmethod
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
""" This is called right after all steps in an epoch are finished"""
pass
pass
def
_process_summary
(
self
,
summary_str
):
def
_process_summary
(
self
,
summary_str
):
...
@@ -100,11 +105,21 @@ class Trainer(object):
...
@@ -100,11 +105,21 @@ class Trainer(object):
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
def
write_scalar_summary
(
self
,
name
,
val
):
def
write_scalar_summary
(
self
,
name
,
val
):
"""
Write a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self
.
summary_writer
.
add_summary
(
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
setup
(
self
):
def
setup
(
self
):
"""
Setup the trainer and be ready for the main loop.
"""
self
.
_setup
()
self
.
_setup
()
describe_model
()
describe_model
()
get_global_step_var
()
get_global_step_var
()
...
@@ -120,7 +135,6 @@ class Trainer(object):
...
@@ -120,7 +135,6 @@ class Trainer(object):
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
logger
.
info
(
"Initializing graph variables ..."
)
# TODO newsession + sessinit?
initop
=
tf
.
global_variables_initializer
()
initop
=
tf
.
global_variables_initializer
()
self
.
sess
.
run
(
initop
)
self
.
sess
.
run
(
initop
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
...
@@ -134,6 +148,9 @@ class Trainer(object):
...
@@ -134,6 +148,9 @@ class Trainer(object):
""" setup Trainer-specific stuff for training"""
""" setup Trainer-specific stuff for training"""
def
main_loop
(
self
):
def
main_loop
(
self
):
"""
Run the main training loop.
"""
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
...
...
tensorpack/train/config.py
View file @
069c0b9c
...
@@ -17,54 +17,64 @@ __all__ = ['TrainConfig']
...
@@ -17,54 +17,64 @@ __all__ = ['TrainConfig']
class
TrainConfig
(
object
):
class
TrainConfig
(
object
):
"""
"""
Config for train
ing a model with a single loss
Config for train
er.
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
dataset
=
None
,
data
=
None
,
model
=
None
,
optimizer
=
None
,
callbacks
=
None
,
session_config
=
get_default_sess_config
(),
session_init
=
None
,
starting_epoch
=
1
,
step_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
**
kwargs
):
"""
"""
:param dataset: the dataset to train. a `DataFlow` instance.
Args:
:param data: an `InputData` instance
dataset (DataFlow): the dataset to train.
data (InputData): an `InputData` instance. Only one of ``dataset``
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig
.
or ``data`` has to be present
.
:param callbacks: a `callback.Callbacks` instance. Define
model (ModelDesc): the model to train.
the callbacks to perform during trainin
g.
optimizer (tf.train.Optimizer): the optimizer for traini
g.
:param session_config: a `tf.ConfigProto` instance to instantiate the session
.
callbacks (Callbacks): the callbacks to perform during training
.
:param session_init: a `sessinit.SessionInit` instance to
session_config (tf.ConfigProto): the config used to instantiate the session.
initialize variables of a session. default
to a new session.
session_init (SessionInit): how to initialize variables of a session. Defaults
to a new session.
:param model: a `ModelDesc` instance
.
starting_epoch (int): The index of the first epoch
.
:param starting_epoch: int. default to be 1
.
step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch
.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch
.
Defaults to the input data size
.
:param max_epoch: maximum number of epoch to run training. default to inf
max_epoch (int): maximum number of epoch to run training.
:param nr_tower: int. number of training towers. default to 1
.
nr_tower (int): number of training towers
.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given
.
tower (list of int): list of training towers in relative id
.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
predict_tower (list of int): list of prediction towers in their relative gpu id.
"""
"""
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
if
'dataset'
in
kwargs
:
if
dataset
is
not
None
:
assert
'data'
not
in
kwargs
,
"dataset and data cannot be both presented in TrainConfig!"
assert
data
is
None
,
"dataset and data cannot be both presented in TrainConfig!"
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
self
.
dataset
=
dataset
assert_type
(
self
.
dataset
,
DataFlow
)
assert_type
(
self
.
dataset
,
DataFlow
)
else
:
else
:
self
.
data
=
kwargs
.
pop
(
'data'
)
self
.
data
=
data
assert_type
(
self
.
data
,
InputData
)
assert_type
(
self
.
data
,
InputData
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
self
.
optimizer
=
optimizer
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
self
.
callbacks
=
callbacks
assert_type
(
self
.
callbacks
,
Callbacks
)
assert_type
(
self
.
callbacks
,
Callbacks
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
self
.
session_config
=
session_config
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
self
.
step_per_epoch
=
step_per_epoch
if
self
.
step_per_epoch
is
None
:
if
self
.
step_per_epoch
is
None
:
try
:
try
:
if
hasattr
(
self
,
'dataset'
)
:
if
dataset
is
not
None
:
self
.
step_per_epoch
=
self
.
dataset
.
size
()
self
.
step_per_epoch
=
self
.
dataset
.
size
()
else
:
else
:
self
.
step_per_epoch
=
self
.
data
.
size
()
self
.
step_per_epoch
=
self
.
data
.
size
()
...
@@ -73,22 +83,20 @@ class TrainConfig(object):
...
@@ -73,22 +83,20 @@ class TrainConfig(object):
else
:
else
:
self
.
step_per_epoch
=
int
(
self
.
step_per_epoch
)
self
.
step_per_epoch
=
int
(
self
.
step_per_epoch
)
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
)
)
self
.
starting_epoch
=
int
(
starting_epoch
)
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
99999
)
)
self
.
max_epoch
=
int
(
max_epoch
)
assert
self
.
step_per_epoch
>=
0
and
self
.
max_epoch
>
0
assert
self
.
step_per_epoch
>=
0
and
self
.
max_epoch
>
0
if
'nr_tower'
in
kwargs
:
self
.
nr_tower
=
nr_tower
assert
'tower'
not
in
kwargs
,
"Cannot set both nr_tower and tower in TrainConfig!"
if
tower
is
not
None
:
self
.
nr_tower
=
kwargs
.
pop
(
'nr_tower'
)
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
elif
'tower'
in
kwargs
:
self
.
tower
=
tower
self
.
tower
=
kwargs
.
pop
(
'tower'
)
else
:
self
.
predict_tower
=
predict_tower
self
.
tower
=
[
0
]
self
.
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
[
0
])
if
isinstance
(
self
.
predict_tower
,
int
):
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
self
.
predict_tower
=
[
self
.
predict_tower
]
# TODO deprecated @
Dec
20
# TODO deprecated @
Jan
20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
...
...
tensorpack/train/feedfree.py
View file @
069c0b9c
...
@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput
...
@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput
from
.base
import
Trainer
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
from
.trainer
import
MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# need to run summary_op every epoch
...
@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer):
...
@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
()
actual_inputs
=
self
.
_get_input_tensors
()
...
@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...
@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
return
cost_var
,
grads
return
cost_var
,
grads
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run
self.train_op
"""
""" Simply run
``self.train_op``, which minimizes the cost.
"""
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
...
...
tensorpack/train/input_data.py
View file @
069c0b9c
...
@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary
...
@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
__all__
=
[
'
InputData'
,
'
QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
'DummyConstantInput'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment