Commit ccf9a521 authored by Yuxin Wu's avatar Yuxin Wu

fix pep8 style in /scripts

parent 9b69e860
......@@ -7,8 +7,6 @@
import numpy as np
from tensorpack.tfutils.varmanip import dump_chkpt_vars
from tensorpack.utils import logger
import tensorflow as tf
import sys
import argparse
parser = argparse.ArgumentParser()
......@@ -26,4 +24,5 @@ logger.info(str(params.keys()))
if args.dump:
np.save(args.dump, params)
if args.shell:
import IPython as IP; IP.embed(config=IP.terminal.ipapp.load_default_config())
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump_train_config.py
# File: dump-dataflow.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
import tqdm
import os
from tensorpack.utils import logger
from tensorpack.utils.fs import mkdir_p
from tensorpack.dataflow import *
from tensorpack.dataflow import RepeatedData
parser = argparse.ArgumentParser()
......@@ -54,6 +53,3 @@ with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
if idx > NR_DP_TEST:
break
pbar.update()
......@@ -8,8 +8,9 @@ import argparse
import tensorflow as tf
import imp
from tensorpack import *
from tensorpack import TowerContext, logger, ModelFromMetaGraph
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.utils.naming import EXTRA_SAVE_VARS_KEY
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file')
......
......@@ -14,23 +14,23 @@ $ cat examples/train_log/mnist-convnet/stat.json \
For more usage, see `plot-point.py -h` or the code.
"""
from math import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fontm
import argparse, sys
import argparse
import sys
from collections import defaultdict
from itertools import chain
import six
from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
#rc('text', usetex=True)
# from matplotlib import rc
# rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
# rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
# rc('text', usetex=True)
STDIN_FNAME = '-'
def get_args():
description = "plot points into graph."
parser = argparse.ArgumentParser(description=description)
......@@ -63,14 +63,14 @@ def get_args():
parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum',
help = 'annonate maximum value in graph',
action = 'store_true')
help='annonate maximum value in graph',
action='store_true')
parser.add_argument('--annotate-minimum',
help = 'annonate minimum value in graph',
action = 'store_true')
help='annonate minimum value in graph',
action='store_true')
parser.add_argument('--xkcd',
help = 'xkcd style',
action = 'store_true')
help='xkcd style',
action='store_true')
parser.add_argument('--decay',
help='exponential decay rate to smooth Y',
type=float, default=0)
......@@ -80,11 +80,12 @@ def get_args():
help='column delimeter', default='\t')
global args
args = parser.parse_args();
args = parser.parse_args()
if not args.show and not args.output:
args.show = True
def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)"""
ret = []
......@@ -95,15 +96,17 @@ def filter_valid_range(points, rect):
ret.append(points[0])
return ret
def exponential_smooth(data, alpha):
""" smooth data by alpha. returned a smoothed version"""
ret = np.copy(data)
now = data[0]
for k in range(len(data)):
ret[k] = now * alpha + data[k] * (1-alpha)
ret[k] = now * alpha + data[k] * (1 - alpha)
now = ret[k]
return ret
def annotate_min_max(data_x, data_y, ax):
max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y)
......@@ -133,9 +136,9 @@ def annotate_min_max(data_x, data_y, ax):
y_max - 0.025 * y_range)],
rect)[0]
ax.annotate('maximum ({:d},{:.3f})' . format(int(x_max), y_max),
xy = (x_max, y_max),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
xy=(x_max, y_max),
xytext=(text_x, text_y),
arrowprops=dict(arrowstyle='->'))
if args.annotate_minimum:
text_x, text_y = filter_valid_range([
(x_min + 0.05 * x_range,
......@@ -148,13 +151,14 @@ def annotate_min_max(data_x, data_y, ax):
y_min + 0.025 * y_range)],
rect)[0]
ax.annotate('minimum ({:d},{:.3f})' . format(int(x_min), y_min),
xy = (x_min, y_min),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
#ax.annotate('{:.3f}' . format(y_min),
#xy = (x_min, y_min),
#xytext = (text_x, text_y),
#arrowprops = dict(arrowstyle = '->'))
xy=(x_min, y_min),
xytext=(text_x, text_y),
arrowprops=dict(arrowstyle='->'))
# ax.annotate('{:.3f}' . format(y_min),
# xy = (x_min, y_min),
# xytext = (text_x, text_y),
# arrowprops = dict(arrowstyle = '->'))
def plot_args_from_column_desc(desc):
if not desc:
......@@ -170,12 +174,13 @@ def plot_args_from_column_desc(desc):
ret['color'] = v[1:]
return ret
def do_plot(data_xs, data_ys):
"""
data_xs: list of 1d array, either of size 1 or size len(data_ys)
data_ys: list of 1d array
"""
fig = plt.figure(figsize = (16.18/1.2, 10/1.2))
fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2))
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
nr_y = len(data_ys)
y_column = args.y_column
......@@ -185,7 +190,7 @@ def do_plot(data_xs, data_ys):
legends = args.legend.split(',')
assert len(legends) == nr_y
else:
legends = None #range(nr_y) #None
legends = None # range(nr_y) #None
if args.scale:
scale = map(float, args.scale.split(','))
assert len(scale) == nr_y
......@@ -209,11 +214,11 @@ def do_plot(data_xs, data_ys):
c = p[0].get_color()
plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c)
#ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none')
#plt.xticks([])
#plt.yticks([])
# ax.set_aspect('equal', 'datalim')
# ax.spines['right'].set_color('none')
# ax.spines['left'].set_color('none')
# plt.xticks([])
# plt.yticks([])
if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax)
......@@ -237,7 +242,7 @@ def do_plot(data_xs, data_ys):
[ax.get_xticklabels(), ax.get_yticklabels()]):
label.set_fontproperties(fontm.FontProperties(size=15))
ax.grid(color = 'gray', linestyle = 'dashed')
ax.grid(color='gray', linestyle='dashed')
plt.title(args.title, fontdict={'fontsize': '20'})
......@@ -246,6 +251,7 @@ def do_plot(data_xs, data_ys):
if args.show:
plt.show()
def main():
get_args()
# parse input args
......@@ -263,7 +269,8 @@ def main():
column = ['y'] * nr_column
else:
column = args.column.strip().split(',')
for k in column: assert k[0] in ['x', 'y']
for k in column:
assert k[0] in ['x', 'y']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
......@@ -275,11 +282,9 @@ def main():
assert nr_x_column == nr_y_column, \
"If multiple x columns are used, nr_x_column must equals to nr_y_column"
x_column_set = set(args.x_column)
# read and parse data
data = [[] for _ in range(nr_column)]
ended = defaultdict(bool)
data_format = -1
for lineno, line in enumerate(all_inputs):
line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \
......@@ -302,14 +307,14 @@ Line: {}""".format(repr(args.delimeter), line)
if nr_x_column:
data_xs = [data[k] for k in args.x_column_idx]
else:
data_xs = [list(range(1, max_ysize+1))]
data_xs = [list(range(1, max_ysize + 1))]
for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y)
if args.decay != 0:
data_ys[idx] = exponential_smooth(data_y, args.decay)
#if idx == 0: # TODO allow different decay for each y
#data_ys[idx] = exponential_smooth(data_y, 0.5)
# if idx == 0: # TODO allow different decay for each y
# data_ys[idx] = exponential_smooth(data_y, 0.5)
for idx, data_x in enumerate(data_xs):
data_xs[idx] = np.asarray(data_x)
......@@ -319,5 +324,6 @@ Line: {}""".format(repr(args.delimeter), line)
else:
do_plot(data_xs, data_ys)
if __name__ == '__main__':
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serve_data.py
# File: serve-data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import imp
#import cv2
#import os
from tensorpack.dataflow import serve_data
parser = argparse.ArgumentParser()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment