Commit 8cea7d8a authored by Yuxin Wu's avatar Yuxin Wu

Add train_with_defaults

parent af667ff4
......@@ -134,11 +134,9 @@ if __name__ == '__main__':
sample(args.load)
else:
logger.auto_set_dir()
config = TrainConfig(
GANTrainer(QueueInput(get_data()), Model()).train_with_defaults(
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
)
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(QueueInput(get_data()), Model()).train_with_config(config)
......@@ -218,7 +218,7 @@ if __name__ == '__main__':
data = get_data(args.data)
data = PrintData(data)
config = TrainConfig(
GANTrainer(QueueInput(data), Model()).train_with_defaults(
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter(
......@@ -230,5 +230,3 @@ if __name__ == '__main__':
steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(QueueInput(data), Model()).train_with_config(config)
......@@ -156,12 +156,11 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = TrainConfig(
GANTrainer(
input=QueueInput(get_data(args.data)),
model=Model()).train_with_defaults(
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(
input=QueueInput(get_data(args.data)),
model=Model()).train_with_config(config)
......@@ -217,13 +217,11 @@ if __name__ == '__main__':
data = get_celebA_data(args.data, args.style_A, args.style_B)
config = TrainConfig(
# train 1 D after 2 G
SeparateGANTrainer(
QueueInput(data), Model(), d_period=3).train_with_defaults(
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=250,
session_init=SaverRestore(args.load) if args.load else None
)
# train 1 D after 2 G
SeparateGANTrainer(
QueueInput(data), Model(), d_period=3).train_with_config(config)
......@@ -210,15 +210,13 @@ if __name__ == '__main__':
logger.auto_set_dir()
data = QueueInput(get_data())
config = TrainConfig(
GANTrainer(data, Model()).train_with_defaults(
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
steps_per_epoch=data.size(),
max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None
)
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(data, Model()).train_with_config(config)
......@@ -95,12 +95,11 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = TrainConfig(
SeparateGANTrainer(
QueueInput(DCGAN.get_data(args.data)),
Model(), g_period=6).train_with_defaults(
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
SeparateGANTrainer(
QueueInput(DCGAN.get_data(args.data)),
Model(), g_period=6).train_with_config(config)
......@@ -245,11 +245,10 @@ if __name__ == '__main__':
sample(args.load)
else:
logger.auto_set_dir()
cfg = TrainConfig(
GANTrainer(QueueInput(get_data()),
Model()).train_with_defaults(
callbacks=[ModelSaver(keep_freq=0.1)],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(QueueInput(get_data()),
Model()).train_with_config(cfg)
......@@ -76,14 +76,15 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = TrainConfig(
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer(
input=QueueInput(DCGAN.get_data(args.data)),
model=Model(),
d_period=3).train_with_defaults(
callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer(
input=QueueInput(DCGAN.get_data(args.data)),
model=Model(), d_period=3).train_with_config(config)
......@@ -9,13 +9,16 @@ from six.moves import range
import six
from abc import abstractmethod, ABCMeta
from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor,
MovingAverageSummary,
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sesscreate import ReuseSessionCreator, NewSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
......@@ -31,6 +34,18 @@ from ..trainv1.config import TrainConfig
__all__ = ['TrainConfig', 'Trainer', 'SingleCostTrainer', 'TowerTrainer']
def DEFAULT_CALLBACKS():
return [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
RunUpdateOps()]
def DEFAULT_MONITORS():
return [TFEventWriter(), JSONWriter(), ScalarPrinter()]
class Trainer(object):
""" Base class for a trainer.
"""
......@@ -220,8 +235,8 @@ class Trainer(object):
def train_with_config(self, config):
"""
An alias to simplify the use of `TrainConfig`.
It is equivalent to the following:
An alias to simplify the use of `TrainConfig` with `Trainer`.
This method is literally the following:
.. code-block:: python
......@@ -240,6 +255,28 @@ class Trainer(object):
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
def train_with_defaults(
self, callbacks=None, monitors=None,
session_creator=None, session_init=None,
steps_per_epoch=None, starting_epoch=1, max_epoch=9999):
"""
Same as :meth:`train()`, but will:
1. Append `DEFAULT_CALLBACKS()` to callbacks.
2. Append `DEFAULT_MONITORS()` to monitors.
3. Provide default values for every option except `steps_per_epoch`.
"""
callbacks = (callbacks or []) + DEFAULT_CALLBACKS()
monitors = (monitors or []) + DEFAULT_MONITORS()
assert steps_per_epoch is not None
session_creator = session_creator or NewSessionCreator()
session_init = session_init or JustCurrentSession()
self.train(callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
......
......@@ -86,4 +86,5 @@ def launch_train_with_config(config, trainer):
trainer.setup_graph(
inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer)
config.data = config.dataflow = config.model = None
trainer.train_with_config(config)
......@@ -17,6 +17,18 @@ from ..utils.develop import log_deprecated
__all__ = ['TrainConfig']
def DEFAULT_CALLBACKS():
return [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
RunUpdateOps()]
def DEFAULT_MONITORS():
return [TFEventWriter(), JSONWriter(), ScalarPrinter()]
class TrainConfig(object):
"""
A collection of options to be used for trainers.
......@@ -84,9 +96,9 @@ class TrainConfig(object):
callbacks = []
assert_type(callbacks, list)
self._callbacks = callbacks + \
(extra_callbacks or TrainConfig.DEFAULT_EXTRA_CALLBACKS())
(extra_callbacks or DEFAULT_CALLBACKS())
self.monitors = monitors or TrainConfig.DEFAULT_MONITORS()
self.monitors = monitors or DEFAULT_MONITORS()
if session_init is None:
session_init = JustCurrentSession()
......@@ -148,15 +160,3 @@ class TrainConfig(object):
@property
def callbacks(self): # disable setter
return self._callbacks
@staticmethod
def DEFAULT_EXTRA_CALLBACKS():
return [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
RunUpdateOps()]
@staticmethod
def DEFAULT_MONITORS():
return [TFEventWriter(), JSONWriter(), ScalarPrinter()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment