Commit 651a5aea authored by Yuxin Wu's avatar Yuxin Wu

deprecate TrainConfig.dataset and use 'dataflow' instead

parent 069c0b9c
......@@ -173,7 +173,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -233,7 +233,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 1e-4, summary=True)
return TrainConfig(
dataset=data_train,
dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -161,7 +161,7 @@ def get_config():
tf.summary.scalar('lr', lr)
return TrainConfig(
dataset=data_train,
dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([
StatPrinter(),
......
......@@ -109,7 +109,7 @@ def get_config():
dataset = get_data()
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataset=dataset,
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -6,13 +6,13 @@
import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainer, TowerContext,
from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.dataflow import DataFlow
class GANTrainer(FeedfreeTrainer):
class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
self._input_method = QueueInput(config.dataset)
......
......@@ -168,7 +168,7 @@ def get_config():
dataset = get_data()
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataset=dataset,
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), PeriodicCallback(ModelSaver(), 3),
......
......@@ -104,7 +104,7 @@ def get_config():
dataset = get_data()
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataset=dataset,
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -171,7 +171,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 3e-5, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -158,7 +158,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.045, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -266,7 +266,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.045, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -204,7 +204,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 0.001, summary=True)
return TrainConfig(
dataset=dataflow,
dataflow=dataflow,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -139,7 +139,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.01, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -187,7 +187,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.1, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -68,7 +68,7 @@ def get_config():
lr = get_scalar_var('learning_rate', 0.01, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
......
......@@ -153,7 +153,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 5e-4, summary=True)
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -94,7 +94,7 @@ def get_config(ds_train, ds_test):
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
return TrainConfig(
dataset=ds_train,
dataflow=ds_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -107,7 +107,7 @@ def get_config():
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-3, summary=True)
return TrainConfig(
dataset=ds,
dataflow=ds,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -122,7 +122,7 @@ def get_config(cifar_classnum):
return lr * 0.31
return TrainConfig(
dataset=dataset_train,
dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -114,7 +114,7 @@ def get_config():
# get the config which contains everything necessary in a training
return TrainConfig(
dataset=dataset_train, # the DataFlow instance for training
dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(), # print statistics in terminal after every epoch
......
......@@ -99,7 +99,7 @@ def get_config():
tf.summary.scalar('lr', lr)
return TrainConfig(
dataset=data_train,
dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
......
......@@ -20,7 +20,7 @@ class TrainConfig(object):
Config for trainer.
"""
def __init__(self, dataset=None, data=None,
def __init__(self, dataflow=None, data=None,
model=None, optimizer=None, callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
......@@ -29,8 +29,8 @@ class TrainConfig(object):
**kwargs):
"""
Args:
dataset (DataFlow): the dataset to train.
data (InputData): an `InputData` instance. Only one of ``dataset``
dataflow (DataFlow): the dataflow to train.
data (InputData): an `InputData` instance. Only one of ``dataflow``
or ``data`` has to be present.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
......@@ -49,13 +49,19 @@ class TrainConfig(object):
# TODO type checker decorator
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
if dataset is not None:
assert data is None, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = dataset
assert_type(self.dataset, DataFlow)
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
logger.warn("[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead.")
if dataflow is not None:
assert data is None, "dataflow and data cannot be both presented in TrainConfig!"
self.dataflow = dataflow
assert_type(self.dataflow, DataFlow)
self.data = None
else:
self.data = data
assert_type(self.data, InputData)
self.dataflow = None
self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer)
......@@ -74,8 +80,8 @@ class TrainConfig(object):
self.step_per_epoch = step_per_epoch
if self.step_per_epoch is None:
try:
if dataset is not None:
self.step_per_epoch = self.dataset.size()
if dataflow is not None:
self.step_per_epoch = self.dataflow.size()
else:
self.step_per_epoch = self.data.size()
except NotImplementedError:
......
......@@ -15,12 +15,12 @@ from .input_data import QueueInput, FeedfreeInput
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer',
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster)
class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
......@@ -39,7 +39,7 @@ class FeedfreeTrainer(Trainer):
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
......@@ -78,11 +78,16 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer):
"""
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
"""
def __init__(self, config):
"""
A trainer with single cost, single training tower and feed-free input
config.data must exists
Args:
config (TrainConfig): ``config.data`` must exist and is a
:class:`FeedfreeInput`.
"""
self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method
......@@ -105,17 +110,20 @@ class SimpleFeedfreeTrainer(
class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
A trainer which automatically wraps ``config.dataflow``
"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist
:param config: a `TrainConfig` instance. config.dataflow must exist
:param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
config.data = QueueInput(config.dataset, input_queue)
config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
......
......@@ -53,8 +53,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
......@@ -122,8 +122,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue=None,
average_gradient=True,
predict_tower=None):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
......
......@@ -59,11 +59,11 @@ class SimpleTrainer(Trainer):
def __init__(self, config):
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if not hasattr(config, 'dataset'):
if config.dataflow is None:
self._input_method = config.data
assert isinstance(self._input_method, FeedInput)
else:
self._input_method = FeedInput(config.dataset)
self._input_method = FeedInput(config.dataflow)
def run_step(self):
feed = self._input_method.next_feed()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment