Commit ccf9a521 authored by Yuxin Wu's avatar Yuxin Wu

fix pep8 style in /scripts

parent 9b69e860
......@@ -7,8 +7,6 @@
import numpy as np
from tensorpack.tfutils.varmanip import dump_chkpt_vars
from tensorpack.utils import logger
import tensorflow as tf
import sys
import argparse
parser = argparse.ArgumentParser()
......@@ -26,4 +24,5 @@ logger.info(str(params.keys()))
if args.dump:
np.save(args.dump, params)
if args.shell:
import IPython as IP; IP.embed(config=IP.terminal.ipapp.load_default_config())
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump_train_config.py
# File: dump-dataflow.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
import tqdm
import os
from tensorpack.utils import logger
from tensorpack.utils.fs import mkdir_p
from tensorpack.dataflow import *
from tensorpack.dataflow import RepeatedData
parser = argparse.ArgumentParser()
......@@ -54,6 +53,3 @@ with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
if idx > NR_DP_TEST:
break
pbar.update()
......@@ -8,8 +8,9 @@ import argparse
import tensorflow as tf
import imp
from tensorpack import *
from tensorpack import TowerContext, logger, ModelFromMetaGraph
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.utils.naming import EXTRA_SAVE_VARS_KEY
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file')
......@@ -51,6 +52,6 @@ with tf.Graph().as_default() as G:
logger.info("Variables to dump:")
logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(
var_list=var_dict,
write_version=tf.train.SaverDef.V2)
var_list=var_dict,
write_version=tf.train.SaverDef.V2)
saver.save(sess, args.output, write_meta_graph=False)
......@@ -14,96 +14,99 @@ $ cat examples/train_log/mnist-convnet/stat.json \
For more usage, see `plot-point.py -h` or the code.
"""
from math import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fontm
import argparse, sys
import argparse
import sys
from collections import defaultdict
from itertools import chain
import six
from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
#rc('text', usetex=True)
# from matplotlib import rc
# rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
# rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
# rc('text', usetex=True)
STDIN_FNAME = '-'
def get_args():
description = "plot points into graph."
parser = argparse.ArgumentParser(description=description)
parser.add_argument('-i', '--input',
help='input data file, use "-" for stdin. Default stdin. Input \
help='input data file, use "-" for stdin. Default stdin. Input \
format is many rows of DELIMIETER-separated data',
default='-')
default='-')
parser.add_argument('-o', '--output',
help='output image', default='')
help='output image', default='')
parser.add_argument('--show',
help='show the figure after rendered',
action='store_true')
help='show the figure after rendered',
action='store_true')
parser.add_argument('-c', '--column',
help="describe each column in data, for example 'x,y,y'. \
help="describe each column in data, for example 'x,y,y'. \
Default to 'y' for one column and 'x,y' for two columns. \
Plot attributes can be appended after 'y', like 'ythick;cr'. \
By default, assume all columns are y. \
")
parser.add_argument('-t', '--title',
help='title of the graph',
default='')
help='title of the graph',
default='')
parser.add_argument('--xlabel',
help='x label', type=six.text_type)
help='x label', type=six.text_type)
parser.add_argument('--ylabel',
help='y label', type=six.text_type)
help='y label', type=six.text_type)
parser.add_argument('--xlim',
help='x lim', type=float, nargs=2)
help='x lim', type=float, nargs=2)
parser.add_argument('--ylim',
help='y lim', type=float, nargs=2)
help='y lim', type=float, nargs=2)
parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma')
help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum',
help = 'annonate maximum value in graph',
action = 'store_true')
help='annonate maximum value in graph',
action='store_true')
parser.add_argument('--annotate-minimum',
help = 'annonate minimum value in graph',
action = 'store_true')
help='annonate minimum value in graph',
action='store_true')
parser.add_argument('--xkcd',
help = 'xkcd style',
action = 'store_true')
help='xkcd style',
action='store_true')
parser.add_argument('--decay',
help='exponential decay rate to smooth Y',
type=float, default=0)
help='exponential decay rate to smooth Y',
type=float, default=0)
parser.add_argument('-l', '--legend',
help='legend for each y')
help='legend for each y')
parser.add_argument('-d', '--delimeter',
help='column delimeter', default='\t')
help='column delimeter', default='\t')
global args
args = parser.parse_args();
args = parser.parse_args()
if not args.show and not args.output:
args.show = True
def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)"""
ret = []
for x, y in points:
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]:
ret.append((x, y))
ret.append((x, y))
if len(ret) == 0:
ret.append(points[0])
return ret
def exponential_smooth(data, alpha):
""" smooth data by alpha. returned a smoothed version"""
ret = np.copy(data)
now = data[0]
for k in range(len(data)):
ret[k] = now * alpha + data[k] * (1-alpha)
ret[k] = now * alpha + data[k] * (1 - alpha)
now = ret[k]
return ret
def annotate_min_max(data_x, data_y, ax):
max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y)
......@@ -133,9 +136,9 @@ def annotate_min_max(data_x, data_y, ax):
y_max - 0.025 * y_range)],
rect)[0]
ax.annotate('maximum ({:d},{:.3f})' . format(int(x_max), y_max),
xy = (x_max, y_max),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
xy=(x_max, y_max),
xytext=(text_x, text_y),
arrowprops=dict(arrowstyle='->'))
if args.annotate_minimum:
text_x, text_y = filter_valid_range([
(x_min + 0.05 * x_range,
......@@ -148,13 +151,14 @@ def annotate_min_max(data_x, data_y, ax):
y_min + 0.025 * y_range)],
rect)[0]
ax.annotate('minimum ({:d},{:.3f})' . format(int(x_min), y_min),
xy = (x_min, y_min),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
#ax.annotate('{:.3f}' . format(y_min),
#xy = (x_min, y_min),
#xytext = (text_x, text_y),
#arrowprops = dict(arrowstyle = '->'))
xy=(x_min, y_min),
xytext=(text_x, text_y),
arrowprops=dict(arrowstyle='->'))
# ax.annotate('{:.3f}' . format(y_min),
# xy = (x_min, y_min),
# xytext = (text_x, text_y),
# arrowprops = dict(arrowstyle = '->'))
def plot_args_from_column_desc(desc):
if not desc:
......@@ -170,12 +174,13 @@ def plot_args_from_column_desc(desc):
ret['color'] = v[1:]
return ret
def do_plot(data_xs, data_ys):
"""
data_xs: list of 1d array, either of size 1 or size len(data_ys)
data_ys: list of 1d array
"""
fig = plt.figure(figsize = (16.18/1.2, 10/1.2))
fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2))
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
nr_y = len(data_ys)
y_column = args.y_column
......@@ -185,7 +190,7 @@ def do_plot(data_xs, data_ys):
legends = args.legend.split(',')
assert len(legends) == nr_y
else:
legends = None #range(nr_y) #None
legends = None # range(nr_y) #None
if args.scale:
scale = map(float, args.scale.split(','))
assert len(scale) == nr_y
......@@ -201,19 +206,19 @@ def do_plot(data_xs, data_ys):
leg = "{}*{}".format(now_scale if int(now_scale) != now_scale else int(now_scale), leg)
data_x = data_xs[0] if len(data_xs) == 1 else data_xs[yidx]
assert len(data_x) >= len(data_y), \
"x column is shorter than y column! {} < {}".format(
len(data_x), len(data_y))
"x column is shorter than y column! {} < {}".format(
len(data_x), len(data_y))
truncate_data_x = data_x[:len(data_y)]
p = plt.plot(truncate_data_x, data_y, label=leg, **plotargs)
c = p[0].get_color()
plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c)
#ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none')
#plt.xticks([])
#plt.yticks([])
# ax.set_aspect('equal', 'datalim')
# ax.spines['right'].set_color('none')
# ax.spines['left'].set_color('none')
# plt.xticks([])
# plt.yticks([])
if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax)
......@@ -237,7 +242,7 @@ def do_plot(data_xs, data_ys):
[ax.get_xticklabels(), ax.get_yticklabels()]):
label.set_fontproperties(fontm.FontProperties(size=15))
ax.grid(color = 'gray', linestyle = 'dashed')
ax.grid(color='gray', linestyle='dashed')
plt.title(args.title, fontdict={'fontsize': '20'})
......@@ -246,6 +251,7 @@ def do_plot(data_xs, data_ys):
if args.show:
plt.show()
def main():
get_args()
# parse input args
......@@ -263,7 +269,8 @@ def main():
column = ['y'] * nr_column
else:
column = args.column.strip().split(',')
for k in column: assert k[0] in ['x', 'y']
for k in column:
assert k[0] in ['x', 'y']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
......@@ -275,11 +282,9 @@ def main():
assert nr_x_column == nr_y_column, \
"If multiple x columns are used, nr_x_column must equals to nr_y_column"
x_column_set = set(args.x_column)
# read and parse data
data = [[] for _ in range(nr_column)]
ended = defaultdict(bool)
data_format = -1
for lineno, line in enumerate(all_inputs):
line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \
......@@ -302,14 +307,14 @@ Line: {}""".format(repr(args.delimeter), line)
if nr_x_column:
data_xs = [data[k] for k in args.x_column_idx]
else:
data_xs = [list(range(1, max_ysize+1))]
data_xs = [list(range(1, max_ysize + 1))]
for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y)
if args.decay != 0:
data_ys[idx] = exponential_smooth(data_y, args.decay)
#if idx == 0: # TODO allow different decay for each y
#data_ys[idx] = exponential_smooth(data_y, 0.5)
# if idx == 0: # TODO allow different decay for each y
# data_ys[idx] = exponential_smooth(data_y, 0.5)
for idx, data_x in enumerate(data_xs):
data_xs[idx] = np.asarray(data_x)
......@@ -319,5 +324,6 @@ Line: {}""".format(repr(args.delimeter), line)
else:
do_plot(data_xs, data_ys)
if __name__ == '__main__':
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serve_data.py
# File: serve-data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import imp
#import cv2
#import os
from tensorpack.dataflow import serve_data
parser = argparse.ArgumentParser()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment