Commit 5626e04d authored by Yuxin Wu's avatar Yuxin Wu

Update plot-point.py

parent 3e9f164d
...@@ -86,28 +86,100 @@ def get_args(): ...@@ -86,28 +86,100 @@ def get_args():
args.show = True args.show = True
def filter_valid_range(points, rect): def read_entire_matrix():
"""rect = (min_x, max_x, min_y, max_y)""" # parse input args
ret = [] if args.input == STDIN_FNAME:
for x, y in points: fin = sys.stdin
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]: else:
ret.append((x, y)) fin = open(args.input)
if len(ret) == 0: all_lines = fin.readlines()
ret.append(points[0]) if args.input != STDIN_FNAME:
return ret fin.close()
nr_column = len(all_lines[0].rstrip('\n').split(args.delimeter))
# read the entire matrix to 'data'
data = [[] for _ in range(nr_column)]
ended = defaultdict(bool)
for lineno, line in enumerate(all_lines):
line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \
"""One row have too many columns (separated by {})!
Line: {}""".format(repr(args.delimeter), line)
for idx, val in enumerate(line):
if val == '':
ended[idx] = True
continue
else:
val = float(val)
assert not ended[idx], "Column {} has hole!".format(idx)
data[idx].append(val)
return data
def exponential_smooth(data, alpha): class Sequence(object):
""" smooth data by alpha. returned a smoothed version""" def __init__(self, xs, ys, plot_args=None):
"""
Args:
xs, ys: a list of floats
"""
self.xs = np.copy(np.asarray(xs))
self.ys = np.asarray(ys)
assert len(xs) >= len(ys), \
"x column is shorter than y column! {} < {}".format(len(xs), len(ys))
self.xs = self.xs[:len(ys)]
if plot_args is None:
plot_args = {}
self.plot_args = plot_args
self.legend = None
self.drawables = []
def exponential_smooth(self, alpha):
""" smooth data by alpha."""
data = self.ys
ret = np.copy(data) ret = np.copy(data)
now = data[0] now = data[0]
for k in range(len(data)): for k in range(len(data)):
ret[k] = now * alpha + data[k] * (1 - alpha) ret[k] = now * alpha + data[k] * (1 - alpha)
now = ret[k] now = ret[k]
return ret self.ys = ret
def scale_y(self, scale):
if scale == 1.0:
return
self.ys *= scale
if self.legend:
self.legend = "{},scaley={:.2g}".format(self.legend, scale)
@property
def xrange(self):
return np.array([min(self.xs), max(self.xs)])
def toggle_vis(self):
assert len(self.drawables), "Called before plot()!"
vis = not self.drawables[0].get_visible()
for d in self.drawables:
d.set_visible(vis)
return vis
def annotate_min_max(data_x, data_y, ax): def annotate_min_max(data_x, data_y, ax):
"""
Annotate on top of ax, given one sequence of X and Y.
"""
def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)"""
ret = []
for x, y in points:
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]:
ret.append((x, y))
if len(ret) == 0:
ret.append(points[0])
return ret
max_x, min_x = max(data_x), min(data_x) max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y) max_y, min_y = max(data_y), min(data_y)
x_range = max_x - min_x x_range = max_x - min_x
...@@ -175,77 +247,63 @@ def plot_args_from_column_desc(desc): ...@@ -175,77 +247,63 @@ def plot_args_from_column_desc(desc):
return ret return ret
def do_plot(data_xs, data_ys): def do_plot(seqs):
""" """
data_xs: list of 1d array, either of size 1 or size len(data_ys) seqs: [Sequence]
data_ys: list of 1d array
""" """
fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2)) fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2))
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7)) ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
nr_y = len(data_ys)
y_column = args.y_column
# parse legend and y-scale for seq in seqs:
if args.legend: curve_obj = plt.plot(seq.xs, seq.ys, label=seq.legend, **seq.plot_args)[0]
legends = args.legend.split(',')
assert len(legends) == nr_y c = curve_obj.get_color()
else: fill_obj = plt.fill_between(seq.xs, seq.ys, alpha=0.1, facecolor=c)
legends = None # range(nr_y) #None
if args.scale:
scale = map(float, args.scale.split(','))
assert len(scale) == nr_y
else:
scale = [1.0] * nr_y
for yidx in range(nr_y):
plotargs = plot_args_from_column_desc(y_column[yidx][1:])
now_scale = scale[yidx]
data_y = data_ys[yidx] * now_scale
leg = legends[yidx] if legends else None
if now_scale != 1:
leg = "{}*{}".format(now_scale if int(now_scale) != now_scale else int(now_scale), leg)
data_x = data_xs[0] if len(data_xs) == 1 else data_xs[yidx]
assert len(data_x) >= len(data_y), \
"x column is shorter than y column! {} < {}".format(
len(data_x), len(data_y))
truncate_data_x = data_x[:len(data_y)]
p = plt.plot(truncate_data_x, data_y, label=leg, **plotargs)
c = p[0].get_color()
plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c)
# ax.set_aspect('equal', 'datalim')
# ax.spines['right'].set_color('none')
# ax.spines['left'].set_color('none')
# plt.xticks([])
# plt.yticks([])
if args.annotate_maximum or args.annotate_minimum: if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax) annotate_min_max(seq.xs, seq.ys, ax)
seq.drawables.extend([curve_obj, fill_obj])
# deal with label and xlim
if args.xlabel: if args.xlabel:
plt.xlabel(args.xlabel, fontsize='xx-large') plt.xlabel(args.xlabel, fontsize='xx-large')
if args.ylabel: if args.ylabel:
plt.ylabel(args.ylabel, fontsize='xx-large') plt.ylabel(args.ylabel, fontsize='xx-large')
if args.xlim: if args.xlim:
plt.xlim(args.xlim[0], args.xlim[1]) plt.xlim(args.xlim[0], args.xlim[1])
if args.ylim: else:
plt.ylim(args.ylim[0], args.ylim[1])
plt.legend(loc='best', fontsize='xx-large')
# adjust maxx # adjust maxx
minx, maxx = min(data_x), max(data_x) all_xrange = np.asarray([s.xrange for s in seqs])
minx, maxx = min(all_xrange[:, 0]), max(all_xrange[:, 1])
new_maxx = maxx + (maxx - minx) * 0.05 new_maxx = maxx + (maxx - minx) * 0.05
plt.xlim(minx, new_maxx) plt.xlim(minx, new_maxx)
if args.ylim:
plt.ylim(args.ylim[0], args.ylim[1])
legend_obj = plt.legend(loc='best', fontsize='xx-large')
# setup click behavior
legend_line_to_seq = {}
for legend_line, seq in zip(legend_obj.get_lines(), seqs):
legend_line.set_picker(5)
legend_line_to_seq[legend_line] = seq
def onclick(event):
legline = event.artist
seq = legend_line_to_seq[legline]
vis = seq.toggle_vis()
legline.set_alpha(1.0 if vis else 0.2)
fig.canvas.draw()
fig.canvas.mpl_connect('pick_event', onclick)
for label in chain.from_iterable( for label in chain.from_iterable(
[ax.get_xticklabels(), ax.get_yticklabels()]): [ax.get_xticklabels(), ax.get_yticklabels()]):
label.set_fontproperties(fontm.FontProperties(size=15)) label.set_fontproperties(fontm.FontProperties(size=15))
ax.grid(color='gray', linestyle='dashed') ax.grid(color='gray', linestyle='dashed')
plt.title(args.title, fontdict={'fontsize': '20'}) plt.title(args.title, fontdict={'fontsize': '20'})
if args.output != '': if args.output != '':
plt.savefig(args.output, bbox_inches='tight') plt.savefig(args.output, bbox_inches='tight')
if args.show: if args.show:
...@@ -254,17 +312,10 @@ def do_plot(data_xs, data_ys): ...@@ -254,17 +312,10 @@ def do_plot(data_xs, data_ys):
def main(): def main():
get_args() get_args()
# parse input args data = read_entire_matrix() # #col x #row
if args.input == STDIN_FNAME:
fin = sys.stdin
else:
fin = open(args.input)
all_inputs = fin.readlines()
if args.input != STDIN_FNAME:
fin.close()
# parse column format # parse column format
nr_column = len(all_inputs[0].rstrip('\n').split(args.delimeter)) nr_column = len(data)
if args.column is None: if args.column is None:
column = ['y'] * nr_column column = ['y'] * nr_column
else: else:
...@@ -272,57 +323,60 @@ def main(): ...@@ -272,57 +323,60 @@ def main():
for k in column: for k in column:
assert k[0] in ['x', 'y', 'n'] assert k[0] in ['x', 'y', 'n']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column)) assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
args.x_column = [v for v in column if v[0] == 'x']
args.x_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'x']
nr_x_column = len(args.x_column)
nr_y_column = len(args.y_column)
if nr_x_column > 1:
assert nr_x_column == nr_y_column, \
"If multiple x columns are used, nr_x_column must equals to nr_y_column"
# read and parse data
data = [[] for _ in range(nr_column)]
ended = defaultdict(bool)
for lineno, line in enumerate(all_inputs):
line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \
"""One row have too many columns (separated by {})!
Line: {}""".format(repr(args.delimeter), line)
for idx, val in enumerate(line):
if val == '':
ended[idx] = True
continue
else:
val = float(val)
assert not ended[idx], "Column {} has hole!".format(idx)
data[idx].append(val)
data_ys = [data[k] for k in args.y_column_idx] # split data into Xs and Ys
data_xs, data_ys, desc_ys = [], [], [] # #col x #row
for column_data, column_desc in zip(data, column):
if column_desc[0] == 'y':
data_ys.append(column_data)
desc_ys.append(column_desc)
elif column_desc[0] == 'x':
data_xs.append(column_data)
num_curve = len(data_ys)
length_ys = [len(t) for t in data_ys] length_ys = [len(t) for t in data_ys]
print("Length of each column:", length_ys) print("Length of each Y column:", length_ys)
max_ysize = max(length_ys)
# populate default xs
if nr_x_column: if len(data_xs) > 1:
data_xs = [data[k] for k in args.x_column_idx] assert len(data_xs) == num_curve, \
"If multiple x columns are used, num_x_column must equals to nr_y_column"
elif len(data_xs) == 1:
data_xs = data_xs * num_curve
else: else:
data_xs = [list(range(1, max_ysize + 1))] data_xs = [list(range(1, max(length_ys) + 1))] * num_curve
# put into seq
seqs = []
assert len(data_xs) == len(data_ys)
for idx, (X, Y) in enumerate(zip(data_xs, data_ys)):
col_desc = desc_ys[idx]
seqs.append(Sequence(
X, Y,
plot_args=plot_args_from_column_desc(col_desc[1:])))
for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y)
if args.decay != 0: if args.decay != 0:
data_ys[idx] = exponential_smooth(data_y, args.decay) for s in seqs:
s.exponential_smooth(args.decay)
# if idx == 0: # TODO allow different decay for each y # if idx == 0: # TODO allow different decay for each y
# data_ys[idx] = exponential_smooth(data_y, 0.5) # data_ys[idx] = exponential_smooth(data_y, 0.5)
for idx, data_x in enumerate(data_xs):
data_xs[idx] = np.asarray(data_x) if args.legend:
legends = args.legend.split(',')
assert len(legends) == num_curve
for legend, seq in zip(legends, seqs):
seq.legend = legend
if args.scale:
scales = list(map(float, args.scale.split(',')))
assert len(scales) == num_curve
for scale, seq in zip(scales, seqs):
seq.scale_y(scale)
if args.xkcd: if args.xkcd:
with plt.xkcd(): with plt.xkcd():
do_plot(data_xs, data_ys) do_plot(seqs)
else: else:
do_plot(data_xs, data_ys) do_plot(seqs)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment