Commit 95ab1563 authored by Yuxin Wu's avatar Yuxin Wu

support prefix for FeedfreeInferenceRunner

parent 16216c6a
...@@ -159,7 +159,7 @@ def get_data(): ...@@ -159,7 +159,7 @@ def get_data():
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchData(ds, 100, 1)
return ds return ds
......
...@@ -11,7 +11,17 @@ __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory'] ...@@ -11,7 +11,17 @@ __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks
Attributes:
epoch_num(int): the number of epochs that have completed the update
trainer(Trainer): the trainer
graph(tf.Graph): the graph
Note:
These attributes are available only after (and including)
:meth:`_setup_graph`.
"""
def setup_graph(self, trainer): def setup_graph(self, trainer):
""" """
...@@ -24,7 +34,6 @@ class Callback(object): ...@@ -24,7 +34,6 @@ class Callback(object):
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch - 1 self.epoch_num = self.trainer.config.starting_epoch - 1
# self.epoch_num is always the number of epochs that finished updating parameters.
with tf.name_scope(type(self).__name__): with tf.name_scope(type(self).__name__):
self._setup_graph() self._setup_graph()
...@@ -50,8 +59,6 @@ class Callback(object): ...@@ -50,8 +59,6 @@ class Callback(object):
def trigger_epoch(self): def trigger_epoch(self):
""" """
Triggered after every epoch. Triggered after every epoch.
In this function, ``self.epoch_num`` would be the number of epoch finished.
""" """
self.epoch_num += 1 self.epoch_num += 1
self._trigger_epoch() self._trigger_epoch()
......
...@@ -11,6 +11,7 @@ from six.moves import zip, range ...@@ -11,6 +11,7 @@ from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, PREDICT_TOWER, SUMMARY_BACKUP_KEYS from ..utils import logger, get_tqdm, PREDICT_TOWER, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name, freeze_collection from ..tfutils.common import get_op_tensor_name, freeze_collection
from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph from ..predict import build_prediction_graph
...@@ -151,12 +152,14 @@ class FeedfreeInferenceRunner(Callback): ...@@ -151,12 +152,14 @@ class FeedfreeInferenceRunner(Callback):
pipeline. pipeline.
""" """
def __init__(self, input, infs, input_names=None): def __init__(self, input, infs, input_names=None, prefix=''):
""" """
Args: Args:
input (FeedfreeInput): the input to use. Must have ``size()``. input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names of InputVar. input_names (list): must be a subset of the names of InputVar.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input_data = input self._input_data = input
...@@ -174,6 +177,7 @@ class FeedfreeInferenceRunner(Callback): ...@@ -174,6 +177,7 @@ class FeedfreeInferenceRunner(Callback):
self._size = input.size() self._size = input.size()
except NotImplementedError: except NotImplementedError:
raise ValueError("Input used in FeedfreeInferencecRunner must have a size!") raise ValueError("Input used in FeedfreeInferencecRunner must have a size!")
self._prefix = prefix
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
...@@ -185,8 +189,8 @@ class FeedfreeInferenceRunner(Callback): ...@@ -185,8 +189,8 @@ class FeedfreeInferenceRunner(Callback):
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0]) build_prediction_graph(fn, [0], prefix=self._prefix) # TODO use towerp1 to support multiple FeedfreeInferenceRunner
self._tower_prefix = PREDICT_TOWER + '0' self._tower_prefix = TowerContext.get_predict_tower_name(self._prefix, 0)
self._find_output_tensors() self._find_output_tensors()
......
...@@ -138,9 +138,9 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i ...@@ -138,9 +138,9 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
if ctx is not None and ctx.is_main_training_tower: if ctx is not None and ctx.is_main_training_tower:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates: if non_grad_updates:
logger.info("Apply UPDATE_OPS collection on cost.")
with tf.control_dependencies(non_grad_updates): with tf.control_dependencies(non_grad_updates):
barrier = tf.control_flow_ops.no_op(name='update_ops_barrier') cost = tf.identity(cost)
cost = tf.control_flow_ops.with_dependencies([barrier], cost)
return cost return cost
def _get_cost(self, *args): def _get_cost(self, *args):
......
...@@ -144,17 +144,20 @@ def get_predict_func(config): ...@@ -144,17 +144,20 @@ def get_predict_func(config):
return OfflinePredictor(config) return OfflinePredictor(config)
def build_prediction_graph(build_tower_fn, towers=[0]): def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
""" """
Args: Args:
build_tower_fn: a function that will be called inside each tower, build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument. taking tower id as the argument.
towers: a list of relative GPU id. towers: a list of relative GPU id.
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
""" """
for k in towers: for k in towers:
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building prediction graph for towerid={} with prefix='{}' ...".format(k, prefix))
towername = TowerContext.get_predict_tower_name(prefix, k)
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('{}{}'.format(PREDICT_TOWER, k)): TowerContext(towername, is_training=False):
build_tower_fn(k) build_tower_fn(k)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
...@@ -72,6 +72,20 @@ class TowerContext(object): ...@@ -72,6 +72,20 @@ class TowerContext(object):
newname = re.sub(predict_tower_prefix, 'tower0/', name) newname = re.sub(predict_tower_prefix, 'tower0/', name)
return graph.get_tensor_by_name(newname) return graph.get_tensor_by_name(newname)
@staticmethod
def get_predict_tower_name(prefix, towerid=0):
"""
Args:
prefix(str): an alphanumeric prefix.
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
Returns:
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
"""
assert prefix == '' or prefix.isalnum()
return PREDICT_TOWER + prefix + str(towerid)
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, \
......
...@@ -105,8 +105,9 @@ class Trainer(object): ...@@ -105,8 +105,9 @@ class Trainer(object):
summary (tf.Summary or str): a summary object, or a str which will summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf. be interpreted as a serialized tf.Summary protobuf.
""" """
if isinstance(summary, six.string_types): if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary) summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
......
...@@ -75,8 +75,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -75,8 +75,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer,
SingleCostFeedfreeTrainer): MultiPredictorTowerTrainer):
""" """
A trainer with single cost, single training tower, any number of A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input. prediction tower, and feed-free input.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment