Commit 58f26c39 authored by Yuxin Wu's avatar Yuxin Wu

add plot script

parent 64a63c5e
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
A general curve plotter used to create curves like:
https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet
A simplest example:
$ cat examples/train_log/mnist-convnet/stat.json \
| jq '.[] | .train_error, .validation_error' \
| paste - - \
| plot-point.py --legend 'train,val' --title 'error'
For more usage, see `plot-point.py -h` or the code.
"""
from math import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fontm
import argparse, sys
from collections import defaultdict
from itertools import chain
#from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('text', usetex=True)
STDIN_FNAME = '-'
def get_args():
description = "plot points into graph."
parser = argparse.ArgumentParser(description=description)
parser.add_argument('-i', '--input',
help='input data file, use "-" for stdin. Default stdin. Input \
format is many rows of DELIMIETER-separated data',
default='-')
parser.add_argument('-o', '--output',
help='output image', default='')
parser.add_argument('--show',
help='show the figure after rendered',
action='store_true')
parser.add_argument('-c', '--column',
help="describe each column in data, for example 'x,y,y'. \
Default to 'y' for one column and 'x,y' for two columns. \
Plot attributes can be appended after 'y', like 'ythick;cr'. \
By default, assume all columns are y. \
")
parser.add_argument('-t', '--title',
help='title of the graph',
default='')
parser.add_argument('--xlabel',
help='x label',
default = 'x')
parser.add_argument('--ylabel',
help='y label',
default='y')
parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum',
help = 'annonate maximum value in graph',
action = 'store_true')
parser.add_argument('--annotate-minimum',
help = 'annonate minimum value in graph',
action = 'store_true')
parser.add_argument('--xkcd',
help = 'xkcd style',
action = 'store_true')
parser.add_argument('--decay',
help='exponential decay rate to smooth Y',
type=float, default=0)
parser.add_argument('-l', '--legend',
help='legend for each y')
parser.add_argument('-d', '--delimeter',
help='column delimeter', default='\t')
global args
args = parser.parse_args();
if not args.show and not args.output:
args.show = True
def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)"""
ret = []
for x, y in points:
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]:
ret.append((x, y))
if len(ret) == 0:
ret.append(points[0])
return ret
def exponential_smooth(data, alpha):
""" smooth data by alpha. returned a smoothed version"""
ret = np.copy(data)
now = data[0]
for k in range(len(data)):
ret[k] = now * alpha + data[k] * (1-alpha)
now = ret[k]
return ret
def annotate_min_max(data_x, data_y, ax):
max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y)
x_range = max_x - min_x
y_range = max_y - min_y
x_max, y_max = data_y[0], data_y[0]
x_min, y_min = data_x[0], data_y[0]
for i in xrange(1, len(data_x)):
if data_y[i] > y_max:
y_max = data_y[i]
x_max = data_x[i]
if data_y[i] < y_min:
y_min = data_y[i]
x_min = data_x[i]
rect = ax.axis()
if args.annotate_maximum:
text_x, text_y = filter_valid_range([
(x_max + 0.05 * x_range,
y_max + 0.025 * y_range),
(x_max - 0.05 * x_range,
y_max + 0.025 * y_range),
(x_max + 0.05 * x_range,
y_max - 0.025 * y_range),
(x_max - 0.05 * x_range,
y_max - 0.025 * y_range)],
rect)[0]
ax.annotate('maximum ({:d},{:.3f})' . format(int(x_max), y_max),
xy = (x_max, y_max),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
if args.annotate_minimum:
text_x, text_y = filter_valid_range([
(x_min + 0.05 * x_range,
y_min - 0.025 * y_range),
(x_min - 0.05 * x_range,
y_min - 0.025 * y_range),
(x_min + 0.05 * x_range,
y_min + 0.025 * y_range),
(x_min - 0.05 * x_range,
y_min + 0.025 * y_range)],
rect)[0]
ax.annotate('minimum ({:d},{:.3f})' . format(int(x_min), y_min),
xy = (x_min, y_min),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
#ax.annotate('{:.3f}' . format(y_min),
#xy = (x_min, y_min),
#xytext = (text_x, text_y),
#arrowprops = dict(arrowstyle = '->'))
def plot_args_from_column_desc(desc):
if not desc:
return {}
ret = {}
desc = desc.split(';')
if 'thick' in desc:
ret['lw'] = 5
if 'dash' in desc:
ret['ls'] = '--'
for v in desc:
if v.startswith('c'):
ret['color'] = v[1:]
return ret
def do_plot(data_xs, data_ys):
"""
data_xs: list of 1d array, either of size 1 of size len(data_ys)
data_ys: list of 1d array
"""
fig = plt.figure(figsize = (16.18/1.2, 10/1.2))
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
nr_y = len(data_ys)
y_column = args.y_column
# parse legend and y-scale
if args.legend:
legends = args.legend.split(',')
assert len(legends) == nr_y
else:
legends = None #range(nr_y) #None
if args.scale:
scale = map(float, args.scale.split(','))
assert len(scale) == nr_y
else:
scale = [1.0] * nr_y
for yidx in range(nr_y):
plotargs = plot_args_from_column_desc(y_column[yidx][1:])
now_scale = scale[yidx]
data_y = data_ys[yidx] * now_scale
leg = legends[yidx] if legends else None
if now_scale != 1:
leg = "{}*{}".format(now_scale if int(now_scale) != now_scale else int(now_scale), leg)
data_x = data_xs[0] if len(data_xs) == 1 else data_xs[yidx]
assert len(data_x) >= len(data_y), \
"x column is shorter than y column! {} < {}".format(
len(data_x), len(data_y))
truncate_data_x = data_x[:len(data_y)]
p = plt.plot(truncate_data_x, data_y, label=leg, **plotargs)
c = p[0].get_color()
plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c)
#ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none')
#plt.xticks([])
#plt.yticks([])
if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax)
plt.xlabel(args.xlabel, fontsize='xx-large')
plt.ylabel(args.ylabel, fontsize='xx-large')
plt.legend(loc='best', fontsize='xx-large')
# adjust maxx
minx, maxx = min(data_x), max(data_x)
new_maxx = maxx + (maxx - minx) * 0.05
plt.xlim(minx, new_maxx)
for label in chain.from_iterable(
[ax.get_xticklabels(), ax.get_yticklabels()]):
label.set_fontproperties(fontm.FontProperties(size=15))
ax.grid(color = 'gray', linestyle = 'dashed')
plt.title(args.title, fontdict={'fontsize': '20'})
if args.output != '':
plt.savefig(args.output)
if args.show:
plt.show()
def main():
get_args()
# parse input args
if args.input == STDIN_FNAME:
fin = sys.stdin
else:
fin = open(args.input)
all_inputs = fin.readlines()
if args.input != STDIN_FNAME:
fin.close()
# parse column format
nr_column = len(all_inputs[0].rstrip().split())
if args.column is None:
column = ['y'] * nr_column
else:
column = args.column.strip().split(',')
for k in column: assert k[0] in ['x', 'y']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
args.x_column = [v for v in column if v[0] == 'x']
args.x_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'x']
nr_x_column = len(args.x_column)
nr_y_column = len(args.y_column)
if nr_x_column > 1:
assert nr_x_column == nr_y_column, \
"If multiple x columns are used, nr_x_column must equals to nr_y_column"
x_column_set = set(args.x_column)
# read and parse data
data = [[] for _ in range(nr_column)]
ended = defaultdict(bool)
data_format = -1
for lineno, line in enumerate(all_inputs):
line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \
"""One row have too many columns (separated by {})!
Line: {}""".format(repr(args.delimeter), line)
for idx, val in enumerate(line):
if val == '':
ended[idx] = True
continue
else:
val = float(val)
assert not ended[idx], "Column {} has hole!".format(idx)
data[idx].append(val)
data_ys = [data[k] for k in args.y_column_idx]
max_ysize = max([len(t) for t in data_ys])
print "Size of the longest y column: ", max_ysize
if nr_x_column:
data_xs = [data[k] for k in args.x_column_idx]
else:
data_xs = [list(range(max_ysize))]
for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y)
if args.decay != 0:
data_ys[idx] = exponential_smooth(data_y, args.decay)
#if idx == 0: # TODO allow different decay for each y
#data_ys[idx] = exponential_smooth(data_y, 0.5)
for idx, data_x in enumerate(data_xs):
data_xs[idx] = np.asarray(data_x)
if args.xkcd:
with plt.xkcd():
do_plot(data_xs, data_ys)
else:
do_plot(data_xs, data_ys)
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment