Commit a6936913 authored by Yuxin Wu's avatar Yuxin Wu

make 'TrainConfig' a simple key-value holder -- let it use train_with_defaults...

make 'TrainConfig' a simple key-value holder -- let it use train_with_defaults to handle the defaults
parent 55098813
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: DQNModel.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import abc
import tensorflow as tf
......
......@@ -28,7 +28,7 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/):
This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs.
This VGG16 script reaches 28.8% single-crop error after 100 epochs (30h with 8 P100s). It gets 1% better if BN is enabled.
This VGG16 script reaches 29~30% single-crop error after 100 epochs (30h with 8 P100s), and 28% if BN is enabled.
### ResNet, DoReFa-Net
......
......@@ -124,7 +124,7 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.set_logger_dir(os.path.join('train_log', 'vgg16'))
logger.set_logger_dir(os.path.join('train_log', 'vgg16-norm={}'.format(args.norm)))
config = get_config()
if args.load:
......
......@@ -18,31 +18,37 @@ class PredictConfig(object):
model=None,
inputs_desc=None,
tower_func=None,
session_creator=None,
session_init=None,
input_names=None,
output_names=None,
session_creator=None,
session_init=None,
return_input=False,
create_graph=True,
):
"""
You need to set either `model`, or `inputs_desc` plus `tower_func`.
They are needed to construct the graph.
You'll also have to set `output_names` as it does not have a default.
Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors
tower_func: a callable which takes input tensors and construct a tower.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
You need to set either `model`, or `inputs_desc` plus `tower_func`.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......
......@@ -7,6 +7,7 @@ import weakref
import time
from six.moves import range
import six
import copy
from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor)
......@@ -187,6 +188,8 @@ class Trainer(object):
callbacks ([Callback]):
monitors ([TrainingMonitor]):
"""
assert isinstance(callbacks, list), callbacks
assert isinstance(monitors, list), monitors
describe_trainable_vars() # TODO weird
self.register_callback(MaintainStepCounter())
......@@ -284,7 +287,7 @@ class Trainer(object):
session_creator, session_init,
steps_per_epoch, starting_epoch=1, max_epoch=9999999):
"""
Implemented by:
Implemented by three lines:
.. code-block:: python
......@@ -299,18 +302,24 @@ class Trainer(object):
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_defaults(
self, callbacks=None, monitors=None,
self, _sentinel=None,
callbacks=None, monitors=None,
session_creator=None, session_init=None,
steps_per_epoch=None, starting_epoch=1, max_epoch=9999999):
steps_per_epoch=None, starting_epoch=1, max_epoch=9999999,
extra_callbacks=None):
"""
Same as :meth:`train()`, but will:
Same as :meth:`train()`, except:
1. Append :meth:`DEFAULT_CALLBACKS()` to callbacks.
2. Append :meth:`DEFAULT_MONITORS()` to monitors.
1. Add `extra_callbacks` to callbacks. The default value for
`extra_callbacks` is :meth:`DEFAULT_CALLBACKS()`.
2. Default value for `monitors` is :meth:`DEFAULT_MONITORS()`.
3. Provide default values for every option except `steps_per_epoch`.
"""
callbacks = (callbacks or []) + DEFAULT_CALLBACKS()
monitors = (monitors or []) + DEFAULT_MONITORS()
assert _sentinel is None, "Please call `train_with_defaults` with keyword arguments only!"
callbacks = copy.copy(callbacks or [])
monitors = DEFAULT_MONITORS() if monitors is None else monitors
extra_callbacks = DEFAULT_CALLBACKS() if extra_callbacks is None else extra_callbacks
callbacks.extend(extra_callbacks)
assert steps_per_epoch is not None
session_creator = session_creator or NewSessionCreator()
......
......@@ -12,7 +12,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils.sessinit import JustCurrentSession, SessionInit, SaverRestore
from ..tfutils.sessinit import SessionInit, SaverRestore
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
......@@ -52,21 +52,20 @@ def DEFAULT_MONITORS():
class TrainConfig(object):
"""
A collection of options to be used for trainers.
A collection of options to be used for single-cost trainers.
"""
def __init__(self,
dataflow=None, data=None, model=None,
dataflow=None, data=None,
model=None,
callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None,
**kwargs):
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
Args:
dataflow (DataFlow):
data (InputSource):
model (ModelDescBase):
model (ModelDesc):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
......@@ -107,20 +106,18 @@ class TrainConfig(object):
assert_type(model, ModelDescBase)
self.model = model
if callbacks is None:
callbacks = []
if callbacks is not None:
assert_type(callbacks, list)
self.callbacks = callbacks
if extra_callbacks is not None:
self._callbacks = callbacks + extra_callbacks
else:
self._callbacks = callbacks + DEFAULT_CALLBACKS()
self.monitors = monitors if monitors is not None else DEFAULT_MONITORS()
if session_init is None:
session_init = JustCurrentSession()
assert_type(extra_callbacks, list)
self.extra_callbacks = extra_callbacks
if monitors is not None:
assert_type(monitors, list)
self.monitors = monitors
if session_init is not None:
assert_type(session_init, SessionInit)
self.session_init = session_init
assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
......@@ -149,27 +146,6 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
# Tower stuff are for Trainer v1 only:
nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower
if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
@property
def nr_tower(self):
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
self.tower = list(range(value))
@property
def callbacks(self): # disable setter
return self._callbacks
class AutoResumeTrainConfig(TrainConfig):
"""
......
......@@ -76,14 +76,16 @@ def launch_train_with_config(config, trainer):
inputs_desc = model.get_inputs_desc()
input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer)
if config.nr_tower > 1:
logger.warn("With trainer v2, setting tower in TrainConfig has no effect.")
logger.warn("It's enough to set the tower when initializing the trainer.")
trainer.setup_graph(
inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer)
trainer.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
trainer.train_with_defaults(
callbacks=config.callbacks,
monitors=config.monitors,
session_creator=config.session_creator,
session_init=config.session_init,
steps_per_epoch=config.steps_per_epoch,
starting_epoch=config.starting_epoch,
max_epoch=config.max_epoch,
extra_callbacks=config.extra_callbacks)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment