Commit 77bcc8b1 authored by Yuxin Wu's avatar Yuxin Wu

input_names instead of input_var_names

parent c6c9a4ba
...@@ -123,7 +123,7 @@ class Model(ModelDesc): ...@@ -123,7 +123,7 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = symbf.huber_loss(target - pred_action_value, name='cost') self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value), BATCH_SIZE, name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']), summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W ('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
...@@ -200,8 +200,8 @@ if __name__ == '__main__': ...@@ -200,8 +200,8 @@ if __name__ == '__main__':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_names=['state'],
output_var_names=['Qvalue']) output_names=['Qvalue'])
if args.task == 'play': if args.task == 'play':
play_model(cfg) play_model(cfg)
elif args.task == 'eval': elif args.task == 'eval':
......
...@@ -245,8 +245,8 @@ def run_image(model, sess_init, inputs): ...@@ -245,8 +245,8 @@ def run_image(model, sess_init, inputs):
model=model, model=model,
session_init=sess_init, session_init=sess_init,
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
input_var_names=['input'], input_names=['input'],
output_var_names=['output'] output_names=['output']
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
meta = dataset.ILSVRCMeta() meta = dataset.ILSVRCMeta()
......
...@@ -184,9 +184,9 @@ def get_config(): ...@@ -184,9 +184,9 @@ def get_config():
def run(model_path, image_path): def run(model_path, image_path):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0],
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
output_var_names=['output' + str(k) for k in range(1, 7)]) input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)])
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
im = cv2.imread(image_path) im = cv2.imread(image_path)
assert im is not None assert im is not None
......
...@@ -95,6 +95,6 @@ if __name__ == '__main__': ...@@ -95,6 +95,6 @@ if __name__ == '__main__':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_names=['state'],
output_var_names=['logits']) output_names=['logits'])
run_submission(cfg, args.output, args.episode) run_submission(cfg, args.output, args.episode)
...@@ -235,8 +235,8 @@ if __name__ == '__main__': ...@@ -235,8 +235,8 @@ if __name__ == '__main__':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_names=['state'],
output_var_names=['logits']) output_names=['logits'])
if args.task == 'play': if args.task == 'play':
play_model(cfg) play_model(cfg)
elif args.task == 'eval': elif args.task == 'eval':
......
## imagenet-resnet.py ## imagenet-resnet.py
ImageNet training code of pre-activation ResNet. It follows the setup in Training code of pre-activation ResNet on ImageNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code). [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code).
More results to come. More results to come.
......
...@@ -213,9 +213,9 @@ def eval_on_ILSVRC12(model_file, data_dir): ...@@ -213,9 +213,9 @@ def eval_on_ILSVRC12(model_file, data_dir):
ds = get_data('val') ds = get_data('val')
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_var_names=['input', 'label'],
session_init=get_model_loader(model_file), session_init=get_model_loader(model_file),
output_var_names=['wrong-top1', 'wrong-top5'] input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
) )
pred = SimpleDatasetPredictor(pred_config, ds) pred = SimpleDatasetPredictor(pred_config, ds)
acc1, acc5 = RatioCounter(), RatioCounter() acc1, acc5 = RatioCounter(), RatioCounter()
......
...@@ -111,8 +111,8 @@ def run_test(params, input): ...@@ -111,8 +111,8 @@ def run_test(params, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(params), session_init=ParamRestore(params),
input_var_names=['input'], input_names=['input'],
output_var_names=['prob'] output_names=['prob']
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
...@@ -134,9 +134,9 @@ def eval_on_ILSVRC12(params, data_dir): ...@@ -134,9 +134,9 @@ def eval_on_ILSVRC12(params, data_dir):
ds = BatchData(ds, 128, remainder=True) ds = BatchData(ds, 128, remainder=True)
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_var_names=['input', 'label'],
session_init=ParamRestore(params), session_init=ParamRestore(params),
output_var_names=['wrong-top1', 'wrong-top5'] input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
) )
pred = SimpleDatasetPredictor(pred_config, ds) pred = SimpleDatasetPredictor(pred_config, ds)
acc1, acc5 = RatioCounter(), RatioCounter() acc1, acc5 = RatioCounter(), RatioCounter()
......
...@@ -109,8 +109,8 @@ def view_warp(modelpath): ...@@ -109,8 +109,8 @@ def view_warp(modelpath):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(modelpath), session_init=get_model_loader(modelpath),
model=Model(), model=Model(),
input_var_names=['input'], input_names=['input'],
output_var_names=['viz', 'STN1/affine', 'STN2/affine'])) output_names=['viz', 'STN1/affine', 'STN2/affine']))
xys = np.array([[0, 0, 1], xys = np.array([[0, 0, 1],
[WARP_TARGET_SIZE, 0, 1], [WARP_TARGET_SIZE, 0, 1],
......
...@@ -53,10 +53,10 @@ def run_test(path, input): ...@@ -53,10 +53,10 @@ def run_test(path, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_var_names=['input'],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
output_var_names=['output'] # the variable 'output' is the probability distribution input_names=['input'],
output_names=['output'] # the variable 'output' is the probability distribution
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
......
...@@ -72,10 +72,10 @@ def run_test(path, input): ...@@ -72,10 +72,10 @@ def run_test(path, input):
param_dict = np.load(path).item() param_dict = np.load(path).item()
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_var_names=['input'], input_names=['input'],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
output_var_names=['output'] # output:0 is the probability distribution output_names=['output'] # output:0 is the probability distribution
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
......
...@@ -26,6 +26,7 @@ def serve_data(ds, addr): ...@@ -26,6 +26,7 @@ def serve_data(ds, addr):
try: try:
ds.reset_state() ds.reset_state()
logger.info("Serving data at {}".format(addr)) logger.info("Serving data at {}".format(addr))
# TODO print statistics here
while True: while True:
for dp in ds.get_data(): for dp in ds.get_data():
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
......
...@@ -8,7 +8,7 @@ import tensorflow as tf ...@@ -8,7 +8,7 @@ import tensorflow as tf
import six import six
from ..utils import logger from ..utils import logger
from ..tfutils import get_vars_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor', __all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase', 'AsyncPredictorBase',
...@@ -41,7 +41,7 @@ class PredictorBase(object): ...@@ -41,7 +41,7 @@ class PredictorBase(object):
@abstractmethod @abstractmethod
def _do_call(self, dp): def _do_call(self, dp):
""" """
:param dp: input datapoint. must have the same length as input_var_names :param dp: input datapoint. must have the same length as input_names
:return: output as defined by the config :return: output as defined by the config
""" """
...@@ -67,18 +67,18 @@ class AsyncPredictorBase(PredictorBase): ...@@ -67,18 +67,18 @@ class AsyncPredictorBase(PredictorBase):
return fut.result() return fut.result()
class OnlinePredictor(PredictorBase): class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_vars, output_vars, return_input=False): def __init__(self, sess, input_tensors, output_tensors, return_input=False):
self.session = sess self.session = sess
self.return_input = return_input self.return_input = return_input
self.input_vars = input_vars self.input_tensors = input_tensors
self.output_vars = output_vars self.output_tensors = output_tensors
def _do_call(self, dp): def _do_call(self, dp):
assert len(dp) == len(self.input_vars), \ assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_vars)) "{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_tensors, dp))
output = self.session.run(self.output_vars, feed_dict=feed) output = self.session.run(self.output_tensors, feed_dict=feed)
return output return output
...@@ -91,8 +91,8 @@ class OfflinePredictor(OnlinePredictor): ...@@ -91,8 +91,8 @@ class OfflinePredictor(OnlinePredictor):
with TowerContext('', False): with TowerContext('', False):
config.model.build_graph(input_vars) config.model.build_graph(input_vars)
input_vars = get_vars_by_names(config.input_var_names) input_vars = get_tensors_by_names(config.input_names)
output_vars = get_vars_by_names(config.output_var_names) output_vars = get_tensors_by_names(config.output_names)
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
...@@ -124,12 +124,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -124,12 +124,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess = tf.Session(config=config.session_config) self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess) config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names) input_vars = get_tensors_by_names(config.input_names)
for k in towers: for k in towers:
output_vars = get_vars_by_names( output_vars = get_tensors_by_names(
['{}{}/'.format(self.PREFIX, k) + n \ ['{}{}/'.format(self.PREFIX, k) + n \
for n in config.output_var_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input)) self.sess, input_vars, output_vars, config.return_input))
......
...@@ -26,9 +26,9 @@ class PredictConfig(object): ...@@ -26,9 +26,9 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to :param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session. initialize variables of a session.
:param input_var_names: a list of input variable names.
:param model: a `ModelDesc` instance :param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the :param input_names: a list of input variable names.
:param output_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False. :param return_input: whether to return (input, output) pair or just output. default to False.
...@@ -45,15 +45,24 @@ class PredictConfig(object): ...@@ -45,15 +45,24 @@ class PredictConfig(object):
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
# inputs & outputs # inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None) # TODO add deprecated warning later
if self.input_var_names is None: self.input_names = kwargs.pop('input_names', None)
if self.input_names is None:
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars] self.input_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names') self.output_names = kwargs.pop('output_names', None)
assert len(self.input_var_names), self.input_var_names if self.output_names is None:
for v in self.input_var_names: assert_type(v, six.string_types) self.output_names = kwargs.pop('output_var_names')
assert len(self.output_var_names), self.output_var_names #logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names
for v in self.input_names: assert_type(v, six.string_types)
assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False) self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment