Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
069c0b9c
Commit
069c0b9c
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use explicit kwargs in TrainConfig
parent
b5f8c73a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
72 deletions
+100
-72
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+43
-26
tensorpack/train/config.py
tensorpack/train/config.py
+48
-40
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+7
-4
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
No files found.
examples/DoReFa-Net/README.md
View file @
069c0b9c
...
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation
error
.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation
accuracy
.
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
...
...
tensorpack/train/base.py
View file @
069c0b9c
...
...
@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining']
class
StopTraining
(
BaseException
):
"""
An exception thrown to stop training.
"""
pass
@
six
.
add_metaclass
(
ABCMeta
)
class
Trainer
(
object
):
""" Base class for a trainer."""
"""a `StatHolder` instance"""
stat_holder
=
None
"""`tf.SummaryWriter`"""
summary_writer
=
None
"""a tf.Tensor which returns summary string"""
summary_op
=
None
""" TrainConfig """
config
=
None
""" a ModelDesc"""
model
=
None
""" the current session"""
sess
=
None
""" the `tf.train.Coordinator` """
coord
=
None
""" Base class for a trainer.
Attributes:
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
"""
def
__init__
(
self
,
config
):
"""
:param config: a `TrainConfig` instance
Args:
config (TrainConfig): the train config.
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
...
...
@@ -56,27 +54,35 @@ class Trainer(object):
self
.
coord
=
tf
.
train
.
Coordinator
()
def
train
(
self
):
""" Start training"""
""" Start training
"""
self
.
setup
()
self
.
main_loop
()
@
abstractmethod
def
run_step
(
self
):
""" run an iteration"""
pass
""" Abstract method. Run one iteration. """
def
get_predict_func
(
self
,
input_names
,
output_names
):
""" return a online predictor"""
"""
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise
NotImplementedError
()
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
"""
return n predictor function
s.
"""
Return n predictor
s.
Can be overwritten by subclasses to exploit more
parallelism among
func
s.
parallelism among
predictor
s.
"""
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
"""
Called after each epoch.
"""
# trigger subclass
self
.
_trigger_epoch
()
# trigger callbacks
...
...
@@ -85,7 +91,6 @@ class Trainer(object):
@
abstractmethod
def
_trigger_epoch
(
self
):
""" This is called right after all steps in an epoch are finished"""
pass
def
_process_summary
(
self
,
summary_str
):
...
...
@@ -100,11 +105,21 @@ class Trainer(object):
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
def
write_scalar_summary
(
self
,
name
,
val
):
"""
Write a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
setup
(
self
):
"""
Setup the trainer and be ready for the main loop.
"""
self
.
_setup
()
describe_model
()
get_global_step_var
()
...
...
@@ -120,7 +135,6 @@ class Trainer(object):
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
# TODO newsession + sessinit?
initop
=
tf
.
global_variables_initializer
()
self
.
sess
.
run
(
initop
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
...
...
@@ -134,6 +148,9 @@ class Trainer(object):
""" setup Trainer-specific stuff for training"""
def
main_loop
(
self
):
"""
Run the main training loop.
"""
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
...
...
tensorpack/train/config.py
View file @
069c0b9c
...
...
@@ -17,54 +17,64 @@ __all__ = ['TrainConfig']
class
TrainConfig
(
object
):
"""
Config for train
ing a model with a single loss
Config for train
er.
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
dataset
=
None
,
data
=
None
,
model
=
None
,
optimizer
=
None
,
callbacks
=
None
,
session_config
=
get_default_sess_config
(),
session_init
=
None
,
starting_epoch
=
1
,
step_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
**
kwargs
):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig
.
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during trainin
g.
:param session_config: a `tf.ConfigProto` instance to instantiate the session
.
:param session_init: a `sessinit.SessionInit` instance to
initialize variables of a session. default
to a new session.
:param model: a `ModelDesc` instance
.
:param starting_epoch: int. default to be 1
.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch
.
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1
.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given
.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
Args:
dataset (DataFlow): the dataset to train.
data (InputData): an `InputData` instance. Only one of ``dataset``
or ``data`` has to be present
.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for traini
g.
callbacks (Callbacks): the callbacks to perform during training
.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults
to a new session.
starting_epoch (int): The index of the first epoch
.
step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch
.
Defaults to the input data size
.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers
.
tower (list of int): list of training towers in relative id
.
predict_tower (list of int): list of prediction towers in their relative gpu id.
"""
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
if
'dataset'
in
kwargs
:
assert
'data'
not
in
kwargs
,
"dataset and data cannot be both presented in TrainConfig!"
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
if
dataset
is
not
None
:
assert
data
is
None
,
"dataset and data cannot be both presented in TrainConfig!"
self
.
dataset
=
dataset
assert_type
(
self
.
dataset
,
DataFlow
)
else
:
self
.
data
=
kwargs
.
pop
(
'data'
)
self
.
data
=
data
assert_type
(
self
.
data
,
InputData
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
self
.
optimizer
=
optimizer
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
self
.
callbacks
=
callbacks
assert_type
(
self
.
callbacks
,
Callbacks
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
self
.
session_config
=
session_config
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
self
.
step_per_epoch
=
step_per_epoch
if
self
.
step_per_epoch
is
None
:
try
:
if
hasattr
(
self
,
'dataset'
)
:
if
dataset
is
not
None
:
self
.
step_per_epoch
=
self
.
dataset
.
size
()
else
:
self
.
step_per_epoch
=
self
.
data
.
size
()
...
...
@@ -73,22 +83,20 @@ class TrainConfig(object):
else
:
self
.
step_per_epoch
=
int
(
self
.
step_per_epoch
)
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
)
)
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
99999
)
)
self
.
starting_epoch
=
int
(
starting_epoch
)
self
.
max_epoch
=
int
(
max_epoch
)
assert
self
.
step_per_epoch
>=
0
and
self
.
max_epoch
>
0
if
'nr_tower'
in
kwargs
:
assert
'tower'
not
in
kwargs
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
nr_tower
=
kwargs
.
pop
(
'nr_tower'
)
elif
'tower'
in
kwargs
:
self
.
tower
=
kwargs
.
pop
(
'tower'
)
else
:
self
.
tower
=
[
0
]
self
.
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
[
0
])
self
.
nr_tower
=
nr_tower
if
tower
is
not
None
:
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
tower
=
tower
self
.
predict_tower
=
predict_tower
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
# TODO deprecated @
Dec
20
# TODO deprecated @
Jan
20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
...
...
tensorpack/train/feedfree.py
View file @
069c0b9c
...
...
@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
...
...
@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
()
...
...
@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
return
cost_var
,
grads
def
run_step
(
self
):
""" Simply run
self.train_op
"""
""" Simply run
``self.train_op``, which minimizes the cost.
"""
self
.
sess
.
run
(
self
.
train_op
)
# if not hasattr(self, 'cnt'):
# self.cnt = 0
...
...
tensorpack/train/input_data.py
View file @
069c0b9c
...
...
@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
__all__
=
[
'
InputData'
,
'
QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment