Commit 77bcc8b1 authored by Yuxin Wu's avatar Yuxin Wu

input_names instead of input_var_names

parent c6c9a4ba
......@@ -123,7 +123,7 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = symbf.huber_loss(target - pred_action_value, name='cost')
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value), BATCH_SIZE, name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
......@@ -200,8 +200,8 @@ if __name__ == '__main__':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['Qvalue'])
input_names=['state'],
output_names=['Qvalue'])
if args.task == 'play':
play_model(cfg)
elif args.task == 'eval':
......
......@@ -245,8 +245,8 @@ def run_image(model, sess_init, inputs):
model=model,
session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_var_names=['input'],
output_var_names=['output']
input_names=['input'],
output_names=['output']
)
predict_func = get_predict_func(pred_config)
meta = dataset.ILSVRCMeta()
......
......@@ -184,9 +184,9 @@ def get_config():
def run(model_path, image_path):
pred_config = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=get_model_loader(model_path),
output_var_names=['output' + str(k) for k in range(1, 7)])
input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)])
predict_func = get_predict_func(pred_config)
im = cv2.imread(image_path)
assert im is not None
......
......@@ -95,6 +95,6 @@ if __name__ == '__main__':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['logits'])
input_names=['state'],
output_names=['logits'])
run_submission(cfg, args.output, args.episode)
......@@ -235,8 +235,8 @@ if __name__ == '__main__':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['logits'])
input_names=['state'],
output_names=['logits'])
if args.task == 'play':
play_model(cfg)
elif args.task == 'eval':
......
## imagenet-resnet.py
ImageNet training code of pre-activation ResNet. It follows the setup in
Training code of pre-activation ResNet on ImageNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code).
More results to come.
......
......@@ -213,9 +213,9 @@ def eval_on_ILSVRC12(model_file, data_dir):
ds = get_data('val')
pred_config = PredictConfig(
model=Model(),
input_var_names=['input', 'label'],
session_init=get_model_loader(model_file),
output_var_names=['wrong-top1', 'wrong-top5']
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, ds)
acc1, acc5 = RatioCounter(), RatioCounter()
......
......@@ -111,8 +111,8 @@ def run_test(params, input):
pred_config = PredictConfig(
model=Model(),
session_init=ParamRestore(params),
input_var_names=['input'],
output_var_names=['prob']
input_names=['input'],
output_names=['prob']
)
predict_func = get_predict_func(pred_config)
......@@ -134,9 +134,9 @@ def eval_on_ILSVRC12(params, data_dir):
ds = BatchData(ds, 128, remainder=True)
pred_config = PredictConfig(
model=Model(),
input_var_names=['input', 'label'],
session_init=ParamRestore(params),
output_var_names=['wrong-top1', 'wrong-top5']
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, ds)
acc1, acc5 = RatioCounter(), RatioCounter()
......
......@@ -109,8 +109,8 @@ def view_warp(modelpath):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(modelpath),
model=Model(),
input_var_names=['input'],
output_var_names=['viz', 'STN1/affine', 'STN2/affine']))
input_names=['input'],
output_names=['viz', 'STN1/affine', 'STN2/affine']))
xys = np.array([[0, 0, 1],
[WARP_TARGET_SIZE, 0, 1],
......
......@@ -53,10 +53,10 @@ def run_test(path, input):
pred_config = PredictConfig(
model=Model(),
input_var_names=['input'],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output'] # the variable 'output' is the probability distribution
input_names=['input'],
output_names=['output'] # the variable 'output' is the probability distribution
)
predict_func = get_predict_func(pred_config)
......
......@@ -72,10 +72,10 @@ def run_test(path, input):
param_dict = np.load(path).item()
pred_config = PredictConfig(
model=Model(),
input_var_names=['input'],
input_names=['input'],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output'] # output:0 is the probability distribution
output_names=['output'] # output:0 is the probability distribution
)
predict_func = get_predict_func(pred_config)
......
......@@ -26,6 +26,7 @@ def serve_data(ds, addr):
try:
ds.reset_state()
logger.info("Serving data at {}".format(addr))
# TODO print statistics here
while True:
for dp in ds.get_data():
socket.send(dumps(dp), copy=False)
......
......@@ -8,7 +8,7 @@ import tensorflow as tf
import six
from ..utils import logger
from ..tfutils import get_vars_by_names, TowerContext
from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
......@@ -41,7 +41,7 @@ class PredictorBase(object):
@abstractmethod
def _do_call(self, dp):
"""
:param dp: input datapoint. must have the same length as input_var_names
:param dp: input datapoint. must have the same length as input_names
:return: output as defined by the config
"""
......@@ -67,18 +67,18 @@ class AsyncPredictorBase(PredictorBase):
return fut.result()
class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_vars, output_vars, return_input=False):
def __init__(self, sess, input_tensors, output_tensors, return_input=False):
self.session = sess
self.return_input = return_input
self.input_vars = input_vars
self.output_vars = output_vars
self.input_tensors = input_tensors
self.output_tensors = output_tensors
def _do_call(self, dp):
assert len(dp) == len(self.input_vars), \
"{} != {}".format(len(dp), len(self.input_vars))
feed = dict(zip(self.input_vars, dp))
output = self.session.run(self.output_vars, feed_dict=feed)
assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp))
output = self.session.run(self.output_tensors, feed_dict=feed)
return output
......@@ -91,8 +91,8 @@ class OfflinePredictor(OnlinePredictor):
with TowerContext('', False):
config.model.build_graph(input_vars)
input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names)
input_vars = get_tensors_by_names(config.input_names)
output_vars = get_tensors_by_names(config.output_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
......@@ -124,12 +124,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names)
input_vars = get_tensors_by_names(config.input_names)
for k in towers:
output_vars = get_vars_by_names(
output_vars = get_tensors_by_names(
['{}{}/'.format(self.PREFIX, k) + n \
for n in config.output_var_names])
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
......
......@@ -26,9 +26,9 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_var_names: a list of input variable names.
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
:param input_names: a list of input variable names.
:param output_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
......@@ -45,15 +45,24 @@ class PredictConfig(object):
assert_type(self.model, ModelDesc)
# inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None)
if self.input_var_names is None:
# TODO add deprecated warning later
self.input_names = kwargs.pop('input_names', None)
if self.input_names is None:
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names')
assert len(self.input_var_names), self.input_var_names
for v in self.input_var_names: assert_type(v, six.string_types)
assert len(self.output_var_names), self.output_var_names
self.input_names = [k.name for k in raw_vars]
self.output_names = kwargs.pop('output_names', None)
if self.output_names is None:
self.output_names = kwargs.pop('output_var_names')
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names
for v in self.input_names: assert_type(v, six.string_types)
assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment