Commit 0d54d791 authored by Yuxin Wu's avatar Yuxin Wu

update model location

parent 030a1d31
...@@ -5,8 +5,8 @@ Code and model for the paper: ...@@ -5,8 +5,8 @@ Code and model for the paper:
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA. We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA.
But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32. But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32.
Pretrained model for 1-2-6-AlexNet is available Pretrained model for 1-2-6-AlexNet is available at
[here](https://github.com/ppwwyyxx/tensorpack/releases/tag/alexnet-dorefa). [google drive](https://drive.google.com/a/%20megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
It's provided in the format of numpy dictionary, so it should be very easy to port into other applications. It's provided in the format of numpy dictionary, so it should be very easy to port into other applications.
## Preparation: ## Preparation:
......
...@@ -10,4 +10,4 @@ The validation error here is computed on test set. ...@@ -10,4 +10,4 @@ The validation error here is computed on test set.
![cifar10](cifar10-resnet.png) ![cifar10](cifar10-resnet.png)
Download model: Download model:
[Cifar10 ResNet-110 (n=18)](https://github.com/ppwwyyxx/tensorpack/releases/tag/cifar10-resnet-110) [Cifar10 ResNet-110 (n=18)](https://drive.google.com/open?id=0B9IPQTvr2BBkTXBlZmh1cmlnQ0k)
...@@ -148,6 +148,7 @@ def get_config(): ...@@ -148,6 +148,7 @@ def get_config():
sess_config = get_default_sess_config(0.9) sess_config = get_default_sess_config(0.9)
get_global_step_var()
lr = tf.Variable(0.01, trainable=False, name='learning_rate') lr = tf.Variable(0.01, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: dump-model-params.py # File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import argparse import argparse
import tensorflow as tf import tensorflow as tf
import imp import imp
...@@ -28,7 +29,10 @@ with tf.Graph().as_default() as G: ...@@ -28,7 +29,10 @@ with tf.Graph().as_default() as G:
M = ModelFromMetaGraph(args.meta) M = ModelFromMetaGraph(args.meta)
# loading... # loading...
init = sessinit.SaverRestore(args.model) if args.model.endswith('.npy'):
init = sessinit.ParamRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
sess = tf.Session() sess = tf.Session()
init.init(sess) init.init(sess)
......
...@@ -40,6 +40,9 @@ def get_global_step_var(): ...@@ -40,6 +40,9 @@ def get_global_step_var():
try: try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError: except KeyError:
scope = tf.get_variable_scope()
assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!"
var = tf.Variable( var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return var return var
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment