Commit 3080e91e authored by Yuxin Wu's avatar Yuxin Wu

some rename and alias

parent 42e5481a
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import six import six
from six.moves import zip from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from .base import Callback from .base import Callback
...@@ -61,7 +61,7 @@ class InferenceRunner(Callback): ...@@ -61,7 +61,7 @@ class InferenceRunner(Callback):
def _find_input_tensors(self): def _find_input_tensors(self):
if self.input_tensors is None: if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars() input_vars = self.trainer.model.get_reuse_placehdrs()
# TODO even if it works here, sparse still is unavailable # TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse # because get_tensor_by_name doesn't work for sparse
def get_name(x): def get_name(x):
...@@ -125,13 +125,58 @@ class FeedfreeInferenceRunner(Callback): ...@@ -125,13 +125,58 @@ class FeedfreeInferenceRunner(Callback):
self.input_tensor_names = input_tensors self.input_tensor_names = input_tensors
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors
self._find_output_tensors()
# TODO build tower
def _find_input_tensors(self):
self._input_data._setup(self.trainer) self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
# TODO filter by names model_placehdrs = self.trainer.model.get_reuse_placehdrs()
self._find_output_tensors() assert len(self._input_tensors) == len(model_placehdrs), \
"FeedfreeInput doesn't produce correct number of output tensors"
if self.input_tensor_names is not None:
assert isinstance(self.input_tensor_names, list)
self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_tensor_names]
assert len(self._input_tensors) == len(self.input_tensor_names), \
"names of input tensors are not defined in the Model"
def _find_output_tensors(self): def _find_output_tensors(self):
pass # doesn't support output an input tensor
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = all_names
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
sz = self._input_data.size()
with get_tqdm(total=sz) as pbar:
for _ in range(sz):
#outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#inf_output = [(outputs if k.isOutput else dp)[k.index]
#for k in tensormap]
#inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs)
...@@ -43,11 +43,14 @@ class ModelDesc(object): ...@@ -43,11 +43,14 @@ class ModelDesc(object):
""" """
if hasattr(self, 'reuse_input_vars'): if hasattr(self, 'reuse_input_vars'):
return self.reuse_input_vars return self.reuse_input_vars
ret = self.get_placeholders() ret = self.build_placeholders()
self.reuse_input_vars = ret self.reuse_input_vars = ret
return ret return ret
def get_placeholders(self, prefix=''): # alias
get_reuse_placehdrs = get_input_vars
def build_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar """ build placeholders with optional prefix, for each InputVar
""" """
input_vars = self._get_input_vars() input_vars = self._get_input_vars()
......
...@@ -151,7 +151,8 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -151,7 +151,8 @@ class DataParallelOfflinePredictor(OnlinePredictor):
output_vars = [] output_vars = []
for k in towers: for k in towers:
towername = PREDICT_TOWER + str(k) towername = PREDICT_TOWER + str(k)
input_vars = config.model.get_placeholders(prefix=towername + '-') input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment