Commit d9e7c6bf authored by Yuxin Wu's avatar Yuxin Wu

allow data / dataset as input to trainconfig

parent 8db5bcd3
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tensorpack import (QueueInputTrainer, TowerContext, from tensorpack import (FeedfreeTrainer, TowerContext,
get_global_step_var, QueueInput) get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
class GANTrainer(QueueInputTrainer): class GANTrainer(FeedfreeTrainer):
def __init__(self, config, g_vs_d=1): def __init__(self, config, g_vs_d=1):
super(GANTrainer, self).__init__(config)
self._input_method = QueueInput(config.dataset) self._input_method = QueueInput(config.dataset)
super(GANTrainer, self).__init__(config)
if g_vs_d > 1: if g_vs_d > 1:
self._opt_g = g_vs_d self._opt_g = g_vs_d
self._opt_d = 1 self._opt_d = 1
......
...@@ -82,7 +82,7 @@ class Model(ModelDesc): ...@@ -82,7 +82,7 @@ class Model(ModelDesc):
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg) self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg)
self.g_loss = tf.add(self.g_loss, MIloss, name='total_g_loss') self.g_loss = tf.add(self.g_loss, MIloss, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, MIloss, name='total_g_loss') self.d_loss = tf.add(self.d_loss, MIloss, name='total_d_loss')
summary.add_moving_summary(MIloss, self.g_loss, self.d_loss, Hc, Elog_qc_given_x) summary.add_moving_summary(MIloss, self.g_loss, self.d_loss, Hc, Elog_qc_given_x)
all_vars = tf.trainable_variables() all_vars = tf.trainable_variables()
......
...@@ -10,6 +10,7 @@ from ..utils import logger ...@@ -10,6 +10,7 @@ from ..utils import logger
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..dataflow import DataFlow from ..dataflow import DataFlow
from .input_data import InputData
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -20,6 +21,7 @@ class TrainConfig(object): ...@@ -20,6 +21,7 @@ class TrainConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
:param dataset: the dataset to train. a `DataFlow` instance. :param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig. :param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define :param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training. the callbacks to perform during training.
...@@ -35,8 +37,14 @@ class TrainConfig(object): ...@@ -35,8 +37,14 @@ class TrainConfig(object):
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset') if 'dataset' in kwargs:
assert_type(self.dataset, DataFlow) assert 'data' not in kwargs, "dataset and data cannot be both presented in TrainConfig!"
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
else:
self.data = kwargs.pop('data')
assert_type(self.data, InputData)
self.optimizer = kwargs.pop('optimizer') self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer) assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks') self.callbacks = kwargs.pop('callbacks')
...@@ -52,7 +60,10 @@ class TrainConfig(object): ...@@ -52,7 +60,10 @@ class TrainConfig(object):
self.step_per_epoch = kwargs.pop('step_per_epoch', None) self.step_per_epoch = kwargs.pop('step_per_epoch', None)
if self.step_per_epoch is None: if self.step_per_epoch is None:
try: try:
self.step_per_epoch = self.dataset.size() if hasattr(self, 'dataset'):
self.step_per_epoch = self.dataset.size()
else:
self.step_per_epoch = self.data.size()
except NotImplementedError: except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.") logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.")
else: else:
...@@ -70,6 +81,7 @@ class TrainConfig(object): ...@@ -70,6 +81,7 @@ class TrainConfig(object):
else: else:
self.tower = [0] self.tower = [0]
# TODO deprecated @Dec20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs: if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs") logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: inputmethod.py # File: input_data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import threading import threading
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from ..dataflow.common import RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput'] __all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
class InputMethod(object): class InputData(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
pass pass
class FeedInput(InputMethod): class FeedInput(InputData):
def __init__(self, ds): def __init__(self, ds):
self.ds = ds self.ds = ds
...@@ -26,8 +27,16 @@ class FeedInput(InputMethod): ...@@ -26,8 +27,16 @@ class FeedInput(InputMethod):
def _setup(self, trainer): def _setup(self, trainer):
self.input_vars = trainer.model.get_input_vars() self.input_vars = trainer.model.get_input_vars()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
class FeedfreeInput(InputMethod): def next_feed(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
return feed
class FeedfreeInput(InputData):
def get_input_tensors(self): def get_input_tensors(self):
return self._get_input_tensors() return self._get_input_tensors()
......
...@@ -17,7 +17,7 @@ from ..tfutils.gradproc import apply_grad_processors, ScaleGradient ...@@ -17,7 +17,7 @@ from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer from .queue import QueueInputTrainer
from .inputmethod import QueueInput from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
...@@ -47,10 +47,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -47,10 +47,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue) assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -95,17 +101,23 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -95,17 +101,23 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue=None, input_queue=None,
predict_tower=None, predict_tower=None,
average_gradient=True): average_gradient=True):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue) self._average_gradient = average_gradient
self.average_gradient = average_gradient
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup() super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
if self.average_gradient and self.config.nr_tower > 1: if self._average_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)) gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
......
...@@ -3,19 +3,16 @@ ...@@ -3,19 +3,16 @@
# File: queue.py # File: queue.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import threading
import tensorflow as tf import tensorflow as tf
from ..dataflow.common import RepeatedData
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import get_global_step_var, TowerContext
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import QueueInput from ..tfutils.summary import summary_moving_average
from .input_data import QueueInput
from .trainer import (FeedfreeTrainer, MultiPredictorTowerTrainer, from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer)
SingleCostFeedfreeTrainer)
__all__ = ['QueueInputTrainer'] __all__ = ['QueueInputTrainer']
...@@ -30,14 +27,19 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer): ...@@ -30,14 +27,19 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
:param predict_tower: list of gpu relative idx to run prediction. default to be [0]. :param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu. Use -1 for cpu.
""" """
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._input_method = QueueInput(config.dataset, input_queue)
def _setup(self): self._setup_predictor_factory(predict_tower)
super(QueueInputTrainer, self)._setup()
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
def _setup(self):
super(SingleCostFeedfreeTrainer, self)._setup()
with TowerContext(''): with TowerContext(''):
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
...@@ -47,4 +49,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer): ...@@ -47,4 +49,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) #self.train_op = tf.group(*self.dequed_inputs)
...@@ -8,15 +8,13 @@ from six.moves import zip ...@@ -8,15 +8,13 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import logger, SUMMARY_BACKUP_KEYS from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import FeedfreeInput from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer', __all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedfreeTrainer'] 'SingleCostFeedfreeTrainer']
...@@ -59,13 +57,18 @@ class SimpleTrainer(Trainer): ...@@ -59,13 +57,18 @@ class SimpleTrainer(Trainer):
def __init__(self, config): def __init__(self, config):
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if not hasattr(config, 'dataset'):
self._input_method = config.data
assert isinstance(self._input_method, FeedInput)
else:
self._input_method = FeedInput(config.dataset)
def run_step(self): def run_step(self):
data = next(self.data_producer) feed = self._input_method.next_feed()
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def _setup(self): def _setup(self):
self._input_method._setup(self)
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
with TowerContext(''): with TowerContext(''):
...@@ -81,14 +84,9 @@ class SimpleTrainer(Trainer): ...@@ -81,14 +84,9 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
data = next(self.data_producer) feed = self._input_method.next_feed()
feed = dict(zip(self.input_vars, data))
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str) self._process_summary(summary_str)
...@@ -126,7 +124,7 @@ class FeedfreeTrainer(Trainer): ...@@ -126,7 +124,7 @@ class FeedfreeTrainer(Trainer):
return self._input_method.get_input_tensors() return self._input_method.get_input_tensors()
def _setup(self): def _setup(self):
assert isinstance(self._input_method, FeedfreeInput) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self) self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer): class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...@@ -155,3 +153,4 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -155,3 +153,4 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
#trace_file = open('timeline.ctf.json', 'w') #trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format()) #trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit() #import sys; sys.exit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment