Commit 3e61aacd authored by Yuxin Wu's avatar Yuxin Wu

standardize the name "predictor" instead of "predict_func"

parent 088521fc
...@@ -349,7 +349,7 @@ class OnlineExport(Callback): ...@@ -349,7 +349,7 @@ class OnlineExport(Callback):
self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, ::-1])[:, :, 0] # read rgb image and extract luminance self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, ::-1])[:, :, 0] # read rgb image and extract luminance
def _setup_graph(self): def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(['luminance'], ['prediction/output']) self.predictor = self.trainer.get_predictor(['luminance'], ['prediction/output'])
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
...@@ -367,7 +367,7 @@ you can simply `print(prediction)` to find out the name. ...@@ -367,7 +367,7 @@ you can simply `print(prediction)` to find out the name.
These two names allows us to build the inference part of the network in These two names allows us to build the inference part of the network in
```python ```python
self.trainer.get_predict_func(['luminance', 'prediction/output']) self.trainer.get_predictor(['luminance', 'prediction/output'])
``` ```
This is very convenient because in the `_tigger_epoch` we can use: This is very convenient because in the `_tigger_epoch` we can use:
...@@ -385,7 +385,7 @@ class OnlineExport(Callback): ...@@ -385,7 +385,7 @@ class OnlineExport(Callback):
self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, [2, 1, 0]])[:, :, 0] self.example_input = color.rgb2lab(cv2.imread('myimage.jpg')[:, :, [2, 1, 0]])[:, :, 0]
def _setup_graph(self): def _setup_graph(self):
self.trainer.get_predict_func(['luminance', 'prediction/output']) self.trainer.get_predictor(['luminance', 'prediction/output'])
def _trigger_epoch(self): def _trigger_epoch(self):
hopefully_cool_rgb = self.pred([[self.example_input]])[0][0] hopefully_cool_rgb = self.pred([[self.example_input]])[0][0]
......
...@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self): def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor( self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predict_funcs(['state'], ['logitsT', 'pred_value'], self.trainer.get_predictors(['state'], ['logitsT', 'pred_value'],
PREDICTOR_THREAD), batch_size=15) PREDICTOR_THREAD), batch_size=15)
def _before_train(self): def _before_train(self):
......
...@@ -38,7 +38,7 @@ def play_model(cfg): ...@@ -38,7 +38,7 @@ def play_model(cfg):
print("Total:", score) print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval): def eval_with_funcs(predictors, nr_eval):
class Worker(StoppableThread, ShareSessionThread): class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue): def __init__(self, func, queue):
super(Worker, self).__init__() super(Worker, self).__init__()
...@@ -62,7 +62,7 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -62,7 +62,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
q = queue.Queue() q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs] threads = [Worker(f, q) for f in predictors]
for k in threads: for k in threads:
k.start() k.start()
...@@ -103,7 +103,7 @@ class Evaluator(Triggerable): ...@@ -103,7 +103,7 @@ class Evaluator(Triggerable):
def _setup_graph(self): def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 20) NR_PROC = min(multiprocessing.cpu_count() // 2, 20)
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predictor(
self.input_names, self.output_names)] * NR_PROC self.input_names, self.output_names)] * NR_PROC
def _trigger(self): def _trigger(self):
......
...@@ -229,7 +229,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -229,7 +229,7 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, isOver] return [state, action, reward, isOver]
def _setup_graph(self): def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self.predictor_io_names) self.predictor = self.trainer.get_predictor(*self.predictor_io_names)
def _before_train(self): def _before_train(self):
self._init_memory() self._init_memory()
......
...@@ -258,7 +258,7 @@ def run_image(model, sess_init, inputs): ...@@ -258,7 +258,7 @@ def run_image(model, sess_init, inputs):
input_names=['input'], input_names=['input'],
output_names=['output'] output_names=['output']
) )
predict_func = OfflinePredictor(pred_config) predictor = OfflinePredictor(pred_config)
meta = dataset.ILSVRCMeta() meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean() pp_mean = meta.get_per_pixel_mean()
pp_mean_224 = pp_mean[16:-16, 16:-16, :] pp_mean_224 = pp_mean[16:-16, 16:-16, :]
...@@ -282,7 +282,7 @@ def run_image(model, sess_init, inputs): ...@@ -282,7 +282,7 @@ def run_image(model, sess_init, inputs):
assert img is not None assert img is not None
img = transformers.augment(img)[np.newaxis, :, :, :] img = transformers.augment(img)[np.newaxis, :, :, :]
outputs = predict_func([img])[0] outputs = predictor([img])[0]
prob = outputs[0] prob = outputs[0]
ret = prob.argsort()[-10:][::-1] ret = prob.argsort()[-10:][::-1]
......
...@@ -192,11 +192,11 @@ def run(model_path, image_path, output): ...@@ -192,11 +192,11 @@ def run(model_path, image_path, output):
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
input_names=['image'], input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)]) output_names=['output' + str(k) for k in range(1, 7)])
predict_func = OfflinePredictor(pred_config) predictor = OfflinePredictor(pred_config)
im = cv2.imread(image_path) im = cv2.imread(image_path)
assert im is not None assert im is not None
im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16)) im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16))
outputs = predict_func([[im.astype('float32')]]) outputs = predictor([[im.astype('float32')]])
if output is None: if output is None:
for k in range(6): for k in range(6):
pred = outputs[k][0] pred = outputs[k][0]
......
...@@ -30,7 +30,7 @@ class Model(tp.ModelDesc): ...@@ -30,7 +30,7 @@ class Model(tp.ModelDesc):
def run(model_path, image_path): def run(model_path, image_path):
predict_func = tp.OfflinePredictor(tp.PredictConfig( predictor = tp.OfflinePredictor(tp.PredictConfig(
model=Model(), model=Model(),
session_init=tp.get_model_loader(model_path), session_init=tp.get_model_loader(model_path),
input_names=['image'], input_names=['image'],
...@@ -42,7 +42,7 @@ def run(model_path, image_path): ...@@ -42,7 +42,7 @@ def run(model_path, image_path):
im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE)) im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
im = im.astype(np.float32)[:, :, ::-1] im = im.astype(np.float32)[:, :, ::-1]
saliency_images = predict_func([im])[0] saliency_images = predictor([im])[0]
abs_saliency = np.abs(saliency_images).max(axis=-1) abs_saliency = np.abs(saliency_images).max(axis=-1)
pos_saliency = np.maximum(0, saliency_images) pos_saliency = np.maximum(0, saliency_images)
......
...@@ -54,7 +54,7 @@ class Model(ModelDesc): ...@@ -54,7 +54,7 @@ class Model(ModelDesc):
def run_test(path, input): def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = np.load(path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig( predictor = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
input_names=['input'], input_names=['input'],
...@@ -65,7 +65,7 @@ def run_test(path, input): ...@@ -65,7 +65,7 @@ def run_test(path, input):
assert im is not None, input assert im is not None, input
im = cv2.resize(im, (227, 227))[:, :, ::-1].reshape( im = cv2.resize(im, (227, 227))[:, :, ::-1].reshape(
(1, 227, 227, 3)).astype('float32') - 110 (1, 227, 227, 3)).astype('float32') - 110
outputs = predict_func([im])[0] outputs = predictor([im])[0]
prob = outputs[0] prob = outputs[0]
ret = prob.argsort()[-10:][::-1] ret = prob.argsort()[-10:][::-1]
print("Top10 predictions:", ret) print("Top10 predictions:", ret)
......
...@@ -92,7 +92,7 @@ class InferenceRunner(Triggerable): ...@@ -92,7 +92,7 @@ class InferenceRunner(Triggerable):
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # these are all tensor names self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func( self.predictor = self.trainer.get_predictor(
self.input_tensors, self.output_tensors) self.input_tensors, self.output_tensors)
def _find_input_tensors(self): def _find_input_tensors(self):
...@@ -135,7 +135,7 @@ class InferenceRunner(Triggerable): ...@@ -135,7 +135,7 @@ class InferenceRunner(Triggerable):
self.ds.reset_state() self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar: with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
outputs = self.pred_func(dp) outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index] inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap] for k in tensormap]
......
...@@ -218,8 +218,8 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -218,8 +218,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
param: same as in :class:`HyperParamSetter`. param: same as in :class:`HyperParamSetter`.
schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``. schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``.
Each ``(ep, val)`` pair means to set the param Each ``(ep, val)`` pair means to set the param
to "val" after the completion of `ep` th epoch. to "val" __after__ the completion of `ep` th epoch.
If ep == 0, the value will be set before training. If ep == 0, the value will be set before the first epoch.
interp: None: no interpolation. 'linear': linear interpolation interp: None: no interpolation. 'linear': linear interpolation
Example: Example:
...@@ -263,6 +263,7 @@ class HyperParamSetterWithFunc(HyperParamSetter): ...@@ -263,6 +263,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
Args: Args:
param: same as in :class:`HyperParamSetter`. param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``. func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
``epoch_num`` is the number of epochs that have finished.
Example: Example:
Decrease by a factor of 0.9 every two epochs: Decrease by a factor of 0.9 every two epochs:
......
...@@ -7,7 +7,7 @@ from abc import abstractmethod, ABCMeta ...@@ -7,7 +7,7 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger from ..utils import logger, deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
...@@ -146,6 +146,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -146,6 +146,7 @@ class OfflinePredictor(OnlinePredictor):
input_tensors, output_tensors, config.return_input, sess) input_tensors, output_tensors, config.return_input, sess)
@deprecated("Use OfflinePredictor instead!", "2017-05-20")
def get_predict_func(config): def get_predict_func(config):
""" """
Equivalent to ``OfflinePredictor(config)``. Equivalent to ``OfflinePredictor(config)``.
......
...@@ -189,7 +189,7 @@ class Trainer(object): ...@@ -189,7 +189,7 @@ class Trainer(object):
self.summary_writer.close() self.summary_writer.close()
self.monitored_sess.close() self.monitored_sess.close()
def get_predict_func(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
""" """
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
...@@ -200,16 +200,25 @@ class Trainer(object): ...@@ -200,16 +200,25 @@ class Trainer(object):
""" """
if not hasattr(self, '_predictor_factory'): if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(self) self._predictor_factory = PredictorFactory(self)
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower
return self._predictor_factory.get_predictor(input_names, output_names, tower) return self._predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n): def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """ """ Return n predictors. """
nr_tower = len(self.config.predict_tower) return [self.get_predictor(input_names, output_names, k) for k in range(n)]
if nr_tower < n:
logger.warn( @deprecated("Use get_predictor instead!", "2017-05-20")
"Requested {} predictor but only have {} predict towers! " def get_predict_func(self, input_names, output_names, tower=0):
"Predictors will be assigned to GPUs in round-robin.".format(n, nr_tower)) return self.get_predictor(input_names, output_names, tower)
return [self.get_predict_func(input_names, output_names, k % nr_tower) for k in range(n)]
@deprecated("Use get_predictors instead!", "2017-05-20")
def get_predict_funcs(self, input_names, output_names, n):
return self.get_predictors(input_names, output_names, n)
@deprecated("Don't need to call it any more!", "2017-03-20") @deprecated("Don't need to call it any more!", "2017-03-20")
def _setup_predictor_factory(self): def _setup_predictor_factory(self):
......
...@@ -26,7 +26,6 @@ class PredictorFactory(object): ...@@ -26,7 +26,6 @@ class PredictorFactory(object):
self._tower_builder = PredictorTowerBuilder(fn) self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list) assert isinstance(self.towers, list)
# TODO sess option
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
Args: Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment