Commit 0b226adb authored by Yuxin Wu's avatar Yuxin Wu

allow empty model & data in config

parent e791b9a5
......@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..callbacks import (
Callbacks, MovingAverageSummary,
MovingAverageSummary,
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow
......@@ -24,20 +24,23 @@ class TrainConfig(object):
"""
def __init__(self,
dataflow=None, data=None,
model=None,
callbacks=None, extra_callbacks=None,
monitors=None,
dataflow=None, data=None, model=None,
callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
nr_tower=1, tower=None, predict_tower=None,
**kwargs):
"""
Note:
It depends on the specific trainer what fields are necessary.
Most existing trainers in tensorpack requires one of `dataflow` or `data`,
and `model` to be present in the config.
Args:
dataflow (DataFlow): the dataflow to train.
data (InputSource): an `InputSource` instance. Only one of ``dataflow``
or ``data`` has to be present.
model (ModelDesc): the model to train.
dataflow (DataFlow):
data (InputSource):
model (ModelDesc):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
......@@ -45,14 +48,17 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
......@@ -62,7 +68,7 @@ class TrainConfig(object):
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
# process data
# process data & model
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
......@@ -71,16 +77,16 @@ class TrainConfig(object):
self.dataflow = dataflow
assert_type(self.dataflow, DataFlow)
self.data = None
else:
if data is not None:
self.data = data
assert_type(self.data, InputSource)
self.dataflow = None
if model is not None:
assert_type(model, ModelDesc)
self.model = model
if callbacks is None:
callbacks = []
assert not isinstance(callbacks, Callbacks), \
"TrainConfig(callbacks=Callbacks([...]))" \
"Change the argument 'callbacks=' to a *list* of callbacks without StatPrinter()."
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [
......@@ -89,16 +95,11 @@ class TrainConfig(object):
MergeAllSummaries(),
RunUpdateOps()]
self._callbacks = callbacks + extra_callbacks
assert_type(self._callbacks, list)
if monitors is None:
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
if model is not None:
assert_type(model, ModelDesc)
self.model = model
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
......@@ -128,22 +129,19 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
assert self.steps_per_epoch >= 0 and self.max_epoch > 0
assert self.steps_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = nr_tower
if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower
if predict_tower is None:
predict_tower = [0]
self.predict_tower = predict_tower
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
assert len(set(self.predict_tower)) == len(self.predict_tower), \
"Cannot have duplicated predict_tower!"
assert 'optimizer' not in kwargs, \
"TrainConfig(optimizer=...) was already deprecated! " \
"Use ModelDesc._get_optimizer() instead."
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
@property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment