Commit 651a5aea authored by Yuxin Wu's avatar Yuxin Wu

deprecate TrainConfig.dataset and use 'dataflow' instead

parent 069c0b9c
...@@ -173,7 +173,7 @@ def get_config(): ...@@ -173,7 +173,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True) lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -233,7 +233,7 @@ def get_config(): ...@@ -233,7 +233,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 1e-4, summary=True) lr = get_scalar_var('learning_rate', 1e-4, summary=True)
return TrainConfig( return TrainConfig(
dataset=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -161,7 +161,7 @@ def get_config(): ...@@ -161,7 +161,7 @@ def get_config():
tf.summary.scalar('lr', lr) tf.summary.scalar('lr', lr)
return TrainConfig( return TrainConfig(
dataset=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
......
...@@ -109,7 +109,7 @@ def get_config(): ...@@ -109,7 +109,7 @@ def get_config():
dataset = get_data() dataset = get_data()
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainer, TowerContext, from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput) get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
class GANTrainer(FeedfreeTrainer): class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataset) self._input_method = QueueInput(config.dataset)
......
...@@ -168,7 +168,7 @@ def get_config(): ...@@ -168,7 +168,7 @@ def get_config():
dataset = get_data() dataset = get_data()
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), PeriodicCallback(ModelSaver(), 3), StatPrinter(), PeriodicCallback(ModelSaver(), 3),
......
...@@ -104,7 +104,7 @@ def get_config(): ...@@ -104,7 +104,7 @@ def get_config():
dataset = get_data() dataset = get_data()
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True) lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -171,7 +171,7 @@ def get_config(): ...@@ -171,7 +171,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 3e-5, summary=True) lr = get_scalar_var('learning_rate', 3e-5, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -158,7 +158,7 @@ def get_config(): ...@@ -158,7 +158,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.045, summary=True) lr = get_scalar_var('learning_rate', 0.045, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -266,7 +266,7 @@ def get_config(): ...@@ -266,7 +266,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.045, summary=True) lr = get_scalar_var('learning_rate', 0.045, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -204,7 +204,7 @@ def get_config(): ...@@ -204,7 +204,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 0.001, summary=True) lr = symbf.get_scalar_var('learning_rate', 0.001, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataflow, dataflow=dataflow,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -139,7 +139,7 @@ def get_config(): ...@@ -139,7 +139,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.01, summary=True) lr = get_scalar_var('learning_rate', 0.01, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -187,7 +187,7 @@ def get_config(): ...@@ -187,7 +187,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.1, summary=True) lr = get_scalar_var('learning_rate', 0.1, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True), optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -68,7 +68,7 @@ def get_config(): ...@@ -68,7 +68,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.01, summary=True) lr = get_scalar_var('learning_rate', 0.01, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
......
...@@ -153,7 +153,7 @@ def get_config(): ...@@ -153,7 +153,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 5e-4, summary=True) lr = symbf.get_scalar_var('learning_rate', 5e-4, summary=True)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -94,7 +94,7 @@ def get_config(ds_train, ds_test): ...@@ -94,7 +94,7 @@ def get_config(ds_train, ds_test):
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
return TrainConfig( return TrainConfig(
dataset=ds_train, dataflow=ds_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -107,7 +107,7 @@ def get_config(): ...@@ -107,7 +107,7 @@ def get_config():
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-3, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 2e-3, summary=True)
return TrainConfig( return TrainConfig(
dataset=ds, dataflow=ds,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -122,7 +122,7 @@ def get_config(cifar_classnum): ...@@ -122,7 +122,7 @@ def get_config(cifar_classnum):
return lr * 0.31 return lr * 0.31
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -114,7 +114,7 @@ def get_config(): ...@@ -114,7 +114,7 @@ def get_config():
# get the config which contains everything necessary in a training # get the config which contains everything necessary in a training
return TrainConfig( return TrainConfig(
dataset=dataset_train, # the DataFlow instance for training dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), # print statistics in terminal after every epoch StatPrinter(), # print statistics in terminal after every epoch
......
...@@ -99,7 +99,7 @@ def get_config(): ...@@ -99,7 +99,7 @@ def get_config():
tf.summary.scalar('lr', lr) tf.summary.scalar('lr', lr)
return TrainConfig( return TrainConfig(
dataset=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
......
...@@ -20,7 +20,7 @@ class TrainConfig(object): ...@@ -20,7 +20,7 @@ class TrainConfig(object):
Config for trainer. Config for trainer.
""" """
def __init__(self, dataset=None, data=None, def __init__(self, dataflow=None, data=None,
model=None, optimizer=None, callbacks=None, model=None, optimizer=None, callbacks=None,
session_config=get_default_sess_config(), session_config=get_default_sess_config(),
session_init=None, session_init=None,
...@@ -29,8 +29,8 @@ class TrainConfig(object): ...@@ -29,8 +29,8 @@ class TrainConfig(object):
**kwargs): **kwargs):
""" """
Args: Args:
dataset (DataFlow): the dataset to train. dataflow (DataFlow): the dataflow to train.
data (InputData): an `InputData` instance. Only one of ``dataset`` data (InputData): an `InputData` instance. Only one of ``dataflow``
or ``data`` has to be present. or ``data`` has to be present.
model (ModelDesc): the model to train. model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig. optimizer (tf.train.Optimizer): the optimizer for trainig.
...@@ -49,13 +49,19 @@ class TrainConfig(object): ...@@ -49,13 +49,19 @@ class TrainConfig(object):
# TODO type checker decorator # TODO type checker decorator
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
if dataset is not None:
assert data is None, "dataset and data cannot be both presented in TrainConfig!" if 'dataset' in kwargs:
self.dataset = dataset dataflow = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow) logger.warn("[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead.")
if dataflow is not None:
assert data is None, "dataflow and data cannot be both presented in TrainConfig!"
self.dataflow = dataflow
assert_type(self.dataflow, DataFlow)
self.data = None
else: else:
self.data = data self.data = data
assert_type(self.data, InputData) assert_type(self.data, InputData)
self.dataflow = None
self.optimizer = optimizer self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer) assert_type(self.optimizer, tf.train.Optimizer)
...@@ -74,8 +80,8 @@ class TrainConfig(object): ...@@ -74,8 +80,8 @@ class TrainConfig(object):
self.step_per_epoch = step_per_epoch self.step_per_epoch = step_per_epoch
if self.step_per_epoch is None: if self.step_per_epoch is None:
try: try:
if dataset is not None: if dataflow is not None:
self.step_per_epoch = self.dataset.size() self.step_per_epoch = self.dataflow.size()
else: else:
self.step_per_epoch = self.data.size() self.step_per_epoch = self.data.size()
except NotImplementedError: except NotImplementedError:
......
...@@ -15,12 +15,12 @@ from .input_data import QueueInput, FeedfreeInput ...@@ -15,12 +15,12 @@ from .input_data import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
from .trainer import MultiPredictorTowerTrainer from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer'] 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer): class FeedfreeTrainerBase(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
...@@ -39,7 +39,7 @@ class FeedfreeTrainer(Trainer): ...@@ -39,7 +39,7 @@ class FeedfreeTrainer(Trainer):
self._input_method._setup(self) self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer): class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower""" """ get the cost and gradient on a new tower"""
...@@ -78,11 +78,16 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -78,11 +78,16 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer): SingleCostFeedfreeTrainer):
"""
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
"""
def __init__(self, config): def __init__(self, config):
""" """
A trainer with single cost, single training tower and feed-free input Args:
config.data must exists config (TrainConfig): ``config.data`` must exist and is a
:class:`FeedfreeInput`.
""" """
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method assert isinstance(self._input_method, FeedfreeInput), self._input_method
...@@ -105,17 +110,20 @@ class SimpleFeedfreeTrainer( ...@@ -105,17 +110,20 @@ class SimpleFeedfreeTrainer(
class QueueInputTrainer(SimpleFeedfreeTrainer): class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
A trainer which automatically wraps ``config.dataflow``
"""
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
""" """
Single tower Trainer, takes input from a queue Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist :param config: a `TrainConfig` instance. config.dataflow must exist
:param input_queue: a `tf.QueueBase` instance :param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0]. :param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu. Use -1 for cpu.
""" """
config.data = QueueInput(config.dataset, input_queue) config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. " logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!") "Use TrainConfig.predict_tower instead!")
......
...@@ -53,8 +53,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -53,8 +53,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
if hasattr(config, 'dataset'): if config.dataflow is not None:
self._input_method = QueueInput(config.dataset, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
...@@ -122,8 +122,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -122,8 +122,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue=None, input_queue=None,
average_gradient=True, average_gradient=True,
predict_tower=None): predict_tower=None):
if hasattr(config, 'dataset'): if config.dataflow is not None:
self._input_method = QueueInput(config.dataset, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
......
...@@ -59,11 +59,11 @@ class SimpleTrainer(Trainer): ...@@ -59,11 +59,11 @@ class SimpleTrainer(Trainer):
def __init__(self, config): def __init__(self, config):
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if not hasattr(config, 'dataset'): if config.dataflow is None:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedInput) assert isinstance(self._input_method, FeedInput)
else: else:
self._input_method = FeedInput(config.dataset) self._input_method = FeedInput(config.dataflow)
def run_step(self): def run_step(self):
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment