Commit 8db5bcd3 authored by Yuxin Wu's avatar Yuxin Wu

separate out input method

parent fdab3db2
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tensorpack import (QueueInputTrainerBase, TowerContext, from tensorpack import (QueueInputTrainer, TowerContext,
get_global_step_var) get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
class GANTrainer(QueueInputTrainerBase): class GANTrainer(QueueInputTrainer):
def __init__(self, config, g_vs_d=1): def __init__(self, config, g_vs_d=1):
super(GANTrainer, self).__init__(config) super(GANTrainer, self).__init__(config)
self._build_enque_thread() self._input_method = QueueInput(config.dataset)
if g_vs_d > 1: if g_vs_d > 1:
self._opt_g = g_vs_d self._opt_g = g_vs_d
self._opt_d = 1 self._opt_d = 1
...@@ -22,8 +22,9 @@ class GANTrainer(QueueInputTrainerBase): ...@@ -22,8 +22,9 @@ class GANTrainer(QueueInputTrainerBase):
self._opt_d = int(1.0 / g_vs_d) self._opt_d = int(1.0 / g_vs_d)
def _setup(self): def _setup(self):
super(GANTrainer, self)._setup()
with TowerContext(''): with TowerContext(''):
actual_inputs = self._get_input_tensors_noreuse() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss, self.g_min = self.config.optimizer.minimize(self.model.g_loss,
var_list=self.model.g_vars, name='g_op') var_list=self.model.g_vars, name='g_op')
......
...@@ -44,7 +44,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -44,7 +44,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride) stride = shape4d(stride)
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d() W_init = tf.contrib.layers.variance_scaling_initializer()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -30,8 +30,8 @@ def FullyConnected(x, out_dim, ...@@ -30,8 +30,8 @@ def FullyConnected(x, out_dim,
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
if W_init is None: if W_init is None:
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim))) #W_init = tf.uniform_unit_scaling_initializer(factor=1.43)
W_init = tf.uniform_unit_scaling_initializer(factor=1.43) W_init = tf.contrib.layers.variance_scaling_initializer()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: inputmethod.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import threading
from abc import ABCMeta, abstractmethod
from ..tfutils.summary import add_moving_summary
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
class InputMethod(object):
__metaclass__ = ABCMeta
pass
class FeedInput(InputMethod):
def __init__(self, ds):
self.ds = ds
def size(self):
return self.ds.size()
def _setup(self, trainer):
self.input_vars = trainer.model.get_input_vars()
class FeedfreeInput(InputMethod):
def get_input_tensors(self):
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
"""
always create and return a list of new input tensors
"""
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
self.dataflow = ds
self.queue = queue
self.sess = trainer.sess
self.coord = trainer.coord
self.placehdrs = input_placehdrs
self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.placehdrs, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
logger.info("Enqueue Thread Exited.")
class QueueInput(FeedfreeInput):
def __init__(self, ds, queue=None):
self.queue = queue
self.ds = ds
def size(self):
return self.ds.size()
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(
trainer, self.queue, self.ds, self.input_placehdrs)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class TensorInput(FeedfreeInput):
def __init__(self, get_tensor_fn, size=None):
self.get_tensor_fn = get_tensor_fn
self._size = size
def size(self):
if self._size is None:
raise ValueError("size of TensorInput is None!")
return self._size
def _setup(self, trainer):
pass
def _get_input_tensors(self):
return self.get_tensor_fn()
class SplitTensorInput(FeedfreeInput):
pass
...@@ -15,12 +15,13 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -15,12 +15,13 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer, MultiPredictorTowerTrainer from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer, QueueInputTrainerBase from .queue import QueueInputTrainer
from .inputmethod import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(FeedlessTrainer): class MultiGPUTrainer(FeedfreeTrainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def _multi_tower_grads(towers, get_tower_grad_func): def _multi_tower_grads(towers, get_tower_grad_func):
...@@ -42,15 +43,14 @@ class MultiGPUTrainer(FeedlessTrainer): ...@@ -42,15 +43,14 @@ class MultiGPUTrainer(FeedlessTrainer):
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
class SyncMultiGPUTrainer(QueueInputTrainerBase, class SyncMultiGPUTrainer(MultiGPUTrainer,
MultiGPUTrainer, SingleCostFeedfreeTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._input_method = QueueInput(config.dataset, input_queue)
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -75,6 +75,7 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase, ...@@ -75,6 +75,7 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
return ret return ret
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
...@@ -87,9 +88,8 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase, ...@@ -87,9 +88,8 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
def run_step(self): def run_step(self):
self.sess.run(self.train_op) self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(QueueInputTrainerBase, class AsyncMultiGPUTrainer(MultiGPUTrainer,
MultiGPUTrainer, SingleCostFeedfreeTrainer,
SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, def __init__(self, config,
input_queue=None, input_queue=None,
...@@ -97,10 +97,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, ...@@ -97,10 +97,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
average_gradient=True): average_gradient=True):
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._input_method = QueueInput(config.dataset, input_queue)
self.average_gradient = average_gradient self.average_gradient = average_gradient
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
......
...@@ -12,86 +12,14 @@ from ..tfutils import get_global_step_var, TowerContext ...@@ -12,86 +12,14 @@ from ..tfutils import get_global_step_var, TowerContext
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import QueueInput
from .trainer import (FeedlessTrainer, MultiPredictorTowerTrainer, from .trainer import (FeedfreeTrainer, MultiPredictorTowerTrainer,
SingleCostFeedlessTrainer) SingleCostFeedfreeTrainer)
__all__ = ['QueueInputTrainerBase', 'QueueInputTrainer'] __all__ = ['QueueInputTrainer']
class EnqueueThread(threading.Thread): class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
def __init__(self, trainer):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1)
self.input_vars = trainer.input_vars
self.queue = trainer.input_queue
self.op = self.queue.enqueue(self.input_vars)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
logger.info("Enqueue Thread Exited.")
class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue=None):
""" create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars()
assert len(self.input_vars) > 0, "QueueInput can only be used with input placeholders!"
if input_queue is None:
self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars],
name='input_queue')
else:
self.input_queue = input_queue
input_th = EnqueueThread(self)
self.config.callbacks.append(StartProcOrThread(input_th))
def _get_input_tensors_noreuse(self):
""" Dequeue a datapoint from input_queue and return.
Can be called multiple times.
"""
ret = self.input_queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, SingleCostFeedlessTrainer):
""" Single GPU Trainer, takes input from a queue""" """ Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
...@@ -104,9 +32,10 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl ...@@ -104,9 +32,10 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._input_method = QueueInput(config.dataset, input_queue)
def _setup(self): def _setup(self):
super(QueueInputTrainer, self)._setup()
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
with TowerContext(''): with TowerContext(''):
...@@ -119,18 +48,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl ...@@ -119,18 +48,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) #self.train_op = tf.group(*self.dequed_inputs)
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
...@@ -16,9 +16,10 @@ from ..tfutils import (get_tensors_by_names, freeze_collection, ...@@ -16,9 +16,10 @@ from ..tfutils import (get_tensors_by_names, freeze_collection,
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .inputmethod import FeedfreeInput
__all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer', __all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedlessTrainer'] 'SingleCostFeedfreeTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
...@@ -112,7 +113,7 @@ class MultiPredictorTowerTrainer(Trainer): ...@@ -112,7 +113,7 @@ class MultiPredictorTowerTrainer(Trainer):
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)] return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
class FeedlessTrainer(Trainer): class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """ """ A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # need to run summary_op every epoch
...@@ -121,16 +122,17 @@ class FeedlessTrainer(Trainer): ...@@ -121,16 +122,17 @@ class FeedlessTrainer(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def _get_input_tensors_noreuse(self): def _get_input_tensors(self):
""" return a list of actual input tensors. return self._input_method.get_input_tensors()
Always return new tensors (for multi tower) if called mutliple times.
""" def _setup(self):
pass assert isinstance(self._input_method, FeedfreeInput)
self._input_method._setup(self)
class SingleCostFeedlessTrainer(FeedlessTrainer): class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower""" """ get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors_noreuse() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
# GATE_NONE faster? # GATE_NONE faster?
...@@ -139,3 +141,17 @@ class SingleCostFeedlessTrainer(FeedlessTrainer): ...@@ -139,3 +141,17 @@ class SingleCostFeedlessTrainer(FeedlessTrainer):
add_moving_summary(cost_var) add_moving_summary(cost_var)
return cost_var, grads return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment