Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
58f26c39
Commit
58f26c39
authored
Jul 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add plot script
parent
64a63c5e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
312 additions
and
0 deletions
+312
-0
scripts/plot-point.py
scripts/plot-point.py
+312
-0
No files found.
scripts/plot-point.py
0 → 100755
View file @
58f26c39
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
A general curve plotter used to create curves like:
https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet
A simplest example:
$ cat examples/train_log/mnist-convnet/stat.json
\
| jq '.[] | .train_error, .validation_error'
\
| paste - -
\
| plot-point.py --legend 'train,val' --title 'error'
For more usage, see `plot-point.py -h` or the code.
"""
from
math
import
*
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
matplotlib.font_manager
as
fontm
import
argparse
,
sys
from
collections
import
defaultdict
from
itertools
import
chain
#from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('text', usetex=True)
STDIN_FNAME
=
'-'
def
get_args
():
description
=
"plot points into graph."
parser
=
argparse
.
ArgumentParser
(
description
=
description
)
parser
.
add_argument
(
'-i'
,
'--input'
,
help
=
'input data file, use "-" for stdin. Default stdin. Input
\
format is many rows of DELIMIETER-separated data'
,
default
=
'-'
)
parser
.
add_argument
(
'-o'
,
'--output'
,
help
=
'output image'
,
default
=
''
)
parser
.
add_argument
(
'--show'
,
help
=
'show the figure after rendered'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-c'
,
'--column'
,
help
=
"describe each column in data, for example 'x,y,y'.
\
Default to 'y' for one column and 'x,y' for two columns.
\
Plot attributes can be appended after 'y', like 'ythick;cr'.
\
By default, assume all columns are y.
\
"
)
parser
.
add_argument
(
'-t'
,
'--title'
,
help
=
'title of the graph'
,
default
=
''
)
parser
.
add_argument
(
'--xlabel'
,
help
=
'x label'
,
default
=
'x'
)
parser
.
add_argument
(
'--ylabel'
,
help
=
'y label'
,
default
=
'y'
)
parser
.
add_argument
(
'-s'
,
'--scale'
,
help
=
'scale of each y, separated by comma'
)
parser
.
add_argument
(
'--annotate-maximum'
,
help
=
'annonate maximum value in graph'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--annotate-minimum'
,
help
=
'annonate minimum value in graph'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--xkcd'
,
help
=
'xkcd style'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--decay'
,
help
=
'exponential decay rate to smooth Y'
,
type
=
float
,
default
=
0
)
parser
.
add_argument
(
'-l'
,
'--legend'
,
help
=
'legend for each y'
)
parser
.
add_argument
(
'-d'
,
'--delimeter'
,
help
=
'column delimeter'
,
default
=
'
\t
'
)
global
args
args
=
parser
.
parse_args
();
if
not
args
.
show
and
not
args
.
output
:
args
.
show
=
True
def
filter_valid_range
(
points
,
rect
):
"""rect = (min_x, max_x, min_y, max_y)"""
ret
=
[]
for
x
,
y
in
points
:
if
x
>=
rect
[
0
]
and
x
<=
rect
[
1
]
and
y
>=
rect
[
2
]
and
y
<=
rect
[
3
]:
ret
.
append
((
x
,
y
))
if
len
(
ret
)
==
0
:
ret
.
append
(
points
[
0
])
return
ret
def
exponential_smooth
(
data
,
alpha
):
""" smooth data by alpha. returned a smoothed version"""
ret
=
np
.
copy
(
data
)
now
=
data
[
0
]
for
k
in
range
(
len
(
data
)):
ret
[
k
]
=
now
*
alpha
+
data
[
k
]
*
(
1
-
alpha
)
now
=
ret
[
k
]
return
ret
def
annotate_min_max
(
data_x
,
data_y
,
ax
):
max_x
,
min_x
=
max
(
data_x
),
min
(
data_x
)
max_y
,
min_y
=
max
(
data_y
),
min
(
data_y
)
x_range
=
max_x
-
min_x
y_range
=
max_y
-
min_y
x_max
,
y_max
=
data_y
[
0
],
data_y
[
0
]
x_min
,
y_min
=
data_x
[
0
],
data_y
[
0
]
for
i
in
xrange
(
1
,
len
(
data_x
)):
if
data_y
[
i
]
>
y_max
:
y_max
=
data_y
[
i
]
x_max
=
data_x
[
i
]
if
data_y
[
i
]
<
y_min
:
y_min
=
data_y
[
i
]
x_min
=
data_x
[
i
]
rect
=
ax
.
axis
()
if
args
.
annotate_maximum
:
text_x
,
text_y
=
filter_valid_range
([
(
x_max
+
0.05
*
x_range
,
y_max
+
0.025
*
y_range
),
(
x_max
-
0.05
*
x_range
,
y_max
+
0.025
*
y_range
),
(
x_max
+
0.05
*
x_range
,
y_max
-
0.025
*
y_range
),
(
x_max
-
0.05
*
x_range
,
y_max
-
0.025
*
y_range
)],
rect
)[
0
]
ax
.
annotate
(
'maximum ({:d},{:.3f})'
.
format
(
int
(
x_max
),
y_max
),
xy
=
(
x_max
,
y_max
),
xytext
=
(
text_x
,
text_y
),
arrowprops
=
dict
(
arrowstyle
=
'->'
))
if
args
.
annotate_minimum
:
text_x
,
text_y
=
filter_valid_range
([
(
x_min
+
0.05
*
x_range
,
y_min
-
0.025
*
y_range
),
(
x_min
-
0.05
*
x_range
,
y_min
-
0.025
*
y_range
),
(
x_min
+
0.05
*
x_range
,
y_min
+
0.025
*
y_range
),
(
x_min
-
0.05
*
x_range
,
y_min
+
0.025
*
y_range
)],
rect
)[
0
]
ax
.
annotate
(
'minimum ({:d},{:.3f})'
.
format
(
int
(
x_min
),
y_min
),
xy
=
(
x_min
,
y_min
),
xytext
=
(
text_x
,
text_y
),
arrowprops
=
dict
(
arrowstyle
=
'->'
))
#ax.annotate('{:.3f}' . format(y_min),
#xy = (x_min, y_min),
#xytext = (text_x, text_y),
#arrowprops = dict(arrowstyle = '->'))
def
plot_args_from_column_desc
(
desc
):
if
not
desc
:
return
{}
ret
=
{}
desc
=
desc
.
split
(
';'
)
if
'thick'
in
desc
:
ret
[
'lw'
]
=
5
if
'dash'
in
desc
:
ret
[
'ls'
]
=
'--'
for
v
in
desc
:
if
v
.
startswith
(
'c'
):
ret
[
'color'
]
=
v
[
1
:]
return
ret
def
do_plot
(
data_xs
,
data_ys
):
"""
data_xs: list of 1d array, either of size 1 of size len(data_ys)
data_ys: list of 1d array
"""
fig
=
plt
.
figure
(
figsize
=
(
16.18
/
1.2
,
10
/
1.2
))
ax
=
fig
.
add_axes
((
0.1
,
0.2
,
0.8
,
0.7
))
nr_y
=
len
(
data_ys
)
y_column
=
args
.
y_column
# parse legend and y-scale
if
args
.
legend
:
legends
=
args
.
legend
.
split
(
','
)
assert
len
(
legends
)
==
nr_y
else
:
legends
=
None
#range(nr_y) #None
if
args
.
scale
:
scale
=
map
(
float
,
args
.
scale
.
split
(
','
))
assert
len
(
scale
)
==
nr_y
else
:
scale
=
[
1.0
]
*
nr_y
for
yidx
in
range
(
nr_y
):
plotargs
=
plot_args_from_column_desc
(
y_column
[
yidx
][
1
:])
now_scale
=
scale
[
yidx
]
data_y
=
data_ys
[
yidx
]
*
now_scale
leg
=
legends
[
yidx
]
if
legends
else
None
if
now_scale
!=
1
:
leg
=
"{}*{}"
.
format
(
now_scale
if
int
(
now_scale
)
!=
now_scale
else
int
(
now_scale
),
leg
)
data_x
=
data_xs
[
0
]
if
len
(
data_xs
)
==
1
else
data_xs
[
yidx
]
assert
len
(
data_x
)
>=
len
(
data_y
),
\
"x column is shorter than y column! {} < {}"
.
format
(
len
(
data_x
),
len
(
data_y
))
truncate_data_x
=
data_x
[:
len
(
data_y
)]
p
=
plt
.
plot
(
truncate_data_x
,
data_y
,
label
=
leg
,
**
plotargs
)
c
=
p
[
0
]
.
get_color
()
plt
.
fill_between
(
truncate_data_x
,
data_y
,
alpha
=
0.1
,
facecolor
=
c
)
#ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none')
#plt.xticks([])
#plt.yticks([])
if
args
.
annotate_maximum
or
args
.
annotate_minimum
:
annotate_min_max
(
truncate_data_x
,
data_y
,
ax
)
plt
.
xlabel
(
args
.
xlabel
,
fontsize
=
'xx-large'
)
plt
.
ylabel
(
args
.
ylabel
,
fontsize
=
'xx-large'
)
plt
.
legend
(
loc
=
'best'
,
fontsize
=
'xx-large'
)
# adjust maxx
minx
,
maxx
=
min
(
data_x
),
max
(
data_x
)
new_maxx
=
maxx
+
(
maxx
-
minx
)
*
0.05
plt
.
xlim
(
minx
,
new_maxx
)
for
label
in
chain
.
from_iterable
(
[
ax
.
get_xticklabels
(),
ax
.
get_yticklabels
()]):
label
.
set_fontproperties
(
fontm
.
FontProperties
(
size
=
15
))
ax
.
grid
(
color
=
'gray'
,
linestyle
=
'dashed'
)
plt
.
title
(
args
.
title
,
fontdict
=
{
'fontsize'
:
'20'
})
if
args
.
output
!=
''
:
plt
.
savefig
(
args
.
output
)
if
args
.
show
:
plt
.
show
()
def
main
():
get_args
()
# parse input args
if
args
.
input
==
STDIN_FNAME
:
fin
=
sys
.
stdin
else
:
fin
=
open
(
args
.
input
)
all_inputs
=
fin
.
readlines
()
if
args
.
input
!=
STDIN_FNAME
:
fin
.
close
()
# parse column format
nr_column
=
len
(
all_inputs
[
0
]
.
rstrip
()
.
split
())
if
args
.
column
is
None
:
column
=
[
'y'
]
*
nr_column
else
:
column
=
args
.
column
.
strip
()
.
split
(
','
)
for
k
in
column
:
assert
k
[
0
]
in
[
'x'
,
'y'
]
assert
nr_column
==
len
(
column
),
"Column and data doesn't have same length. {}!={}"
.
format
(
nr_column
,
len
(
column
))
args
.
y_column
=
[
v
for
v
in
column
if
v
[
0
]
==
'y'
]
args
.
y_column_idx
=
[
idx
for
idx
,
v
in
enumerate
(
column
)
if
v
[
0
]
==
'y'
]
args
.
x_column
=
[
v
for
v
in
column
if
v
[
0
]
==
'x'
]
args
.
x_column_idx
=
[
idx
for
idx
,
v
in
enumerate
(
column
)
if
v
[
0
]
==
'x'
]
nr_x_column
=
len
(
args
.
x_column
)
nr_y_column
=
len
(
args
.
y_column
)
if
nr_x_column
>
1
:
assert
nr_x_column
==
nr_y_column
,
\
"If multiple x columns are used, nr_x_column must equals to nr_y_column"
x_column_set
=
set
(
args
.
x_column
)
# read and parse data
data
=
[[]
for
_
in
range
(
nr_column
)]
ended
=
defaultdict
(
bool
)
data_format
=
-
1
for
lineno
,
line
in
enumerate
(
all_inputs
):
line
=
line
.
rstrip
(
'
\n
'
)
.
split
(
args
.
delimeter
)
assert
len
(
line
)
<=
nr_column
,
\
"""One row have too many columns (separated by {})!
Line: {}"""
.
format
(
repr
(
args
.
delimeter
),
line
)
for
idx
,
val
in
enumerate
(
line
):
if
val
==
''
:
ended
[
idx
]
=
True
continue
else
:
val
=
float
(
val
)
assert
not
ended
[
idx
],
"Column {} has hole!"
.
format
(
idx
)
data
[
idx
]
.
append
(
val
)
data_ys
=
[
data
[
k
]
for
k
in
args
.
y_column_idx
]
max_ysize
=
max
([
len
(
t
)
for
t
in
data_ys
])
print
"Size of the longest y column: "
,
max_ysize
if
nr_x_column
:
data_xs
=
[
data
[
k
]
for
k
in
args
.
x_column_idx
]
else
:
data_xs
=
[
list
(
range
(
max_ysize
))]
for
idx
,
data_y
in
enumerate
(
data_ys
):
data_ys
[
idx
]
=
np
.
asarray
(
data_y
)
if
args
.
decay
!=
0
:
data_ys
[
idx
]
=
exponential_smooth
(
data_y
,
args
.
decay
)
#if idx == 0: # TODO allow different decay for each y
#data_ys[idx] = exponential_smooth(data_y, 0.5)
for
idx
,
data_x
in
enumerate
(
data_xs
):
data_xs
[
idx
]
=
np
.
asarray
(
data_x
)
if
args
.
xkcd
:
with
plt
.
xkcd
():
do_plot
(
data_xs
,
data_ys
)
else
:
do_plot
(
data_xs
,
data_ys
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment