Commit 99c70935 authored by Yuxin Wu's avatar Yuxin Wu

move predict_tower into trainconfig

parent 48ef46aa
......@@ -258,4 +258,5 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
config.tower = train_tower
trainer(config, predict_tower=predict_tower).train()
config.predict_tower = predict_tower
trainer(config).train()
......@@ -14,9 +14,21 @@ from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm
from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner']
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
trainer.write_scalar_summary(k, v)
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
......@@ -31,14 +43,14 @@ class InferenceRunner(Callback):
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
assert isinstance(ds, DataFlow), ds
self.ds = ds
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
assert isinstance(v, Inferencer), v
self.input_tensors = input_tensors
def _setup_graph(self):
......@@ -96,12 +108,30 @@ class InferenceRunner(Callback):
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, input, infs, input_tensors=None):
assert isinstance(input, FeedfreeInput), input
self._input_data = input
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
self.input_tensor_names = input_tensors
def _setup_graph(self):
self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
# TODO filter by names
self._find_output_tensors()
def _find_output_tensors(self):
pass
......@@ -4,12 +4,12 @@
import tensorflow as tf
from ..callbacks import Callbacks
from ..callbacks.group import Callbacks
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..dataflow import DataFlow
from .input_data import InputData
__all__ = ['TrainConfig']
......@@ -35,6 +35,7 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......@@ -81,6 +82,9 @@ class TrainConfig(object):
self.tower = kwargs.pop('tower')
else:
self.tower = [0]
self.predict_tower = kwargs.pop('predict_tower', [0])
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
# TODO deprecated @Dec20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
......
......@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer):
def __init__(self, config, predict_tower=None):
def __init__(self, config):
"""
A trainer with single cost, single training tower and feed-free input
config.data must exists
......@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer(
self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method
super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._setup_predictor_factory(config.predict_tower)
assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!"
......@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Use -1 for cpu.
"""
config.data = QueueInput(config.dataset, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower)
super(QueueInputTrainer, self).__init__(config)
......@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available()
......@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config,
input_queue=None,
predict_tower=None,
average_gradient=True):
average_gradient=True,
predict_tower=None):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
......@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower)
self._average_gradient = average_gradient
assert tf.test.is_gpu_available()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment