Commit 881c6c4b authored by Yuxin Wu's avatar Yuxin Wu

add resnet curve

parent d6723566
...@@ -10,6 +10,9 @@ More results to come. ...@@ -10,6 +10,9 @@ More results to come.
| ResNet 18 | 10.67% | 29.50% | | ResNet 18 | 10.67% | 29.50% |
| ResNet 34 | 8.66% | 26.45% | | ResNet 34 | 8.66% | 26.45% |
| ResNet 50 | 7.13% | 24.12% | | ResNet 50 | 7.13% | 24.12% |
| ResNet 101 | 6.54% | 22.89% |
![imagenet](imagenet-resnet.png)
## load-resnet.py ## load-resnet.py
......
...@@ -56,6 +56,10 @@ def get_args(): ...@@ -56,6 +56,10 @@ def get_args():
help='x label', type=six.text_type) help='x label', type=six.text_type)
parser.add_argument('--ylabel', parser.add_argument('--ylabel',
help='y label', type=six.text_type) help='y label', type=six.text_type)
parser.add_argument('--xlim',
help='x lim', type=float, nargs=2)
parser.add_argument('--ylim',
help='y lim', type=float, nargs=2)
parser.add_argument('-s', '--scale', parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma') help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum', parser.add_argument('--annotate-maximum',
...@@ -218,6 +222,10 @@ def do_plot(data_xs, data_ys): ...@@ -218,6 +222,10 @@ def do_plot(data_xs, data_ys):
plt.xlabel(args.xlabel, fontsize='xx-large') plt.xlabel(args.xlabel, fontsize='xx-large')
if args.ylabel: if args.ylabel:
plt.ylabel(args.ylabel, fontsize='xx-large') plt.ylabel(args.ylabel, fontsize='xx-large')
if args.xlim:
plt.xlim(args.xlim[0], args.xlim[1])
if args.ylim:
plt.ylim(args.ylim[0], args.ylim[1])
plt.legend(loc='best', fontsize='xx-large') plt.legend(loc='best', fontsize='xx-large')
# adjust maxx # adjust maxx
...@@ -250,7 +258,7 @@ def main(): ...@@ -250,7 +258,7 @@ def main():
fin.close() fin.close()
# parse column format # parse column format
nr_column = len(all_inputs[0].rstrip().split(args.delimeter)) nr_column = len(all_inputs[0].rstrip('\n').split(args.delimeter))
if args.column is None: if args.column is None:
column = ['y'] * nr_column column = ['y'] * nr_column
else: else:
...@@ -287,7 +295,9 @@ Line: {}""".format(repr(args.delimeter), line) ...@@ -287,7 +295,9 @@ Line: {}""".format(repr(args.delimeter), line)
data[idx].append(val) data[idx].append(val)
data_ys = [data[k] for k in args.y_column_idx] data_ys = [data[k] for k in args.y_column_idx]
max_ysize = max([len(t) for t in data_ys]) length_ys = [len(t) for t in data_ys]
print("Length of each column:", length_ys)
max_ysize = max(length_ys)
print("Size of the longest y column: ", max_ysize) print("Size of the longest y column: ", max_ysize)
if nr_x_column: if nr_x_column:
......
...@@ -60,7 +60,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -60,7 +60,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with tqdm(total=sz) as pbar: with tqdm(total=sz, disable=(sz==0)) as pbar:
for dp in self.dataset.get_data(): for dp in self.dataset.get_data():
res = self.predictor(dp) res = self.predictor(dp)
yield res yield res
...@@ -119,7 +119,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -119,7 +119,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with tqdm(total=sz) as pbar: with tqdm(total=sz, disable=(sz==0)) as pbar:
die_cnt = 0 die_cnt = 0
while True: while True:
res = self.result_queue.get() res = self.result_queue.get()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment