Commit 3e61aacd authored by Yuxin Wu's avatar Yuxin Wu

standardize the name "predictor" instead of "predict_func"

parent 088521fc
......@@ -349,7 +349,7 @@ class OnlineExport(Callback):
self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, ::-1])[:, :, 0] # read rgb image and extract luminance
def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(['luminance'], ['prediction/output'])
self.predictor = self.trainer.get_predictor(['luminance'], ['prediction/output'])
def _trigger_epoch(self):
pass
......@@ -367,7 +367,7 @@ you can simply `print(prediction)` to find out the name.
These two names allows us to build the inference part of the network in
```python
self.trainer.get_predict_func(['luminance', 'prediction/output'])
self.trainer.get_predictor(['luminance', 'prediction/output'])
```
This is very convenient because in the `_tigger_epoch` we can use:
......@@ -385,7 +385,7 @@ class OnlineExport(Callback):
self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, [2, 1, 0]])[:, :, 0]
def _setup_graph(self):
self.trainer.get_predict_func(['luminance', 'prediction/output'])
self.trainer.get_predictor(['luminance', 'prediction/output'])
def _trigger_epoch(self):
hopefully_cool_rgb = self.pred([[self.example_input]])[0][0]
......
......@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predict_funcs(['state'], ['logitsT', 'pred_value'],
self.trainer.get_predictors(['state'], ['logitsT', 'pred_value'],
PREDICTOR_THREAD), batch_size=15)
def _before_train(self):
......
......@@ -38,7 +38,7 @@ def play_model(cfg):
print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval):
def eval_with_funcs(predictors, nr_eval):
class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
......@@ -62,7 +62,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
self.queue_put_stoppable(self.q, score)
q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs]
threads = [Worker(f, q) for f in predictors]
for k in threads:
k.start()
......@@ -103,7 +103,7 @@ class Evaluator(Triggerable):
def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 20)
self.pred_funcs = [self.trainer.get_predict_func(
self.pred_funcs = [self.trainer.get_predictor(
self.input_names, self.output_names)] * NR_PROC
def _trigger(self):
......
......@@ -229,7 +229,7 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, isOver]
def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self.predictor_io_names)
self.predictor = self.trainer.get_predictor(*self.predictor_io_names)
def _before_train(self):
self._init_memory()
......
......@@ -258,7 +258,7 @@ def run_image(model, sess_init, inputs):
input_names=['input'],
output_names=['output']
)
predict_func = OfflinePredictor(pred_config)
predictor = OfflinePredictor(pred_config)
meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean()
pp_mean_224 = pp_mean[16:-16, 16:-16, :]
......@@ -282,7 +282,7 @@ def run_image(model, sess_init, inputs):
assert img is not None
img = transformers.augment(img)[np.newaxis, :, :, :]
outputs = predict_func([img])[0]
outputs = predictor([img])[0]
prob = outputs[0]
ret = prob.argsort()[-10:][::-1]
......
......@@ -192,11 +192,11 @@ def run(model_path, image_path, output):
session_init=get_model_loader(model_path),
input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)])
predict_func = OfflinePredictor(pred_config)
predictor = OfflinePredictor(pred_config)
im = cv2.imread(image_path)
assert im is not None
im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16))
outputs = predict_func([[im.astype('float32')]])
outputs = predictor([[im.astype('float32')]])
if output is None:
for k in range(6):
pred = outputs[k][0]
......
......@@ -30,7 +30,7 @@ class Model(tp.ModelDesc):
def run(model_path, image_path):
predict_func = tp.OfflinePredictor(tp.PredictConfig(
predictor = tp.OfflinePredictor(tp.PredictConfig(
model=Model(),
session_init=tp.get_model_loader(model_path),
input_names=['image'],
......@@ -42,7 +42,7 @@ def run(model_path, image_path):
im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
im = im.astype(np.float32)[:, :, ::-1]
saliency_images = predict_func([im])[0]
saliency_images = predictor([im])[0]
abs_saliency = np.abs(saliency_images).max(axis=-1)
pos_saliency = np.maximum(0, saliency_images)
......
......@@ -54,7 +54,7 @@ class Model(ModelDesc):
def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig(
predictor = OfflinePredictor(PredictConfig(
model=Model(),
session_init=ParamRestore(param_dict),
input_names=['input'],
......@@ -65,7 +65,7 @@ def run_test(path, input):
assert im is not None, input
im = cv2.resize(im, (227, 227))[:, :, ::-1].reshape(
(1, 227, 227, 3)).astype('float32') - 110
outputs = predict_func([im])[0]
outputs = predictor([im])[0]
prob = outputs[0]
ret = prob.argsort()[-10:][::-1]
print("Top10 predictions:", ret)
......
......@@ -92,7 +92,7 @@ class InferenceRunner(Triggerable):
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.predictor = self.trainer.get_predictor(
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
......@@ -135,7 +135,7 @@ class InferenceRunner(Triggerable):
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
......
......@@ -218,8 +218,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
param: same as in :class:`HyperParamSetter`.
schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``.
Each ``(ep, val)`` pair means to set the param
to "val" after the completion of `ep` th epoch.
If ep == 0, the value will be set before training.
to "val" __after__ the completion of `ep` th epoch.
If ep == 0, the value will be set before the first epoch.
interp: None: no interpolation. 'linear': linear interpolation
Example:
......@@ -263,6 +263,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
``epoch_num`` is the number of epochs that have finished.
Example:
Decrease by a factor of 0.9 every two epochs:
......
......@@ -7,7 +7,7 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf
import six
from ..utils import logger
from ..utils import logger, deprecated
from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext
......@@ -146,6 +146,7 @@ class OfflinePredictor(OnlinePredictor):
input_tensors, output_tensors, config.return_input, sess)
@deprecated("Use OfflinePredictor instead!", "2017-05-20")
def get_predict_func(config):
"""
Equivalent to ``OfflinePredictor(config)``.
......
......@@ -189,7 +189,7 @@ class Trainer(object):
self.summary_writer.close()
self.monitored_sess.close()
def get_predict_func(self, input_names, output_names, tower=0):
def get_predictor(self, input_names, output_names, tower=0):
"""
Args:
input_names (list), output_names(list): list of names
......@@ -200,16 +200,25 @@ class Trainer(object):
"""
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(self)
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower
return self._predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """
nr_tower = len(self.config.predict_tower)
if nr_tower < n:
logger.warn(
"Requested {} predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(n, nr_tower))
return [self.get_predict_func(input_names, output_names, k % nr_tower) for k in range(n)]
return [self.get_predictor(input_names, output_names, k) for k in range(n)]
@deprecated("Use get_predictor instead!", "2017-05-20")
def get_predict_func(self, input_names, output_names, tower=0):
return self.get_predictor(input_names, output_names, tower)
@deprecated("Use get_predictors instead!", "2017-05-20")
def get_predict_funcs(self, input_names, output_names, n):
return self.get_predictors(input_names, output_names, n)
@deprecated("Don't need to call it any more!", "2017-03-20")
def _setup_predictor_factory(self):
......
......@@ -26,7 +26,6 @@ class PredictorFactory(object):
self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list)
# TODO sess option
def get_predictor(self, input_names, output_names, tower):
"""
Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment