Commit d9e7c6bf authored by Yuxin Wu's avatar Yuxin Wu

allow data / dataset as input to trainconfig

parent 8db5bcd3
......@@ -5,15 +5,15 @@
import tensorflow as tf
import numpy as np
from tensorpack import (QueueInputTrainer, TowerContext,
from tensorpack import (FeedfreeTrainer, TowerContext,
get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.dataflow import DataFlow
class GANTrainer(QueueInputTrainer):
class GANTrainer(FeedfreeTrainer):
def __init__(self, config, g_vs_d=1):
super(GANTrainer, self).__init__(config)
self._input_method = QueueInput(config.dataset)
super(GANTrainer, self).__init__(config)
if g_vs_d > 1:
self._opt_g = g_vs_d
self._opt_d = 1
......
......@@ -82,7 +82,7 @@ class Model(ModelDesc):
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg)
self.g_loss = tf.add(self.g_loss, MIloss, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, MIloss, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, MIloss, name='total_d_loss')
summary.add_moving_summary(MIloss, self.g_loss, self.d_loss, Hc, Elog_qc_given_x)
all_vars = tf.trainable_variables()
......
......@@ -10,6 +10,7 @@ from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..dataflow import DataFlow
from .input_data import InputData
__all__ = ['TrainConfig']
......@@ -20,6 +21,7 @@ class TrainConfig(object):
def __init__(self, **kwargs):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training.
......@@ -35,8 +37,14 @@ class TrainConfig(object):
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
if 'dataset' in kwargs:
assert 'data' not in kwargs, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
else:
self.data = kwargs.pop('data')
assert_type(self.data, InputData)
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
......@@ -52,7 +60,10 @@ class TrainConfig(object):
self.step_per_epoch = kwargs.pop('step_per_epoch', None)
if self.step_per_epoch is None:
try:
self.step_per_epoch = self.dataset.size()
if hasattr(self, 'dataset'):
self.step_per_epoch = self.dataset.size()
else:
self.step_per_epoch = self.data.size()
except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.")
else:
......@@ -70,6 +81,7 @@ class TrainConfig(object):
else:
self.tower = [0]
# TODO deprecated @Dec20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: inputmethod.py
# File: input_data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import threading
from abc import ABCMeta, abstractmethod
from ..dataflow.common import RepeatedData
from ..tfutils.summary import add_moving_summary
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
class InputMethod(object):
class InputData(object):
__metaclass__ = ABCMeta
pass
class FeedInput(InputMethod):
class FeedInput(InputData):
def __init__(self, ds):
self.ds = ds
......@@ -26,8 +27,16 @@ class FeedInput(InputMethod):
def _setup(self, trainer):
self.input_vars = trainer.model.get_input_vars()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
class FeedfreeInput(InputMethod):
def next_feed(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
return feed
class FeedfreeInput(InputData):
def get_input_tensors(self):
return self._get_input_tensors()
......
......@@ -17,7 +17,7 @@ from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer
from .inputmethod import QueueInput
from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
......@@ -47,10 +47,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
@staticmethod
def _average_grads(tower_grads):
......@@ -95,17 +101,23 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue=None,
predict_tower=None,
average_gradient=True):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue)
self.average_gradient = average_gradient
self._average_gradient = average_gradient
def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor()
if self.average_gradient and self.config.nr_tower > 1:
if self._average_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
......
......@@ -3,19 +3,16 @@
# File: queue.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import threading
import tensorflow as tf
from ..dataflow.common import RepeatedData
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import get_global_step_var, TowerContext
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import QueueInput
from ..tfutils.summary import summary_moving_average
from .input_data import QueueInput
from .trainer import (FeedfreeTrainer, MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer)
from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer)
__all__ = ['QueueInputTrainer']
......@@ -30,14 +27,19 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(QueueInputTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue)
def _setup(self):
super(QueueInputTrainer, self)._setup()
self._setup_predictor_factory(predict_tower)
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
def _setup(self):
super(SingleCostFeedfreeTrainer, self)._setup()
with TowerContext(''):
cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
......@@ -47,4 +49,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
......@@ -8,15 +8,13 @@ from six.moves import zip
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import FeedfreeInput
from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedfreeTrainer']
......@@ -59,13 +57,18 @@ class SimpleTrainer(Trainer):
def __init__(self, config):
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if not hasattr(config, 'dataset'):
self._input_method = config.data
assert isinstance(self._input_method, FeedInput)
else:
self._input_method = FeedInput(config.dataset)
def run_step(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
feed = self._input_method.next_feed()
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def _setup(self):
self._input_method._setup(self)
model = self.model
self.input_vars = model.get_input_vars()
with TowerContext(''):
......@@ -81,14 +84,9 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
def _trigger_epoch(self):
if self.summary_op is not None:
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
feed = self._input_method.next_feed()
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
......@@ -126,7 +124,7 @@ class FeedfreeTrainer(Trainer):
return self._input_method.get_input_tensors()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput)
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
......@@ -155,3 +153,4 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment