Commit ccf9a521 authored by Yuxin Wu's avatar Yuxin Wu

fix pep8 style in /scripts

parent 9b69e860
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
import numpy as np import numpy as np
from tensorpack.tfutils.varmanip import dump_chkpt_vars from tensorpack.tfutils.varmanip import dump_chkpt_vars
from tensorpack.utils import logger from tensorpack.utils import logger
import tensorflow as tf
import sys
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -26,4 +24,5 @@ logger.info(str(params.keys())) ...@@ -26,4 +24,5 @@ logger.info(str(params.keys()))
if args.dump: if args.dump:
np.save(args.dump, params) np.save(args.dump, params)
if args.shell: if args.shell:
import IPython as IP; IP.embed(config=IP.terminal.ipapp.load_default_config()) import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dump_train_config.py # File: dump-dataflow.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse import argparse
import cv2 import cv2
import tensorflow as tf
import imp import imp
import tqdm import tqdm
import os import os
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.fs import mkdir_p from tensorpack.utils.fs import mkdir_p
from tensorpack.dataflow import * from tensorpack.dataflow import RepeatedData
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -54,6 +53,3 @@ with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar: ...@@ -54,6 +53,3 @@ with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
if idx > NR_DP_TEST: if idx > NR_DP_TEST:
break break
pbar.update() pbar.update()
...@@ -8,8 +8,9 @@ import argparse ...@@ -8,8 +8,9 @@ import argparse
import tensorflow as tf import tensorflow as tf
import imp import imp
from tensorpack import * from tensorpack import TowerContext, logger, ModelFromMetaGraph
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip
from tensorpack.utils.naming import EXTRA_SAVE_VARS_KEY
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file') parser.add_argument('--config', help='config file')
......
...@@ -14,23 +14,23 @@ $ cat examples/train_log/mnist-convnet/stat.json \ ...@@ -14,23 +14,23 @@ $ cat examples/train_log/mnist-convnet/stat.json \
For more usage, see `plot-point.py -h` or the code. For more usage, see `plot-point.py -h` or the code.
""" """
from math import *
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.font_manager as fontm import matplotlib.font_manager as fontm
import argparse, sys import argparse
import sys
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
import six import six
from matplotlib import rc # from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) # rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']}) # rc('font',**{'family':'sans-serif','sans-serif':['Microsoft Yahei']})
#rc('text', usetex=True) # rc('text', usetex=True)
STDIN_FNAME = '-' STDIN_FNAME = '-'
def get_args(): def get_args():
description = "plot points into graph." description = "plot points into graph."
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
...@@ -63,14 +63,14 @@ def get_args(): ...@@ -63,14 +63,14 @@ def get_args():
parser.add_argument('-s', '--scale', parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma') help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum', parser.add_argument('--annotate-maximum',
help = 'annonate maximum value in graph', help='annonate maximum value in graph',
action = 'store_true') action='store_true')
parser.add_argument('--annotate-minimum', parser.add_argument('--annotate-minimum',
help = 'annonate minimum value in graph', help='annonate minimum value in graph',
action = 'store_true') action='store_true')
parser.add_argument('--xkcd', parser.add_argument('--xkcd',
help = 'xkcd style', help='xkcd style',
action = 'store_true') action='store_true')
parser.add_argument('--decay', parser.add_argument('--decay',
help='exponential decay rate to smooth Y', help='exponential decay rate to smooth Y',
type=float, default=0) type=float, default=0)
...@@ -80,11 +80,12 @@ def get_args(): ...@@ -80,11 +80,12 @@ def get_args():
help='column delimeter', default='\t') help='column delimeter', default='\t')
global args global args
args = parser.parse_args(); args = parser.parse_args()
if not args.show and not args.output: if not args.show and not args.output:
args.show = True args.show = True
def filter_valid_range(points, rect): def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)""" """rect = (min_x, max_x, min_y, max_y)"""
ret = [] ret = []
...@@ -95,15 +96,17 @@ def filter_valid_range(points, rect): ...@@ -95,15 +96,17 @@ def filter_valid_range(points, rect):
ret.append(points[0]) ret.append(points[0])
return ret return ret
def exponential_smooth(data, alpha): def exponential_smooth(data, alpha):
""" smooth data by alpha. returned a smoothed version""" """ smooth data by alpha. returned a smoothed version"""
ret = np.copy(data) ret = np.copy(data)
now = data[0] now = data[0]
for k in range(len(data)): for k in range(len(data)):
ret[k] = now * alpha + data[k] * (1-alpha) ret[k] = now * alpha + data[k] * (1 - alpha)
now = ret[k] now = ret[k]
return ret return ret
def annotate_min_max(data_x, data_y, ax): def annotate_min_max(data_x, data_y, ax):
max_x, min_x = max(data_x), min(data_x) max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y) max_y, min_y = max(data_y), min(data_y)
...@@ -133,9 +136,9 @@ def annotate_min_max(data_x, data_y, ax): ...@@ -133,9 +136,9 @@ def annotate_min_max(data_x, data_y, ax):
y_max - 0.025 * y_range)], y_max - 0.025 * y_range)],
rect)[0] rect)[0]
ax.annotate('maximum ({:d},{:.3f})' . format(int(x_max), y_max), ax.annotate('maximum ({:d},{:.3f})' . format(int(x_max), y_max),
xy = (x_max, y_max), xy=(x_max, y_max),
xytext = (text_x, text_y), xytext=(text_x, text_y),
arrowprops = dict(arrowstyle = '->')) arrowprops=dict(arrowstyle='->'))
if args.annotate_minimum: if args.annotate_minimum:
text_x, text_y = filter_valid_range([ text_x, text_y = filter_valid_range([
(x_min + 0.05 * x_range, (x_min + 0.05 * x_range,
...@@ -148,13 +151,14 @@ def annotate_min_max(data_x, data_y, ax): ...@@ -148,13 +151,14 @@ def annotate_min_max(data_x, data_y, ax):
y_min + 0.025 * y_range)], y_min + 0.025 * y_range)],
rect)[0] rect)[0]
ax.annotate('minimum ({:d},{:.3f})' . format(int(x_min), y_min), ax.annotate('minimum ({:d},{:.3f})' . format(int(x_min), y_min),
xy = (x_min, y_min), xy=(x_min, y_min),
xytext = (text_x, text_y), xytext=(text_x, text_y),
arrowprops = dict(arrowstyle = '->')) arrowprops=dict(arrowstyle='->'))
#ax.annotate('{:.3f}' . format(y_min), # ax.annotate('{:.3f}' . format(y_min),
#xy = (x_min, y_min), # xy = (x_min, y_min),
#xytext = (text_x, text_y), # xytext = (text_x, text_y),
#arrowprops = dict(arrowstyle = '->')) # arrowprops = dict(arrowstyle = '->'))
def plot_args_from_column_desc(desc): def plot_args_from_column_desc(desc):
if not desc: if not desc:
...@@ -170,12 +174,13 @@ def plot_args_from_column_desc(desc): ...@@ -170,12 +174,13 @@ def plot_args_from_column_desc(desc):
ret['color'] = v[1:] ret['color'] = v[1:]
return ret return ret
def do_plot(data_xs, data_ys): def do_plot(data_xs, data_ys):
""" """
data_xs: list of 1d array, either of size 1 or size len(data_ys) data_xs: list of 1d array, either of size 1 or size len(data_ys)
data_ys: list of 1d array data_ys: list of 1d array
""" """
fig = plt.figure(figsize = (16.18/1.2, 10/1.2)) fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2))
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7)) ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
nr_y = len(data_ys) nr_y = len(data_ys)
y_column = args.y_column y_column = args.y_column
...@@ -185,7 +190,7 @@ def do_plot(data_xs, data_ys): ...@@ -185,7 +190,7 @@ def do_plot(data_xs, data_ys):
legends = args.legend.split(',') legends = args.legend.split(',')
assert len(legends) == nr_y assert len(legends) == nr_y
else: else:
legends = None #range(nr_y) #None legends = None # range(nr_y) #None
if args.scale: if args.scale:
scale = map(float, args.scale.split(',')) scale = map(float, args.scale.split(','))
assert len(scale) == nr_y assert len(scale) == nr_y
...@@ -209,11 +214,11 @@ def do_plot(data_xs, data_ys): ...@@ -209,11 +214,11 @@ def do_plot(data_xs, data_ys):
c = p[0].get_color() c = p[0].get_color()
plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c) plt.fill_between(truncate_data_x, data_y, alpha=0.1, facecolor=c)
#ax.set_aspect('equal', 'datalim') # ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none') # ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none') # ax.spines['left'].set_color('none')
#plt.xticks([]) # plt.xticks([])
#plt.yticks([]) # plt.yticks([])
if args.annotate_maximum or args.annotate_minimum: if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax) annotate_min_max(truncate_data_x, data_y, ax)
...@@ -237,7 +242,7 @@ def do_plot(data_xs, data_ys): ...@@ -237,7 +242,7 @@ def do_plot(data_xs, data_ys):
[ax.get_xticklabels(), ax.get_yticklabels()]): [ax.get_xticklabels(), ax.get_yticklabels()]):
label.set_fontproperties(fontm.FontProperties(size=15)) label.set_fontproperties(fontm.FontProperties(size=15))
ax.grid(color = 'gray', linestyle = 'dashed') ax.grid(color='gray', linestyle='dashed')
plt.title(args.title, fontdict={'fontsize': '20'}) plt.title(args.title, fontdict={'fontsize': '20'})
...@@ -246,6 +251,7 @@ def do_plot(data_xs, data_ys): ...@@ -246,6 +251,7 @@ def do_plot(data_xs, data_ys):
if args.show: if args.show:
plt.show() plt.show()
def main(): def main():
get_args() get_args()
# parse input args # parse input args
...@@ -263,7 +269,8 @@ def main(): ...@@ -263,7 +269,8 @@ def main():
column = ['y'] * nr_column column = ['y'] * nr_column
else: else:
column = args.column.strip().split(',') column = args.column.strip().split(',')
for k in column: assert k[0] in ['x', 'y'] for k in column:
assert k[0] in ['x', 'y']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column)) assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y'] args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y'] args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
...@@ -275,11 +282,9 @@ def main(): ...@@ -275,11 +282,9 @@ def main():
assert nr_x_column == nr_y_column, \ assert nr_x_column == nr_y_column, \
"If multiple x columns are used, nr_x_column must equals to nr_y_column" "If multiple x columns are used, nr_x_column must equals to nr_y_column"
x_column_set = set(args.x_column)
# read and parse data # read and parse data
data = [[] for _ in range(nr_column)] data = [[] for _ in range(nr_column)]
ended = defaultdict(bool) ended = defaultdict(bool)
data_format = -1
for lineno, line in enumerate(all_inputs): for lineno, line in enumerate(all_inputs):
line = line.rstrip('\n').split(args.delimeter) line = line.rstrip('\n').split(args.delimeter)
assert len(line) <= nr_column, \ assert len(line) <= nr_column, \
...@@ -302,14 +307,14 @@ Line: {}""".format(repr(args.delimeter), line) ...@@ -302,14 +307,14 @@ Line: {}""".format(repr(args.delimeter), line)
if nr_x_column: if nr_x_column:
data_xs = [data[k] for k in args.x_column_idx] data_xs = [data[k] for k in args.x_column_idx]
else: else:
data_xs = [list(range(1, max_ysize+1))] data_xs = [list(range(1, max_ysize + 1))]
for idx, data_y in enumerate(data_ys): for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y) data_ys[idx] = np.asarray(data_y)
if args.decay != 0: if args.decay != 0:
data_ys[idx] = exponential_smooth(data_y, args.decay) data_ys[idx] = exponential_smooth(data_y, args.decay)
#if idx == 0: # TODO allow different decay for each y # if idx == 0: # TODO allow different decay for each y
#data_ys[idx] = exponential_smooth(data_y, 0.5) # data_ys[idx] = exponential_smooth(data_y, 0.5)
for idx, data_x in enumerate(data_xs): for idx, data_x in enumerate(data_xs):
data_xs[idx] = np.asarray(data_x) data_xs[idx] = np.asarray(data_x)
...@@ -319,5 +324,6 @@ Line: {}""".format(repr(args.delimeter), line) ...@@ -319,5 +324,6 @@ Line: {}""".format(repr(args.delimeter), line)
else: else:
do_plot(data_xs, data_ys) do_plot(data_xs, data_ys)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: serve_data.py # File: serve-data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse import argparse
import imp import imp
#import cv2
#import os
from tensorpack.dataflow import serve_data from tensorpack.dataflow import serve_data
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment