Commit 069c0b9c authored by Yuxin Wu's avatar Yuxin Wu

use explicit kwargs in TrainConfig

parent b5f8c73a
...@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`. ...@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ). [google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications. They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation error. The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
Alternative link to this page: [http://dorefa.net](http://dorefa.net) Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
...@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining'] ...@@ -22,32 +22,30 @@ __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException): class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass pass
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Trainer(object): class Trainer(object):
""" Base class for a trainer.""" """ Base class for a trainer.
"""a `StatHolder` instance""" Attributes:
stat_holder = None stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
"""`tf.SummaryWriter`""" summary_op (tf.Operation): an Op which outputs all summaries.
summary_writer = None config (TrainConfig): the config used in this trainer.
"""a tf.Tensor which returns summary string""" model (ModelDesc)
summary_op = None sess (tf.Session): the current session in use.
""" TrainConfig """ coord (tf.train.Coordinator)
config = None """
""" a ModelDesc"""
model = None
""" the current session"""
sess = None
""" the `tf.train.Coordinator` """
coord = None
def __init__(self, config): def __init__(self, config):
""" """
:param config: a `TrainConfig` instance Args:
config (TrainConfig): the train config.
""" """
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
...@@ -56,27 +54,35 @@ class Trainer(object): ...@@ -56,27 +54,35 @@ class Trainer(object):
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
def train(self): def train(self):
""" Start training""" """ Start training """
self.setup() self.setup()
self.main_loop() self.main_loop()
@abstractmethod @abstractmethod
def run_step(self): def run_step(self):
""" run an iteration""" """ Abstract method. Run one iteration. """
pass
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
""" return a online predictor""" """
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise NotImplementedError() raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions. """ Return n predictors.
Can be overwritten by subclasses to exploit more Can be overwritten by subclasses to exploit more
parallelism among funcs. parallelism among predictors.
""" """
return [self.get_predict_func(input_names, output_names) for k in range(n)] return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
"""
Called after each epoch.
"""
# trigger subclass # trigger subclass
self._trigger_epoch() self._trigger_epoch()
# trigger callbacks # trigger callbacks
...@@ -85,7 +91,6 @@ class Trainer(object): ...@@ -85,7 +91,6 @@ class Trainer(object):
@abstractmethod @abstractmethod
def _trigger_epoch(self): def _trigger_epoch(self):
""" This is called right after all steps in an epoch are finished"""
pass pass
def _process_summary(self, summary_str): def _process_summary(self, summary_str):
...@@ -100,11 +105,21 @@ class Trainer(object): ...@@ -100,11 +105,21 @@ class Trainer(object):
self.summary_writer.add_summary(summary, get_global_step()) self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val): def write_scalar_summary(self, name, val):
"""
Write a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self.summary_writer.add_summary( self.summary_writer.add_summary(
create_summary(name, val), get_global_step()) create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val) self.stat_holder.add_stat(name, val)
def setup(self): def setup(self):
"""
Setup the trainer and be ready for the main loop.
"""
self._setup() self._setup()
describe_model() describe_model()
get_global_step_var() get_global_step_var()
...@@ -120,7 +135,6 @@ class Trainer(object): ...@@ -120,7 +135,6 @@ class Trainer(object):
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
# TODO newsession + sessinit?
initop = tf.global_variables_initializer() initop = tf.global_variables_initializer()
self.sess.run(initop) self.sess.run(initop)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
...@@ -134,6 +148,9 @@ class Trainer(object): ...@@ -134,6 +148,9 @@ class Trainer(object):
""" setup Trainer-specific stuff for training""" """ setup Trainer-specific stuff for training"""
def main_loop(self): def main_loop(self):
"""
Run the main training loop.
"""
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
......
...@@ -17,54 +17,64 @@ __all__ = ['TrainConfig'] ...@@ -17,54 +17,64 @@ __all__ = ['TrainConfig']
class TrainConfig(object): class TrainConfig(object):
""" """
Config for training a model with a single loss Config for trainer.
""" """
def __init__(self, **kwargs): def __init__(self, dataset=None, data=None,
model=None, optimizer=None, callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
**kwargs):
""" """
:param dataset: the dataset to train. a `DataFlow` instance. Args:
:param data: an `InputData` instance dataset (DataFlow): the dataset to train.
data (InputData): an `InputData` instance. Only one of ``dataset``
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig. or ``data`` has to be present.
:param callbacks: a `callback.Callbacks` instance. Define model (ModelDesc): the model to train.
the callbacks to perform during training. optimizer (tf.train.Optimizer): the optimizer for trainig.
:param session_config: a `tf.ConfigProto` instance to instantiate the session. callbacks (Callbacks): the callbacks to perform during training.
:param session_init: a `sessinit.SessionInit` instance to session_config (tf.ConfigProto): the config used to instantiate the session.
initialize variables of a session. default to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
:param model: a `ModelDesc` instance. starting_epoch (int): The index of the first epoch.
:param starting_epoch: int. default to be 1. step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch. Defaults to the input data size.
:param max_epoch: maximum number of epoch to run training. default to inf max_epoch (int): maximum number of epoch to run training.
:param nr_tower: int. number of training towers. default to 1. nr_tower (int): number of training towers.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given. tower (list of int): list of training towers in relative id.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0] predict_tower (list of int): list of prediction towers in their relative gpu id.
""" """
# TODO type checker decorator
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
if 'dataset' in kwargs: if dataset is not None:
assert 'data' not in kwargs, "dataset and data cannot be both presented in TrainConfig!" assert data is None, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = kwargs.pop('dataset') self.dataset = dataset
assert_type(self.dataset, DataFlow) assert_type(self.dataset, DataFlow)
else: else:
self.data = kwargs.pop('data') self.data = data
assert_type(self.data, InputData) assert_type(self.data, InputData)
self.optimizer = kwargs.pop('optimizer') self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer) assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks') self.callbacks = callbacks
assert_type(self.callbacks, Callbacks) assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model') self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config()) self.session_config = session_config
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', JustCurrentSession()) if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.step_per_epoch = kwargs.pop('step_per_epoch', None) self.step_per_epoch = step_per_epoch
if self.step_per_epoch is None: if self.step_per_epoch is None:
try: try:
if hasattr(self, 'dataset'): if dataset is not None:
self.step_per_epoch = self.dataset.size() self.step_per_epoch = self.dataset.size()
else: else:
self.step_per_epoch = self.data.size() self.step_per_epoch = self.data.size()
...@@ -73,22 +83,20 @@ class TrainConfig(object): ...@@ -73,22 +83,20 @@ class TrainConfig(object):
else: else:
self.step_per_epoch = int(self.step_per_epoch) self.step_per_epoch = int(self.step_per_epoch)
self.starting_epoch = int(kwargs.pop('starting_epoch', 1)) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(kwargs.pop('max_epoch', 99999)) self.max_epoch = int(max_epoch)
assert self.step_per_epoch >= 0 and self.max_epoch > 0 assert self.step_per_epoch >= 0 and self.max_epoch > 0
if 'nr_tower' in kwargs: self.nr_tower = nr_tower
assert 'tower' not in kwargs, "Cannot set both nr_tower and tower in TrainConfig!" if tower is not None:
self.nr_tower = kwargs.pop('nr_tower') assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
elif 'tower' in kwargs: self.tower = tower
self.tower = kwargs.pop('tower')
else: self.predict_tower = predict_tower
self.tower = [0]
self.predict_tower = kwargs.pop('predict_tower', [0])
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
# TODO deprecated @Dec20 # TODO deprecated @Jan20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs: if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs") logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
......
...@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput ...@@ -15,11 +15,14 @@ from .input_data import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
from .trainer import MultiPredictorTowerTrainer from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer'] __all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer): class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """ """ A trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # need to run summary_op every epoch
...@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer): ...@@ -37,7 +40,7 @@ class FeedfreeTrainer(Trainer):
class SingleCostFeedfreeTrainer(FeedfreeTrainer): class SingleCostFeedfreeTrainer(FeedfreeTrainer):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower""" """ get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
...@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -52,7 +55,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
return cost_var, grads return cost_var, grads
def run_step(self): def run_step(self):
""" Simply run self.train_op""" """ Simply run ``self.train_op``, which minimizes the cost."""
self.sess.run(self.train_op) self.sess.run(self.train_op)
# if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
# self.cnt = 0 # self.cnt = 0
......
...@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary ...@@ -13,7 +13,7 @@ from ..tfutils.summary import add_moving_summary
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput', __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput'] 'DummyConstantInput']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment