Commit 99c70935 authored by Yuxin Wu's avatar Yuxin Wu

move predict_tower into trainconfig

parent 48ef46aa
...@@ -258,4 +258,5 @@ if __name__ == '__main__': ...@@ -258,4 +258,5 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.tower = train_tower config.tower = train_tower
trainer(config, predict_tower=predict_tower).train() config.predict_tower = predict_tower
trainer(config).train()
...@@ -14,9 +14,21 @@ from .inference import Inferencer ...@@ -14,9 +14,21 @@ from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner'] __all__ = ['InferenceRunner']
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
trainer.write_scalar_summary(k, v)
class InferenceRunner(Callback): class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
...@@ -31,14 +43,14 @@ class InferenceRunner(Callback): ...@@ -31,14 +43,14 @@ class InferenceRunner(Callback):
:param input_tensor_names: list of tensors to feed the dataflow to. :param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders. default to all the input placeholders.
""" """
assert isinstance(ds, DataFlow), type(ds) assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
else: else:
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), str(v) assert isinstance(v, Inferencer), v
self.input_tensors = input_tensors self.input_tensors = input_tensors
def _setup_graph(self): def _setup_graph(self):
...@@ -96,12 +108,30 @@ class InferenceRunner(Callback): ...@@ -96,12 +108,30 @@ class InferenceRunner(Callback):
self._write_summary_after_inference() self._write_summary_after_inference()
def _write_summary_after_inference(self): def _write_summary_after_inference(self):
for inf in self.infs: summary_inferencer(self.trainer, self.infs)
ret = inf.after_inference()
for k, v in six.iteritems(ret): class FeedfreeInferenceRunner(Callback):
try: IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
v = float(v)
except: def __init__(self, input, infs, input_tensors=None):
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__)) assert isinstance(input, FeedfreeInput), input
continue self._input_data = input
self.trainer.write_scalar_summary(k, v) if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
self.input_tensor_names = input_tensors
def _setup_graph(self):
self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
# TODO filter by names
self._find_output_tensors()
def _find_output_tensors(self):
pass
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
import tensorflow as tf import tensorflow as tf
from ..callbacks import Callbacks from ..callbacks.group import Callbacks
from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..dataflow import DataFlow
from .input_data import InputData from .input_data import InputData
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -35,6 +35,7 @@ class TrainConfig(object): ...@@ -35,6 +35,7 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf :param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1. :param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given. :param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -81,6 +82,9 @@ class TrainConfig(object): ...@@ -81,6 +82,9 @@ class TrainConfig(object):
self.tower = kwargs.pop('tower') self.tower = kwargs.pop('tower')
else: else:
self.tower = [0] self.tower = [0]
self.predict_tower = kwargs.pop('predict_tower', [0])
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
# TODO deprecated @Dec20 # TODO deprecated @Dec20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
......
...@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer): SingleCostFeedfreeTrainer):
def __init__(self, config, predict_tower=None): def __init__(self, config):
""" """
A trainer with single cost, single training tower and feed-free input A trainer with single cost, single training tower and feed-free input
config.data must exists config.data must exists
...@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer( ...@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer(
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method assert isinstance(self._input_method, FeedfreeInput), self._input_method
super(SimpleFeedfreeTrainer, self).__init__(config) super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(config.predict_tower)
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!" "SimpleFeedfreeTrainer doesn't support multigpu!"
...@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer): ...@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Use -1 for cpu. Use -1 for cpu.
""" """
config.data = QueueInput(config.dataset, input_queue) config.data = QueueInput(config.dataset, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower) super(QueueInputTrainer, self).__init__(config)
...@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else: else:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
...@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, def __init__(self, config,
input_queue=None, input_queue=None,
predict_tower=None, average_gradient=True,
average_gradient=True): predict_tower=None):
if hasattr(config, 'dataset'): if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue) self._input_method = QueueInput(config.dataset, input_queue)
else: else:
...@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower)
self._average_gradient = average_gradient self._average_gradient = average_gradient
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment