Commit fbf93d44 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in inferencer, and better naming

parent 941093a9
......@@ -9,7 +9,7 @@ A simplest example:
$ cat examples/train_log/mnist-convnet/stat.json \
| jq '.[] | .train_error, .validation_error' \
| paste - - \
| plot-point.py --legend 'train,val' --title 'error'
| plot-point.py --legend 'train,val' --xlabel 'epoch' --ylabel 'error'
For more usage, see `plot-point.py -h` or the code.
"""
......@@ -23,8 +23,9 @@ import argparse, sys
from collections import defaultdict
from itertools import chain
#from matplotlib import rc
from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
#rc('text', usetex=True)
STDIN_FNAME = '-'
......@@ -168,7 +169,7 @@ def plot_args_from_column_desc(desc):
def do_plot(data_xs, data_ys):
"""
data_xs: list of 1d array, either of size 1 of size len(data_ys)
data_xs: list of 1d array, either of size 1 or size len(data_ys)
data_ys: list of 1d array
"""
fig = plt.figure(figsize = (16.18/1.2, 10/1.2))
......@@ -214,8 +215,8 @@ def do_plot(data_xs, data_ys):
if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax)
plt.xlabel(args.xlabel, fontsize='xx-large')
plt.ylabel(args.ylabel, fontsize='xx-large')
plt.xlabel(args.xlabel.decode('utf-8'), fontsize='xx-large')
plt.ylabel(args.ylabel.decode('utf-8'), fontsize='xx-large')
plt.legend(loc='best', fontsize='xx-large')
# adjust maxx
......@@ -232,7 +233,7 @@ def do_plot(data_xs, data_ys):
plt.title(args.title, fontdict={'fontsize': '20'})
if args.output != '':
plt.savefig(args.output)
plt.savefig(args.output, bbox_inches='tight')
if args.show:
plt.show()
......
......@@ -55,7 +55,7 @@ class Inferencer(object):
"""
Return a list of tensor names needed for this inference
"""
return self._get_output_vars()
return self._get_output_tensors()
@abstractmethod
def _get_output_tensors(self):
......@@ -66,18 +66,18 @@ class InferenceRunner(Callback):
A callback that runs different kinds of inferencer.
"""
def __init__(self, ds, vcs):
def __init__(self, ds, infs):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param vcs: a list of `Inferencer` instance.
:param infs: a list of `Inferencer` instance.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(vcs, list):
self.vcs = [vcs]
if not isinstance(infs, list):
self.infs = [infs]
else:
self.vcs = vcs
for v in self.vcs:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
def _setup_graph(self):
......@@ -89,21 +89,21 @@ class InferenceRunner(Callback):
def _find_output_tensors(self):
self.output_tensors = [] # list of names
self.vc_to_vars = [] # list of list of (var_name: output_idx)
for vc in self.vcs:
vc_vars = vc._get_output_tensors()
def find_oid(var):
if var in self.output_tensors:
return self.output_tensors.index(var)
self.inf_to_tensors = [] # list of list of (var_name: output_idx)
for inf in self.infs:
inf_tensors = inf.get_output_tensors()
def find_oid(t):
if t in self.output_tensors:
return self.output_tensors.index(t)
else:
self.output_tensors.append(var)
self.output_tensors.append(t)
return len(self.output_tensors) - 1
vc_vars = [(var, find_oid(var)) for var in vc_vars]
self.vc_to_vars.append(vc_vars)
inf_tensors = [(t, find_oid(t)) for t in inf_tensors]
self.inf_to_tensors.append(inf_tensors)
def _trigger_epoch(self):
for vc in self.vcs:
vc.before_inference()
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
......@@ -112,18 +112,18 @@ class InferenceRunner(Callback):
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs = self.pred_func(dp)
for vc, varsmap in zip(self.vcs, self.vc_to_vars):
vc_output = [outputs[k[1]] for k in varsmap]
vc.datapoint(dp, vc_output)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [outputs[k[1]] for k in tensormap]
inf.datapoint(dp, inf_output)
pbar.update()
for vc in self.vcs:
ret = vc.after_inference()
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(vc).__name__))
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
self.trainer.write_scalar_summary(k, v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment