Commit f363d2e8 authored by Yuxin Wu's avatar Yuxin Wu

Call predictor with positional arguments

parent a6a2aba4
......@@ -116,7 +116,7 @@ def run_test(model_path, img_file):
im = cv2.imread(img_file, cv2.IMREAD_COLOR).astype('float32')
im = cv2.resize(im, (368, 368))
out = predict_func([[im]])[0][0]
out = predict_func(im[None, :, :, :])[0][0]
hm = out[:, :, :14].sum(axis=2)
viz = colorize(im, hm)
cv2.imwrite("output.jpg", viz)
......
......@@ -20,7 +20,7 @@ def play_one_episode(env, func, render=False):
"""
Map from observation to action, with 0.001 greedy.
"""
act = func([[s]])[0][0].argmax()
act = func(s[None, :, :, :])[0][0].argmax()
if random.random() < 0.001:
spc = env.action_space
act = spc.sample()
......
......@@ -199,7 +199,7 @@ class ExpReplay(DataFlow, Callback):
history = np.stack(history, axis=2)
# assume batched network
q_values = self.predictor([[history]])[0][0] # this is the bottleneck
q_values = self.predictor(history[None, :, :, :])[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
if isOver:
......
......@@ -284,7 +284,7 @@ def run_image(model, sess_init, inputs):
assert img is not None
img = transformers.augment(img)[np.newaxis, :, :, :]
outputs = predictor([img])[0]
outputs = predictor(img)[0]
prob = outputs[0]
ret = prob.argsort()[-10:][::-1]
......
......@@ -139,7 +139,7 @@ def run_image(model, sess_init, inputs):
assert img is not None
img = transformers.augment(img)[np.newaxis, :, :, :]
o = predict_func([img])
o = predict_func(img)
prob = o[0][0]
ret = prob.argsort()[-10:][::-1]
......
......@@ -78,7 +78,7 @@ class OnlineTensorboardExport(Callback):
x /= x.max()
return x
o = self.pred([self.theta])
o = self.pred(self.theta)
gt_filters = np.concatenate([self.filters[i, :, :] for i in range(8)], axis=0)
pred_filters = np.concatenate([o[0][i, :, :, 0] for i in range(8)], axis=0)
......
......@@ -94,7 +94,7 @@ def detect_one_image(img, model_func):
resizer = CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE)
resized_img = resizer.augment(img)
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
fg_probs, fg_boxes = model_func([resized_img])
fg_probs, fg_boxes = model_func(resized_img)
fg_boxes = fg_boxes / scale
fg_boxes = clip_boxes(fg_boxes, img.shape[:2])
return nms_fastrcnn_results(fg_boxes, fg_probs)
......
......@@ -198,8 +198,10 @@ def run(model_path, image_path, output):
predictor = OfflinePredictor(pred_config)
im = cv2.imread(image_path)
assert im is not None
im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16))
outputs = predictor([[im.astype('float32')]])
im = cv2.resize(
im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16)
)[None, :, :, :].astype('float32')
outputs = predictor(im)
if output is None:
for k in range(6):
pred = outputs[k][0]
......
......@@ -98,7 +98,7 @@ def run_test(params, input):
im = cv2.imread(input).astype('float32')
im = prepro.augment(im)
im = np.reshape(im, (1, 224, 224, 3))
outputs = predict_func([im])
outputs = predict_func(im)
prob = outputs[0]
ret = prob[0].argsort()[-10:][::-1]
......
......@@ -42,7 +42,7 @@ def run(model_path, image_path):
im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
im = im.astype(np.float32)[:, :, ::-1]
saliency_images = predictor([im])[0]
saliency_images = predictor(im)[0]
abs_saliency = np.abs(saliency_images).max(axis=-1)
pos_saliency = np.maximum(0, saliency_images)
......
......@@ -387,7 +387,7 @@ def visualize(model_path, model, algo_name):
for offset, dp in enumerate(ds.get_data()):
digit, label = dp
prediction = pred([digit])[0]
prediction = pred(digit)[0]
embed[offset * BATCH_SIZE:offset * BATCH_SIZE + BATCH_SIZE, ...] = prediction
images[offset * BATCH_SIZE:offset * BATCH_SIZE + BATCH_SIZE, ...] = digit
offset += 1
......
......@@ -140,7 +140,7 @@ def view_warp(modelpath):
ds.reset_state()
for k in ds.get_data():
img, label = k
outputs, affine1, affine2 = pred([img])
outputs, affine1, affine2 = pred(img)
for idx, viz in enumerate(outputs):
viz = cv2.cvtColor(viz, cv2.COLOR_GRAY2BGR)
# Here we assume the second branch focuses on the first digit
......
......@@ -65,7 +65,7 @@ def run_test(path, input):
assert im is not None, input
im = cv2.resize(im, (227, 227))[:, :, ::-1].reshape(
(1, 227, 227, 3)).astype('float32') - 110
outputs = predictor([im])[0]
outputs = predictor(im)[0]
prob = outputs[0]
ret = prob.argsort()[-10:][::-1]
print("Top10 predictions:", ret)
......
......@@ -76,7 +76,7 @@ def run_test(path, input):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (224, 224)).reshape((1, 224, 224, 3)).astype('float32')
im = im - 110
outputs = predict_func([im])[0]
outputs = predict_func(im)[0]
prob = outputs[0]
ret = prob.argsort()[-10:][::-1]
print("Top10 predictions:", ret)
......
......@@ -10,6 +10,7 @@ import six
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
......@@ -30,22 +31,21 @@ class PredictorBase(object):
"""
Call the predictor on some inputs.
If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
otherwise, assume ``args`` is a datapoinnt
Examples:
When you have a predictor which takes a datapoint [e1, e2], you
can call it in two ways:
When you have a predictor defined with two inputs, call it with:
.. code-block:: python
predictor(e1, e2)
predictor([e1, e2])
"""
if len(args) != 1:
dp = args
if len(args) == 1 and isinstance(args[0], (list, tuple)):
dp = args[0] # backward-compatibility
log_deprecated(
"Calling a predictor with one datapoint",
"Call it with positional arguments instead!",
"2018-3-1")
else:
dp = args[0]
dp = args
output = self._do_call(dp)
if self.return_input:
return (dp, output)
......
......@@ -3,6 +3,7 @@
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import multiprocessing
import six
from six.moves import queue, range
......@@ -71,7 +72,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((DIE, None))
return
else:
self.outqueue.put((tid, self.predictor(dp)))
self.outqueue.put((tid, self.predictor(*dp)))
class PredictorWorkerThread(StoppableThread, ShareSessionThread):
......@@ -89,7 +90,7 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
while not self.stopped():
batched, futures = self.fetch_batch()
try:
outputs = self.func(batched)
outputs = self.func(*batched)
except tf.errors.CancelledError:
for f in futures:
f.cancel()
......@@ -122,6 +123,9 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
futures.append(f)
except queue.Empty:
break # do not wait
for k in range(nr_input_var):
batched[k] = np.asarray(batched[k])
return batched, futures
......
......@@ -73,7 +73,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = 0
with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
for dp in self.dataset.get_data():
res = self.predictor(dp)
res = self.predictor(*dp)
yield res
pbar.update()
......
......@@ -277,3 +277,8 @@ class TowerTensorHandle(object):
The output returned by the tower function.
"""
return self._output
# def make_callable(self, input_names, output_names):
# input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names)
# pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment