Commit 6beb258b authored by Yuxin Wu's avatar Yuxin Wu

a general feed-free trainer

parent e4d6992d
...@@ -5,19 +5,15 @@ ...@@ -5,19 +5,15 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple
import sys import sys
import six import six
from six.moves import zip, map from six.moves import zip
from ..dataflow import DataFlow from ..utils import logger, execute_only_once
from ..utils import get_tqdm, logger, execute_only_once
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name from ..tfutils import get_op_var_name
from .base import Callback
from .dispatcher import OutputTensorDispatcer
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats'] 'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -63,95 +59,6 @@ class Inferencer(object): ...@@ -63,95 +59,6 @@ class Inferencer(object):
def _get_output_tensors(self): def _get_output_tensors(self):
pass pass
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
Write some scalar tensor to both stat and summary. Write some scalar tensor to both stat and summary.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: inference_runner.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from collections import namedtuple
import six
from six.moves import zip
from ..dataflow import DataFlow
from .base import Callback
from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm
__all__ = ['InferenceRunner']
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
...@@ -22,6 +22,7 @@ class TrainConfig(object): ...@@ -22,6 +22,7 @@ class TrainConfig(object):
""" """
:param dataset: the dataset to train. a `DataFlow` instance. :param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance :param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig. :param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define :param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training. the callbacks to perform during training.
......
...@@ -131,7 +131,7 @@ class TensorInput(FeedfreeInput): ...@@ -131,7 +131,7 @@ class TensorInput(FeedfreeInput):
def size(self): def size(self):
if self._size is None: if self._size is None:
raise ValueError("size of TensorInput is None!") raise ValueError("size of TensorInput is undefined!")
return self._size return self._size
def _setup(self, trainer): def _setup(self, trainer):
...@@ -139,6 +139,3 @@ class TensorInput(FeedfreeInput): ...@@ -139,6 +139,3 @@ class TensorInput(FeedfreeInput):
def _get_input_tensors(self): def _get_input_tensors(self):
return self.get_tensor_fn() return self.get_tensor_fn()
class SplitTensorInput(FeedfreeInput):
pass
...@@ -16,7 +16,6 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -16,7 +16,6 @@ from ..tfutils import (backup_collection, restore_collection,
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer
from .input_data import QueueInput from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
......
...@@ -14,33 +14,25 @@ from .input_data import QueueInput ...@@ -14,33 +14,25 @@ from .input_data import QueueInput
from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer) from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer)
__all__ = ['QueueInputTrainer'] __all__ = ['SimpleFeedfreeTrainer', 'QueueInputTrainer']
class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer): class SimpleFeedfreeTrainer(
""" Single GPU Trainer, takes input from a queue""" MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, predict_tower=None):
""" """
:param config: a `TrainConfig` instance A trainer with single cost, single training tower and feed-free input
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. config.data must exists
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
""" """
if hasattr(config, 'dataset'): self._input_method = config.data
self._input_method = QueueInput(config.dataset, input_queue) super(SimpleFeedfreeTrainer, self).__init__(config)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(QueueInputTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "SimpleFeedfreeTrainer doesn't support multigpu!"
def _setup(self): def _setup(self):
super(SingleCostFeedfreeTrainer, self)._setup() super(SimpleFeedfreeTrainer, self)._setup()
with TowerContext(''): with TowerContext('', is_training=True):
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
...@@ -49,3 +41,23 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer): ...@@ -49,3 +41,23 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) #self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
config.data = QueueInput(config.dataset, input_queue)
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower)
def _setup(self):
super(QueueInputTrainer, self)._setup()
...@@ -71,7 +71,7 @@ class SimpleTrainer(Trainer): ...@@ -71,7 +71,7 @@ class SimpleTrainer(Trainer):
self._input_method._setup(self) self._input_method._setup(self)
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
with TowerContext(''): with TowerContext('', is_training=True):
model.build_graph(self.input_vars) model.build_graph(self.input_vars)
cost_var = model.get_cost() cost_var = model.get_cost()
add_moving_summary(cost_var) add_moving_summary(cost_var)
...@@ -153,4 +153,3 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -153,4 +153,3 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
#trace_file = open('timeline.ctf.json', 'w') #trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format()) #trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit() #import sys; sys.exit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment