Commit a6936913 authored by Yuxin Wu's avatar Yuxin Wu

make 'TrainConfig' a simple key-value holder -- let it use train_with_defaults...

make 'TrainConfig' a simple key-value holder -- let it use train_with_defaults to handle the defaults
parent 55098813
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: DQNModel.py # File: DQNModel.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import abc import abc
import tensorflow as tf import tensorflow as tf
......
...@@ -28,7 +28,7 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/): ...@@ -28,7 +28,7 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/):
This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs. This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs.
This VGG16 script reaches 28.8% single-crop error after 100 epochs (30h with 8 P100s). It gets 1% better if BN is enabled. This VGG16 script reaches 29~30% single-crop error after 100 epochs (30h with 8 P100s), and 28% if BN is enabled.
### ResNet, DoReFa-Net ### ResNet, DoReFa-Net
......
...@@ -124,7 +124,7 @@ if __name__ == '__main__': ...@@ -124,7 +124,7 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.set_logger_dir(os.path.join('train_log', 'vgg16')) logger.set_logger_dir(os.path.join('train_log', 'vgg16-norm={}'.format(args.norm)))
config = get_config() config = get_config()
if args.load: if args.load:
......
...@@ -18,31 +18,37 @@ class PredictConfig(object): ...@@ -18,31 +18,37 @@ class PredictConfig(object):
model=None, model=None,
inputs_desc=None, inputs_desc=None,
tower_func=None, tower_func=None,
session_creator=None,
session_init=None,
input_names=None, input_names=None,
output_names=None, output_names=None,
session_creator=None,
session_init=None,
return_input=False, return_input=False,
create_graph=True, create_graph=True,
): ):
""" """
You need to set either `model`, or `inputs_desc` plus `tower_func`.
They are needed to construct the graph.
You'll also have to set `output_names` as it does not have a default.
Args: Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func. model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors tower_func: a callable which takes input tensors and construct a tower.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
session_creator (tf.train.SessionCreator): how to create the session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`tf.train.ChiefSessionCreator()`. session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing. Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input (bool): same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized. when predictor is first initialized.
You need to set either `model`, or `inputs_desc` plus `tower_func`.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
......
...@@ -7,6 +7,7 @@ import weakref ...@@ -7,6 +7,7 @@ import weakref
import time import time
from six.moves import range from six.moves import range
import six import six
import copy
from ..callbacks import ( from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor) Callback, Callbacks, Monitors, TrainingMonitor)
...@@ -187,6 +188,8 @@ class Trainer(object): ...@@ -187,6 +188,8 @@ class Trainer(object):
callbacks ([Callback]): callbacks ([Callback]):
monitors ([TrainingMonitor]): monitors ([TrainingMonitor]):
""" """
assert isinstance(callbacks, list), callbacks
assert isinstance(monitors, list), monitors
describe_trainable_vars() # TODO weird describe_trainable_vars() # TODO weird
self.register_callback(MaintainStepCounter()) self.register_callback(MaintainStepCounter())
...@@ -284,7 +287,7 @@ class Trainer(object): ...@@ -284,7 +287,7 @@ class Trainer(object):
session_creator, session_init, session_creator, session_init,
steps_per_epoch, starting_epoch=1, max_epoch=9999999): steps_per_epoch, starting_epoch=1, max_epoch=9999999):
""" """
Implemented by: Implemented by three lines:
.. code-block:: python .. code-block:: python
...@@ -299,18 +302,24 @@ class Trainer(object): ...@@ -299,18 +302,24 @@ class Trainer(object):
self.main_loop(steps_per_epoch, starting_epoch, max_epoch) self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_defaults( def train_with_defaults(
self, callbacks=None, monitors=None, self, _sentinel=None,
callbacks=None, monitors=None,
session_creator=None, session_init=None, session_creator=None, session_init=None,
steps_per_epoch=None, starting_epoch=1, max_epoch=9999999): steps_per_epoch=None, starting_epoch=1, max_epoch=9999999,
extra_callbacks=None):
""" """
Same as :meth:`train()`, but will: Same as :meth:`train()`, except:
1. Append :meth:`DEFAULT_CALLBACKS()` to callbacks. 1. Add `extra_callbacks` to callbacks. The default value for
2. Append :meth:`DEFAULT_MONITORS()` to monitors. `extra_callbacks` is :meth:`DEFAULT_CALLBACKS()`.
2. Default value for `monitors` is :meth:`DEFAULT_MONITORS()`.
3. Provide default values for every option except `steps_per_epoch`. 3. Provide default values for every option except `steps_per_epoch`.
""" """
callbacks = (callbacks or []) + DEFAULT_CALLBACKS() assert _sentinel is None, "Please call `train_with_defaults` with keyword arguments only!"
monitors = (monitors or []) + DEFAULT_MONITORS() callbacks = copy.copy(callbacks or [])
monitors = DEFAULT_MONITORS() if monitors is None else monitors
extra_callbacks = DEFAULT_CALLBACKS() if extra_callbacks is None else extra_callbacks
callbacks.extend(extra_callbacks)
assert steps_per_epoch is not None assert steps_per_epoch is not None
session_creator = session_creator or NewSessionCreator() session_creator = session_creator or NewSessionCreator()
......
...@@ -12,7 +12,7 @@ from ..callbacks import ( ...@@ -12,7 +12,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger from ..utils import logger
from ..tfutils.sessinit import JustCurrentSession, SessionInit, SaverRestore from ..tfutils.sessinit import SessionInit, SaverRestore
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource from ..input_source import InputSource
...@@ -52,21 +52,20 @@ def DEFAULT_MONITORS(): ...@@ -52,21 +52,20 @@ def DEFAULT_MONITORS():
class TrainConfig(object): class TrainConfig(object):
""" """
A collection of options to be used for trainers. A collection of options to be used for single-cost trainers.
""" """
def __init__(self, def __init__(self,
dataflow=None, data=None, model=None, dataflow=None, data=None,
model=None,
callbacks=None, extra_callbacks=None, monitors=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
nr_tower=1, tower=None,
**kwargs):
""" """
Args: Args:
dataflow (DataFlow): dataflow (DataFlow):
data (InputSource): data (InputSource):
model (ModelDescBase): model (ModelDesc):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
...@@ -107,20 +106,18 @@ class TrainConfig(object): ...@@ -107,20 +106,18 @@ class TrainConfig(object):
assert_type(model, ModelDescBase) assert_type(model, ModelDescBase)
self.model = model self.model = model
if callbacks is None: if callbacks is not None:
callbacks = [] assert_type(callbacks, list)
assert_type(callbacks, list) self.callbacks = callbacks
if extra_callbacks is not None: if extra_callbacks is not None:
self._callbacks = callbacks + extra_callbacks assert_type(extra_callbacks, list)
else: self.extra_callbacks = extra_callbacks
self._callbacks = callbacks + DEFAULT_CALLBACKS() if monitors is not None:
assert_type(monitors, list)
self.monitors = monitors if monitors is not None else DEFAULT_MONITORS() self.monitors = monitors
if session_init is not None:
if session_init is None: assert_type(session_init, SessionInit)
session_init = JustCurrentSession()
self.session_init = session_init self.session_init = session_init
assert_type(self.session_init, SessionInit)
if session_creator is None: if session_creator is None:
if session_config is not None: if session_config is not None:
...@@ -149,27 +146,6 @@ class TrainConfig(object): ...@@ -149,27 +146,6 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
# Tower stuff are for Trainer v1 only:
nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower
if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
@property
def nr_tower(self):
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
self.tower = list(range(value))
@property
def callbacks(self): # disable setter
return self._callbacks
class AutoResumeTrainConfig(TrainConfig): class AutoResumeTrainConfig(TrainConfig):
""" """
......
...@@ -76,14 +76,16 @@ def launch_train_with_config(config, trainer): ...@@ -76,14 +76,16 @@ def launch_train_with_config(config, trainer):
inputs_desc = model.get_inputs_desc() inputs_desc = model.get_inputs_desc()
input = config.data or config.dataflow input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer) input = apply_default_prefetch(input, trainer)
if config.nr_tower > 1:
logger.warn("With trainer v2, setting tower in TrainConfig has no effect.")
logger.warn("It's enough to set the tower when initializing the trainer.")
trainer.setup_graph( trainer.setup_graph(
inputs_desc, input, inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer) model._build_graph_get_cost, model.get_optimizer)
trainer.train( trainer.train_with_defaults(
config.callbacks, config.monitors, callbacks=config.callbacks,
config.session_creator, config.session_init, monitors=config.monitors,
config.steps_per_epoch, config.starting_epoch, config.max_epoch) session_creator=config.session_creator,
session_init=config.session_init,
steps_per_epoch=config.steps_per_epoch,
starting_epoch=config.starting_epoch,
max_epoch=config.max_epoch,
extra_callbacks=config.extra_callbacks)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment