Commit 6beb258b authored by Yuxin Wu's avatar Yuxin Wu

a general feed-free trainer

parent e4d6992d
......@@ -5,19 +5,15 @@
import tensorflow as tf
import numpy as np
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import sys
import six
from six.moves import zip, map
from six.moves import zip
from ..dataflow import DataFlow
from ..utils import get_tqdm, logger, execute_only_once
from ..utils import logger, execute_only_once
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback
from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_var_name
__all__ = ['InferenceRunner', 'ClassificationError',
__all__ = ['ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta)
......@@ -63,95 +59,6 @@ class Inferencer(object):
def _get_output_tensors(self):
pass
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
class ScalarStats(Inferencer):
"""
Write some scalar tensor to both stat and summary.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: inference_runner.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from collections import namedtuple
import six
from six.moves import zip
from ..dataflow import DataFlow
from .base import Callback
from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm
__all__ = ['InferenceRunner']
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
......@@ -22,6 +22,7 @@ class TrainConfig(object):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training.
......
......@@ -131,7 +131,7 @@ class TensorInput(FeedfreeInput):
def size(self):
if self._size is None:
raise ValueError("size of TensorInput is None!")
raise ValueError("size of TensorInput is undefined!")
return self._size
def _setup(self, trainer):
......@@ -139,6 +139,3 @@ class TensorInput(FeedfreeInput):
def _get_input_tensors(self):
return self.get_tensor_fn()
class SplitTensorInput(FeedfreeInput):
pass
......@@ -16,7 +16,6 @@ from ..tfutils import (backup_collection, restore_collection,
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .queue import QueueInputTrainer
from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
......
......@@ -14,33 +14,25 @@ from .input_data import QueueInput
from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer)
__all__ = ['QueueInputTrainer']
__all__ = ['SimpleFeedfreeTrainer', 'QueueInputTrainer']
class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
""" Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None):
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer):
def __init__(self, config, predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
A trainer with single cost, single training tower and feed-free input
config.data must exists
"""
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(QueueInputTrainer, self).__init__(config)
self._input_method = config.data
super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"SimpleFeedfreeTrainer doesn't support multigpu!"
def _setup(self):
super(SingleCostFeedfreeTrainer, self)._setup()
with TowerContext(''):
super(SimpleFeedfreeTrainer, self)._setup()
with TowerContext('', is_training=True):
cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
......@@ -49,3 +41,23 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
config.data = QueueInput(config.dataset, input_queue)
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower)
def _setup(self):
super(QueueInputTrainer, self)._setup()
......@@ -71,7 +71,7 @@ class SimpleTrainer(Trainer):
self._input_method._setup(self)
model = self.model
self.input_vars = model.get_input_vars()
with TowerContext(''):
with TowerContext('', is_training=True):
model.build_graph(self.input_vars)
cost_var = model.get_cost()
add_moving_summary(cost_var)
......@@ -153,4 +153,3 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment