Commit 881c6c4b authored by Yuxin Wu's avatar Yuxin Wu

add resnet curve

parent d6723566
......@@ -10,6 +10,9 @@ More results to come.
| ResNet 18 | 10.67% | 29.50% |
| ResNet 34 | 8.66% | 26.45% |
| ResNet 50 | 7.13% | 24.12% |
| ResNet 101 | 6.54% | 22.89% |
![imagenet](imagenet-resnet.png)
## load-resnet.py
......
......@@ -56,6 +56,10 @@ def get_args():
help='x label', type=six.text_type)
parser.add_argument('--ylabel',
help='y label', type=six.text_type)
parser.add_argument('--xlim',
help='x lim', type=float, nargs=2)
parser.add_argument('--ylim',
help='y lim', type=float, nargs=2)
parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum',
......@@ -218,6 +222,10 @@ def do_plot(data_xs, data_ys):
plt.xlabel(args.xlabel, fontsize='xx-large')
if args.ylabel:
plt.ylabel(args.ylabel, fontsize='xx-large')
if args.xlim:
plt.xlim(args.xlim[0], args.xlim[1])
if args.ylim:
plt.ylim(args.ylim[0], args.ylim[1])
plt.legend(loc='best', fontsize='xx-large')
# adjust maxx
......@@ -250,7 +258,7 @@ def main():
fin.close()
# parse column format
nr_column = len(all_inputs[0].rstrip().split(args.delimeter))
nr_column = len(all_inputs[0].rstrip('\n').split(args.delimeter))
if args.column is None:
column = ['y'] * nr_column
else:
......@@ -287,7 +295,9 @@ Line: {}""".format(repr(args.delimeter), line)
data[idx].append(val)
data_ys = [data[k] for k in args.y_column_idx]
max_ysize = max([len(t) for t in data_ys])
length_ys = [len(t) for t in data_ys]
print("Length of each column:", length_ys)
max_ysize = max(length_ys)
print("Size of the longest y column: ", max_ysize)
if nr_x_column:
......
......@@ -60,7 +60,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with tqdm(total=sz) as pbar:
with tqdm(total=sz, disable=(sz==0)) as pbar:
for dp in self.dataset.get_data():
res = self.predictor(dp)
yield res
......@@ -119,7 +119,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with tqdm(total=sz) as pbar:
with tqdm(total=sz, disable=(sz==0)) as pbar:
die_cnt = 0
while True:
res = self.result_queue.get()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment