Commit 3080e91e authored by Yuxin Wu's avatar Yuxin Wu

some rename and alias

parent 42e5481a
......@@ -6,7 +6,7 @@
import tensorflow as tf
from collections import namedtuple
import six
from six.moves import zip
from six.moves import zip, range
from ..dataflow import DataFlow
from .base import Callback
......@@ -61,7 +61,7 @@ class InferenceRunner(Callback):
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
input_vars = self.trainer.model.get_reuse_placehdrs()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
......@@ -125,13 +125,58 @@ class FeedfreeInferenceRunner(Callback):
self.input_tensor_names = input_tensors
def _setup_graph(self):
self._find_input_tensors() # tensors
self._find_output_tensors()
# TODO build tower
def _find_input_tensors(self):
self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
# TODO filter by names
self._find_output_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs()
assert len(self._input_tensors) == len(model_placehdrs), \
"FeedfreeInput doesn't produce correct number of output tensors"
if self.input_tensor_names is not None:
assert isinstance(self.input_tensor_names, list)
self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_tensor_names]
assert len(self._input_tensors) == len(self.input_tensor_names), \
"names of input tensors are not defined in the Model"
def _find_output_tensors(self):
pass
# doesn't support output an input tensor
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
self.output_tensors = all_names
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
sz = self._input_data.size()
with get_tqdm(total=sz) as pbar:
for _ in range(sz):
#outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#inf_output = [(outputs if k.isOutput else dp)[k.index]
#for k in tensormap]
#inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs)
......@@ -43,11 +43,14 @@ class ModelDesc(object):
"""
if hasattr(self, 'reuse_input_vars'):
return self.reuse_input_vars
ret = self.get_placeholders()
ret = self.build_placeholders()
self.reuse_input_vars = ret
return ret
def get_placeholders(self, prefix=''):
# alias
get_reuse_placehdrs = get_input_vars
def build_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar
"""
input_vars = self._get_input_vars()
......
......@@ -151,7 +151,8 @@ class DataParallelOfflinePredictor(OnlinePredictor):
output_vars = []
for k in towers:
towername = PREDICT_TOWER + str(k)
input_vars = config.model.get_placeholders(prefix=towername + '-')
input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment