Commit 0b226adb authored by Yuxin Wu's avatar Yuxin Wu

allow empty model & data in config

parent e791b9a5
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, MovingAverageSummary,
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps) TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
...@@ -24,20 +24,23 @@ class TrainConfig(object): ...@@ -24,20 +24,23 @@ class TrainConfig(object):
""" """
def __init__(self, def __init__(self,
dataflow=None, data=None, dataflow=None, data=None, model=None,
model=None, callbacks=None, extra_callbacks=None, monitors=None,
callbacks=None, extra_callbacks=None,
monitors=None,
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0], nr_tower=1, tower=None, predict_tower=None,
**kwargs): **kwargs):
""" """
Note:
It depends on the specific trainer what fields are necessary.
Most existing trainers in tensorpack requires one of `dataflow` or `data`,
and `model` to be present in the config.
Args: Args:
dataflow (DataFlow): the dataflow to train. dataflow (DataFlow):
data (InputSource): an `InputSource` instance. Only one of ``dataflow`` data (InputSource):
or ``data`` has to be present. model (ModelDesc):
model (ModelDesc): the model to train.
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
...@@ -45,14 +48,17 @@ class TrainConfig(object): ...@@ -45,14 +48,17 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()` session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`. with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session. session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing. session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size. Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers. nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id. tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu. predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
...@@ -62,7 +68,7 @@ class TrainConfig(object): ...@@ -62,7 +68,7 @@ class TrainConfig(object):
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
# process data # process data & model
if 'dataset' in kwargs: if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset') dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11") log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
...@@ -71,16 +77,16 @@ class TrainConfig(object): ...@@ -71,16 +77,16 @@ class TrainConfig(object):
self.dataflow = dataflow self.dataflow = dataflow
assert_type(self.dataflow, DataFlow) assert_type(self.dataflow, DataFlow)
self.data = None self.data = None
else: if data is not None:
self.data = data self.data = data
assert_type(self.data, InputSource) assert_type(self.data, InputSource)
self.dataflow = None self.dataflow = None
if model is not None:
assert_type(model, ModelDesc)
self.model = model
if callbacks is None: if callbacks is None:
callbacks = [] callbacks = []
assert not isinstance(callbacks, Callbacks), \
"TrainConfig(callbacks=Callbacks([...]))" \
"Change the argument 'callbacks=' to a *list* of callbacks without StatPrinter()."
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
extra_callbacks = [ extra_callbacks = [
...@@ -89,16 +95,11 @@ class TrainConfig(object): ...@@ -89,16 +95,11 @@ class TrainConfig(object):
MergeAllSummaries(), MergeAllSummaries(),
RunUpdateOps()] RunUpdateOps()]
self._callbacks = callbacks + extra_callbacks self._callbacks = callbacks + extra_callbacks
assert_type(self._callbacks, list)
if monitors is None: if monitors is None:
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()] monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors self.monitors = monitors
if model is not None:
assert_type(model, ModelDesc)
self.model = model
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
self.session_init = session_init self.session_init = session_init
...@@ -128,22 +129,19 @@ class TrainConfig(object): ...@@ -128,22 +129,19 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
assert self.steps_per_epoch >= 0 and self.max_epoch > 0 assert self.steps_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = nr_tower self.nr_tower = nr_tower
if tower is not None: if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!" assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower self.tower = tower
if predict_tower is None:
predict_tower = [0]
self.predict_tower = predict_tower self.predict_tower = predict_tower
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
assert len(set(self.predict_tower)) == len(self.predict_tower), \
"Cannot have duplicated predict_tower!"
assert 'optimizer' not in kwargs, \
"TrainConfig(optimizer=...) was already deprecated! " \
"Use ModelDesc._get_optimizer() instead."
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
@property @property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment