Commit caa93135 authored by Yuxin Wu's avatar Yuxin Wu

work for python3

parent e9c72570
...@@ -10,6 +10,8 @@ To use the script. You'll need: ...@@ -10,6 +10,8 @@ To use the script. You'll need:
+ [TensorFlow](https://tensorflow.org) >= 0.8 + [TensorFlow](https://tensorflow.org) >= 0.8
+ OpenCV Bindings for Python
+ [tensorpack](https://github.com/ppwwyyxx/tensorpack): + [tensorpack](https://github.com/ppwwyyxx/tensorpack):
``` ```
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# Author: Yuheng Zou, Yuxin Wu {zouyuheng,wyx}@megvii.com # Author: Yuheng Zou, Yuxin Wu {zouyuheng,wyx}@megvii.com
...@@ -108,8 +108,8 @@ def run_test(model, sess_init, inputs): ...@@ -108,8 +108,8 @@ def run_test(model, sess_init, inputs):
meta = dataset.ILSVRCMeta().get_synset_words_1000() meta = dataset.ILSVRCMeta().get_synset_words_1000()
names = [meta[i] for i in ret] names = [meta[i] for i in ret]
print f + ":" print(f + ":")
print list(zip(names, prob[ret])) print(list(zip(names, prob[ret])))
# save the metagraph # save the metagraph
#saver = tf.train.Saver() #saver = tf.train.Saver()
...@@ -131,14 +131,14 @@ if __name__ == '__main__': ...@@ -131,14 +131,14 @@ if __name__ == '__main__':
M = ModelFromMetaGraph(args.graph) M = ModelFromMetaGraph(args.graph)
else: else:
# build the graph from scratch # build the graph from scratch
logger.warn("Building the graph from scratch might result \ logger.warn("[DoReFa-Net] Building the graph from scratch might result \
in compatibility issues in the future, if TensorFlow changes some of its \ in compatibility issues in the future, if TensorFlow changes some of its \
op/variable names") op/variable names")
M = Model() M = Model()
if args.load.endswith('.npy'): if args.load.endswith('.npy'):
# load from a parameter dict # load from a parameter dict
param_dict= np.load(args.load).item() param_dict= np.load(args.load, encoding='latin1').item()
sess_init = ParamRestore(param_dict) sess_init = ParamRestore(param_dict)
elif args.load.endswith('.tfmodel'): elif args.load.endswith('.tfmodel'):
sess_init = SaverRestore(args.load) sess_init = SaverRestore(args.load)
......
...@@ -56,7 +56,7 @@ class ILSVRCMeta(object): ...@@ -56,7 +56,7 @@ class ILSVRCMeta(object):
def get_image_list(self, name): def get_image_list(self, name):
""" """
:param name: 'train' or 'val' or 'test' :param name: 'train' or 'val' or 'test'
:returns list of image filenames :returns list of (image filename, cls)
""" """
assert name in ['train', 'val', 'test'] assert name in ['train', 'val', 'test']
fname = os.path.join(self.dir, name + '.txt') fname = os.path.join(self.dir, name + '.txt')
...@@ -66,7 +66,7 @@ class ILSVRCMeta(object): ...@@ -66,7 +66,7 @@ class ILSVRCMeta(object):
for line in f.readlines(): for line in f.readlines():
name, cls = line.strip().split() name, cls = line.strip().split()
ret.append((name, int(cls))) ret.append((name, int(cls)))
return ret return ret
def get_per_pixel_mean(self, size=None): def get_per_pixel_mean(self, size=None):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment