Commit 069c0b9c authored by Yuxin Wu's avatar Yuxin Wu

use explicit kwargs in TrainConfig

parent b5f8c73a
......@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation error.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
......@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass
@six.add_metaclass(ABCMeta)
class Trainer(object):
""" Base class for a trainer."""
"""a `StatHolder` instance"""
stat_holder = None
"""`tf.SummaryWriter`"""
summary_writer = None
"""a tf.Tensor which returns summary string"""
summary_op = None
""" TrainConfig """
config = None
""" a ModelDesc"""
model = None
""" the current session"""
sess = None
""" the `tf.train.Coordinator` """
coord = None
""" Base class for a trainer.
Attributes:
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
"""
def __init__(self, config):
"""
:param config: a `TrainConfig` instance
Args:
config (TrainConfig): the train config.
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
......@@ -56,27 +54,35 @@ class Trainer(object):
self.coord = tf.train.Coordinator()
def train(self):
""" Start training"""
""" Start training """
self.setup()
self.main_loop()
@abstractmethod
def run_step(self):
""" run an iteration"""
pass
""" Abstract method. Run one iteration. """
def get_predict_func(self, input_names, output_names):
""" return a online predictor"""
"""
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions.
""" Return n predictors.
Can be overwritten by subclasses to exploit more
parallelism among funcs.
parallelism among predictors.
"""
return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self):
"""
Called after each epoch.
"""
# trigger subclass
self._trigger_epoch()
# trigger callbacks
......@@ -85,7 +91,6 @@ class Trainer(object):
@abstractmethod
def _trigger_epoch(self):
""" This is called right after all steps in an epoch are finished"""
pass
def _process_summary(self, summary_str):
......@@ -100,11 +105,21 @@ class Trainer(object):
self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val):
"""
Write a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self.summary_writer.add_summary(
create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val)
def setup(self):
"""
Setup the trainer and be ready for the main loop.
"""
self._setup()
describe_model()
get_global_step_var()
......@@ -120,7 +135,6 @@ class Trainer(object):
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...")
# TODO newsession + sessinit?
initop = tf.global_variables_initializer()
self.sess.run(initop)
self.config.session_init.init(self.sess)
......@@ -134,6 +148,9 @@ class Trainer(object):
""" setup Trainer-specific stuff for training"""
def main_loop(self):
"""
Run the main training loop.
"""
callbacks = self.config.callbacks
with self.sess.as_default():
try:
......
......@@ -17,54 +17,64 @@ __all__ = ['TrainConfig']
class TrainConfig(object):
"""
Config for training a model with a single loss
Config for trainer.
"""
def __init__(self, **kwargs):
def __init__(self, dataset=None, data=None,
model=None, optimizer=None, callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
**kwargs):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training.
:param session_config: a `tf.ConfigProto` instance to instantiate the session.
:param session_init: a `sessinit.SessionInit` instance to
initialize variables of a session. default to a new session.
:param model: a `ModelDesc` instance.
:param starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
Args:
dataset (DataFlow): the dataset to train.
data (InputData): an `InputData` instance. Only one of ``dataset``
or ``data`` has to be present.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (Callbacks): the callbacks to perform during training.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
"""
# TODO type checker decorator
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
if 'dataset' in kwargs:
assert 'data' not in kwargs, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = kwargs.pop('dataset')
if dataset is not None:
assert data is None, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = dataset
assert_type(self.dataset, DataFlow)
else:
self.data = kwargs.pop('data')
self.data = data
assert_type(self.data, InputData)
self.optimizer = kwargs.pop('optimizer')
self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
self.callbacks = callbacks
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
self.model = model
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
self.session_config = session_config
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', JustCurrentSession())
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit)
self.step_per_epoch = kwargs.pop('step_per_epoch', None)
self.step_per_epoch = step_per_epoch
if self.step_per_epoch is None:
try:
if hasattr(self, 'dataset'):
if dataset is not None:
self.step_per_epoch = self.dataset.size()
else:
self.step_per_epoch = self.data.size()
......@@ -73,22 +83,20 @@ class TrainConfig(object):
else:
self.step_per_epoch = int(self.step_per_epoch)
self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
self.max_epoch = int(kwargs.pop('max_epoch', 99999))
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
assert self.step_per_epoch >= 0 and self.max_epoch > 0
if 'nr_tower' in kwargs:
assert 'tower' not in kwargs, "Cannot set both nr_tower and tower in TrainConfig!"
self.nr_tower = kwargs.pop('nr_tower')
elif 'tower' in kwargs:
self.tower = kwargs.pop('tower')
else:
self.tower = [0]
self.predict_tower = kwargs.pop('predict_tower', [0])
self.nr_tower = nr_tower
if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower
self.predict_tower = predict_tower
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
# TODO deprecated @Dec20
# TODO deprecated @Jan20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
......
......@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def _trigger_epoch(self):
# need to run summary_op every epoch
......@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer):
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
......@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
""" Simply run ``self.train_op``, which minimizes the cost."""
self.sess.run(self.train_op)
# if not hasattr(self, 'cnt'):
# self.cnt = 0
......
......@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput',
__all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment