Commit ded6e834 authored by Yuxin Wu's avatar Yuxin Wu

prefer npz over npy

parent e372f57f
...@@ -14,7 +14,7 @@ uses it to print all variables and their shapes in a checkpoint. ...@@ -14,7 +14,7 @@ uses it to print all variables and their shapes in a checkpoint.
[scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint. [scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint.
It takes a metagraph file (which is also saved by `ModelSaver`) and only saves variables that the model needs at inference time. It takes a metagraph file (which is also saved by `ModelSaver`) and only saves variables that the model needs at inference time.
It can dump the model to a `var-name: value` dict saved in npy/npz format. It can dump the model to a `var-name: value` dict saved in npz format.
## Load a Model ## Load a Model
...@@ -25,7 +25,7 @@ which restores a TF checkpoint, ...@@ -25,7 +25,7 @@ which restores a TF checkpoint,
or [session_init=DictRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore) which restores a dict or [session_init=DictRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore) which restores a dict
([get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader) ([get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader)
is a small helper to decide which one to use from a file name). is a small helper to decide which one to use from a file name).
To load multiple models, use [ChainInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit) To load multiple models, use [ChainInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit).
Variable restoring is completely based on __name match__ between Variable restoring is completely based on __name match__ between
......
...@@ -34,10 +34,10 @@ multiprocess Python program to get a cgroup dedicated for the task. ...@@ -34,10 +34,10 @@ multiprocess Python program to get a cgroup dedicated for the task.
Download models from [model zoo](https://goo.gl/9yIol2). Download models from [model zoo](https://goo.gl/9yIol2).
Watch the agent play: Watch the agent play:
`./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npy` `./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz`
Generate gym submissions: Generate gym submissions:
`./train-atari.py --task gen_submit --load Breakout-v0.npy --env Breakout-v0 --output output_dir` `./train-atari.py --task gen_submit --load Breakout-v0.npz --env Breakout-v0 --output output_dir`
Models are available for the following atari environments (click to watch videos of my agent): Models are available for the following atari environments (click to watch videos of my agent):
......
...@@ -13,14 +13,14 @@ Prepare the model: ...@@ -13,14 +13,14 @@ Prepare the model:
wget http://pearl.vasc.ri.cmu.edu/caffe_model_github/model/_trained_MPI/pose_iter_320000.caffemodel wget http://pearl.vasc.ri.cmu.edu/caffe_model_github/model/_trained_MPI/pose_iter_320000.caffemodel
wget https://github.com/shihenw/convolutional-pose-machines-release/raw/master/model/_trained_MPI/pose_deploy_resize.prototxt wget https://github.com/shihenw/convolutional-pose-machines-release/raw/master/model/_trained_MPI/pose_deploy_resize.prototxt
# convert the model to a dict: # convert the model to a dict:
python -m tensorpack.utils.loadcaffe pose_deploy_resize.prototxt pose_iter_320000.caffemodel CPM-original.npy python -m tensorpack.utils.loadcaffe pose_deploy_resize.prototxt pose_iter_320000.caffemodel CPM-original.npz
``` ```
Or you can download the converted model from [model zoo](http://models.tensorpack.com/caffe/). Or you can download the converted model from [model zoo](http://models.tensorpack.com/caffe/).
Run it on an image, and produce `output.jpg`: Run it on an image, and produce `output.jpg`:
``` ```
python load-cpm.py --load CPM-original.npy --input test.jpg python load-cpm.py --load CPM-original.npz --input test.jpg
``` ```
Input image will get resized to 368x368. Note that this CPM comes without person detection, so the Input image will get resized to 368x368. Note that this CPM comes without person detection, so the
person has to be in the center of the image (and not too small). person has to be in the center of the image (and not too small).
......
...@@ -97,7 +97,7 @@ def CPM(image): ...@@ -97,7 +97,7 @@ def CPM(image):
def run_test(model_path, img_file): def run_test(model_path, img_file):
param_dict = np.load(model_path, encoding='latin1').item() param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')], inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')],
tower_func=CPM, tower_func=CPM,
...@@ -116,7 +116,7 @@ def run_test(model_path, img_file): ...@@ -116,7 +116,7 @@ def run_test(model_path, img_file):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', required=True, help='.npy model file') parser.add_argument('--load', required=True, help='.npz model file')
parser.add_argument('--input', required=True, help='input image') parser.add_argument('--input', required=True, help='input image')
args = parser.parse_args() args = parser.parse_args()
run_test(args.load, args.input) run_test(args.load, args.input)
...@@ -65,8 +65,8 @@ To Train, for example: ...@@ -65,8 +65,8 @@ To Train, for example:
More than 20 CPU cores (for data processing) More than 20 CPU cores (for data processing)
More than 10G of free memory More than 10G of free memory
To Run Pretrained Model: To run pretrained model:
./alexnet-dorefa.py --load alexnet-126.npy --run a.jpg --dorefa 1,2,6 ./alexnet-dorefa.py --load alexnet-126.npz --run a.jpg --dorefa 1,2,6
""" """
BITW = 1 BITW = 1
...@@ -240,7 +240,7 @@ def run_image(model, sess_init, inputs): ...@@ -240,7 +240,7 @@ def run_image(model, sess_init, inputs):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='the physical ids of GPUs to use') parser.add_argument('--gpu', help='the physical ids of GPUs to use')
parser.add_argument('--load', help='load a checkpoint, or a npy (given as the pretrained model)') parser.add_argument('--load', help='load a checkpoint, or a npz (given as the pretrained model)')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--dorefa', parser.add_argument('--dorefa',
help='number of bits for W,A,G, separated by comma', required=True) help='number of bits for W,A,G, separated by comma', required=True)
...@@ -253,8 +253,8 @@ if __name__ == '__main__': ...@@ -253,8 +253,8 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.run: if args.run:
assert args.load.endswith('.npy') assert args.load.endswith('.npz')
run_image(Model(), DictRestore(np.load(args.load, encoding='latin1').item()), args.run) run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
sys.exit() sys.exit()
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
......
...@@ -20,10 +20,10 @@ This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32) ...@@ -20,10 +20,10 @@ This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
It has 59.2% top-1 and 81.5% top-5 validation error on ILSVRC12 validation set. It has 59.2% top-1 and 81.5% top-5 validation error on ILSVRC12 validation set.
To run on images: To run on images:
./resnet-dorefa.py --load pretrained.npy --run a.jpg b.jpg ./resnet-dorefa.py --load ResNet-18-14f.npz --run a.jpg b.jpg
To eval on ILSVRC validation set: To eval on ILSVRC validation set:
./resnet-dorefa.py --load pretrained.npy --eval --data /path/to/ILSVRC ./resnet-dorefa.py --load ResNet-18-14f.npz --eval --data /path/to/ILSVRC
""" """
BITW = 1 BITW = 1
...@@ -145,7 +145,7 @@ def run_image(model, sess_init, inputs): ...@@ -145,7 +145,7 @@ def run_image(model, sess_init, inputs):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='the physical ids of GPUs to use') parser.add_argument('--gpu', help='the physical ids of GPUs to use')
parser.add_argument('--load', help='load a npy pretrained model') parser.add_argument('--load', help='load a npz pretrained model')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--dorefa', parser.add_argument('--dorefa',
help='number of bits for W,A,G, separated by comma. Defaults to \'1,4,32\'', help='number of bits for W,A,G, separated by comma. Defaults to \'1,4,32\'',
...@@ -166,6 +166,5 @@ if __name__ == '__main__': ...@@ -166,6 +166,5 @@ if __name__ == '__main__':
ds = BatchData(ds, 192, remainder=True) ds = BatchData(ds, 192, remainder=True)
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds) eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
elif args.run: elif args.run:
assert args.load.endswith('.npy') assert args.load.endswith('.npz')
run_image(Model(), DictRestore( run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
np.load(args.load, encoding='latin1').item()), args.run)
...@@ -25,7 +25,7 @@ To view augmented training images: ...@@ -25,7 +25,7 @@ To view augmented training images:
To start training: To start training:
```bash ```bash
./hed.py --load vgg16.npy ./hed.py --load vgg16.npz
``` ```
It takes about 100k steps (~10 hours on a TitanX) to reach a reasonable performance. It takes about 100k steps (~10 hours on a TitanX) to reach a reasonable performance.
......
...@@ -37,10 +37,10 @@ Note that the architecture is different from the `imagenet-resnet.py` script and ...@@ -37,10 +37,10 @@ Note that the architecture is different from the `imagenet-resnet.py` script and
Usage: Usage:
```bash ```bash
# download and convert caffe model to npy format # download and convert caffe model to npz format
python -m tensorpack.utils.loadcaffe PATH/TO/{ResNet-101-deploy.prototxt,ResNet-101-model.caffemodel} ResNet101.npy python -m tensorpack.utils.loadcaffe PATH/TO/{ResNet-101-deploy.prototxt,ResNet-101-model.caffemodel} ResNet101.npz
# run on an image # run on an image
./load-resnet.py --load ResNet-101.npy --input cat.jpg --depth 101 ./load-resnet.py --load ResNet-101.npz --input cat.jpg --depth 101
``` ```
The converted models are verified on ILSVRC12 validation set. The converted models are verified on ILSVRC12 validation set.
......
...@@ -155,7 +155,7 @@ def convert_param_name(param): ...@@ -155,7 +155,7 @@ def convert_param_name(param):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', required=True, parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe') help='.npz model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('-d', '--depth', help='resnet depth', required=True, type=int, choices=[50, 101, 152]) parser.add_argument('-d', '--depth', help='resnet depth', required=True, type=int, choices=[50, 101, 152])
parser.add_argument('--input', help='an input image') parser.add_argument('--input', help='an input image')
parser.add_argument('--convert', help='npz output file to save the converted model') parser.add_argument('--convert', help='npz output file to save the converted model')
...@@ -164,7 +164,7 @@ if __name__ == '__main__': ...@@ -164,7 +164,7 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DEPTH = args.depth DEPTH = args.depth
param = np.load(args.load, encoding='latin1').item() param = dict(np.load(args.load))
param = convert_param_name(param) param = convert_param_name(param)
if args.convert: if args.convert:
......
...@@ -22,5 +22,5 @@ To train (takes about 300 epochs to reach 8.8% error): ...@@ -22,5 +22,5 @@ To train (takes about 300 epochs to reach 8.8% error):
To draw the above visualization with [pretrained model](http://models.tensorpack.com/SpatialTransformer/): To draw the above visualization with [pretrained model](http://models.tensorpack.com/SpatialTransformer/):
```bash ```bash
./mnist-addition.py --load pretrained.npy --view ./mnist-addition.py --load mnist-addition.npz --view
``` ```
...@@ -20,17 +20,17 @@ produce a 4x resolution image using different loss functions. ...@@ -20,17 +20,17 @@ produce a 4x resolution image using different loss functions.
```bash ```bash
wget http://images.cocodataset.org/zips/train2017.zip wget http://images.cocodataset.org/zips/train2017.zip
wget http://models.tensorpack.com/caffe/vgg19.npy wget http://models.tensorpack.com/caffe/vgg19.npz
``` ```
2. Train an EnhanceNet-PAT using: 2. Train an EnhanceNet-PAT using:
```bash ```bash
python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.zip python enet-pat.py --vgg19 /path/to/vgg19.npz --data train2017.zip
# or: convert to an lmdb first and train with lmdb: # or: convert to an lmdb first and train with lmdb:
python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create
python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.lmdb python enet-pat.py --vgg19 /path/to/vgg19.npz --data train2017.lmdb
``` ```
Training is highly unstable and does not often give results as good as the pretrained model. Training is highly unstable and does not often give results as good as the pretrained model.
......
...@@ -282,7 +282,7 @@ if __name__ == '__main__': ...@@ -282,7 +282,7 @@ if __name__ == '__main__':
session_init = SaverRestore(args.load) session_init = SaverRestore(args.load)
else: else:
assert os.path.isfile(args.vgg19) assert os.path.isfile(args.vgg19)
param_dict = np.load(args.vgg19, encoding='latin1').item() param_dict = dict(np.load(args.vgg19))
param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)} param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)}
session_init = DictRestore(param_dict) session_init = DictRestore(param_dict)
......
...@@ -21,12 +21,12 @@ Usage: ...@@ -21,12 +21,12 @@ Usage:
Install caffe python bindings. Install caffe python bindings.
python -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy python -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npz
Or download a converted caffe model from http://models.tensorpack.com/caffe/ Or download a converted caffe model from http://models.tensorpack.com/caffe/
Then, run it: Then, run it:
./load-alexnet.py --load alexnet.npy --input cat.png ./load-alexnet.py --load alexnet.npz --input cat.png
""" """
...@@ -54,7 +54,7 @@ def tower_func(image): ...@@ -54,7 +54,7 @@ def tower_func(image):
def run_test(path, input): def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig( predictor = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 227, 227, 3), 'input')], inputs_desc=[InputDesc(tf.float32, (None, 227, 227, 3), 'input')],
tower_func=tower_func, tower_func=tower_func,
...@@ -79,10 +79,10 @@ if __name__ == '__main__': ...@@ -79,10 +79,10 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', required=True, parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe') help='.npz model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('--input', help='an input image', required=True) parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# run alexnet with given model (in npy format) # run alexnet with given model (in npz format)
run_test(args.load, args.input) run_test(args.load, args.input)
...@@ -23,12 +23,12 @@ Usage: ...@@ -23,12 +23,12 @@ Usage:
Install caffe python bindings. Install caffe python bindings.
python -m tensorpack.utils.loadcaffe \ python -m tensorpack.utils.loadcaffe \
PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npz
Or download a converted caffe model from http://models.tensorpack.com/caffe/ Or download a converted caffe model from http://models.tensorpack.com/caffe/
Then, run it: Then, run it:
./load-vgg16.py --load vgg16.npy --input cat.png ./load-vgg16.py --load vgg16.npz --input cat.png
""" """
...@@ -67,7 +67,7 @@ def tower_func(image): ...@@ -67,7 +67,7 @@ def tower_func(image):
def run_test(path, input): def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = dict(np.load(path))
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')], inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')],
tower_func=tower_func, tower_func=tower_func,
...@@ -98,7 +98,7 @@ if __name__ == '__main__': ...@@ -98,7 +98,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', required=True, parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe') help='.npz model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('--input', help='an input image', required=True) parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
......
...@@ -17,11 +17,11 @@ from tensorpack.dataflow.dataset import ILSVRCMeta ...@@ -17,11 +17,11 @@ from tensorpack.dataflow.dataset import ILSVRCMeta
""" """
Usage: Usage:
python -m tensorpack.utils.loadcaffe \ python -m tensorpack.utils.loadcaffe \
PATH/TO/VGG/{VGG_ILSVRC_19_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg19.npy PATH/TO/VGG/{VGG_ILSVRC_19_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg19.npz
./load-vgg19.py --load vgg19.npy --input cat.png ./load-vgg19.py --load vgg19.npz --input cat.png
Or download a converted caffe model from http://models.tensorpack.com/caffe/ Or download a converted caffe model from http://models.tensorpack.com/caffe/
./load-vgg19.py --load vgg19.npy --input cat.png ./load-vgg19.py --load vgg19.npz --input cat.png
""" """
...@@ -63,7 +63,7 @@ def tower_func(image): ...@@ -63,7 +63,7 @@ def tower_func(image):
def run_test(path, input): def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = dict(np.load(path))
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')], inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')],
tower_func=tower_func, tower_func=tower_func,
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', required=True, parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe') help='.npz model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('--input', help='an input image', required=True) parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
......
...@@ -13,12 +13,12 @@ import argparse ...@@ -13,12 +13,12 @@ import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('model') parser.add_argument('model')
parser.add_argument('--dump', help='dump to an npy file') parser.add_argument('--dump', help='dump to an npz file')
parser.add_argument('--shell', action='store_true', help='start a shell with the params') parser.add_argument('--shell', action='store_true', help='start a shell with the params')
args = parser.parse_args() args = parser.parse_args()
if args.model.endswith('.npy'): if args.model.endswith('.npy'):
params = np.load(args.model).item() params = np.load(args.model, encoding='latin1').item()
elif args.model.endswith('.npz'): elif args.model.endswith('.npz'):
params = dict(np.load(args.model)) params = dict(np.load(args.model))
else: else:
...@@ -27,8 +27,8 @@ if __name__ == '__main__': ...@@ -27,8 +27,8 @@ if __name__ == '__main__':
logger.info(str(params.keys())) logger.info(str(params.keys()))
if args.dump: if args.dump:
assert args.dump.endswith('.npy'), args.dump assert args.dump.endswith('.npz'), args.dump
np.save(args.dump, params) np.save(args.dump, **params)
if args.shell: if args.shell:
# params is a dict. play with it # params is a dict. play with it
......
...@@ -14,7 +14,7 @@ if __name__ == '__main__': ...@@ -14,7 +14,7 @@ if __name__ == '__main__':
description='Keep only TRAINABLE and MODEL variables in a checkpoint.') description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
parser.add_argument('--meta', help='metagraph file', required=True) parser.add_argument('--meta', help='metagraph file', required=True)
parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint') parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
parser.add_argument(dest='output', help='output model file, can be npy/npz or TF checkpoint') parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
args = parser.parse_args() args = parser.parse_args()
tf.train.import_meta_graph(args.meta) tf.train.import_meta_graph(args.meta)
......
...@@ -17,6 +17,9 @@ if __name__ == '__main__': ...@@ -17,6 +17,9 @@ if __name__ == '__main__':
if fpath.endswith('.npy'): if fpath.endswith('.npy'):
params = np.load(fpath, encoding='latin1').item() params = np.load(fpath, encoding='latin1').item()
dic = {k: v.shape for k, v in six.iteritems(params)} dic = {k: v.shape for k, v in six.iteritems(params)}
elif fpath.endswith('.npz'):
params = dict(np.load(fpath))
dic = {k: v.shape for k, v in six.iteritems(params)}
else: else:
path = get_checkpoint_path(sys.argv[1]) path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path) reader = tf.train.NewCheckpointReader(path)
......
...@@ -113,10 +113,10 @@ class SessionUpdate(object): ...@@ -113,10 +113,10 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy/npz format (loadable by :class:`DictRestore`). npz format (loadable by :class:`DictRestore`).
Args: Args:
path(str): the file name to save the parameters. Must ends with npy or npz. path(str): the file name to save the parameters. Must ends with npz.
""" """
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
......
...@@ -153,8 +153,14 @@ if __name__ == '__main__': ...@@ -153,8 +153,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('model', help='.prototxt file') parser.add_argument('model', help='.prototxt file')
parser.add_argument('weights', help='.caffemodel file') parser.add_argument('weights', help='.caffemodel file')
parser.add_argument('output', help='output npy file') parser.add_argument('output', help='output npz file')
args = parser.parse_args() args = parser.parse_args()
ret = load_caffe(args.model, args.weights) ret = load_caffe(args.model, args.weights)
if args.output.endswith('.npz'):
np.savez_compressed(args.output, **ret)
elif args.output.endswith('.npy'):
logger.warn("Please use npz format instead!")
np.save(args.output, ret) np.save(args.output, ret)
else:
raise ValueError("Unknown format {}".format(args.output))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment