Commit 4ee67733 authored by Yuxin Wu's avatar Yuxin Wu

dorefa & resnet models

parent ea093029
......@@ -21,7 +21,7 @@ def play_one_episode(player, func, verbose=False):
def f(s):
spc = player.get_action_space()
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
if random.random() < 0.001:
act = spc.sample()
if verbose:
print(act)
......
......@@ -5,8 +5,8 @@ Code and model for the paper:
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA.
But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32.
Pretrained model for 1-2-6-AlexNet is available at
[google drive](https://drive.google.com/a/ megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
Pretrained model for 1-2-6-AlexNet is available
[here](https://github.com/ppwwyyxx/tensorpack/releases/tag/alexnet-dorefa).
It's provided in the format of numpy dictionary, so it should be very easy to port into other applications.
## Preparation:
......
......@@ -9,7 +9,5 @@ The validation error here is computed on test set.
![cifar10](cifar10-resnet.png)
<!--
-Download model:
-[Cifar10 n=18](https://drive.google.com/open?id=0B308TeQzmFDLeHpSaHAxWGV1WDg)
-->
Download model:
[Cifar10 ResNet-110 (n=18)](https://github.com/ppwwyyxx/tensorpack/releases/tag/cifar10-resnet-110)
......@@ -36,7 +36,7 @@ class ModelSaver(Callback):
vars = tf.all_variables()
var_dict = {}
for v in vars:
name = v.op.name
name = v.name
if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name))
continue
......
......@@ -11,7 +11,7 @@ import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name
from .varmanip import SessionUpdate
from .varmanip import SessionUpdate, get_savename_from_varname
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
......@@ -112,19 +112,17 @@ class SaverRestore(SessionInit):
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
name = v.op.name
if 'towerp' in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph
continue
if 'tower' in name:
name = re.sub('tower[p0-9]+/', '', name)
if self.prefix and name.startswith(self.prefix):
name = name[len(self.prefix)+1:]
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
# try to load both 'varname' and 'opname' from checkpoint
# because some old checkpoint might not have ':0'
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
#vars_available.remove(name)
elif name.endswith(':0'):
name = name[:-2]
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
else:
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
......@@ -141,12 +139,13 @@ class ParamRestore(SessionInit):
"""
:param param_dict: a dict of {name: value}
"""
# use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
variable_names = set([k.name for k in variables])
variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms))
intersect = variable_names & param_names
......@@ -159,7 +158,9 @@ class ParamRestore(SessionInit):
logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
upd = SessionUpdate(sess,
[v for v in variables if \
get_savename_from_varname(v.name) in intersect])
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
......
......@@ -5,10 +5,37 @@
import six
import tensorflow as tf
from collections import defaultdict
import re
import numpy as np
from ..utils import logger
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars']
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname']
def get_savename_from_varname(
varname, varname_prefix=None,
savename_prefix=None):
"""
:param varname: a variable name in the graph
:param varname_prefix: an optional prefix that may need to be removed in varname
:param savename_prefix: an optional prefix to append to all savename
:returns: the name used to save the variable
"""
name = varname
if 'towerp' in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph
return None
if 'tower' in name:
name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix)+1:]
if savename_prefix is not None:
name = savename_prefix + '/' + name
return name
class SessionUpdate(object):
""" Update the variables in a session """
......@@ -17,10 +44,14 @@ class SessionUpdate(object):
:param vars_to_update: a collection of variables to update
"""
self.sess = sess
self.assign_ops = {}
self.assign_ops = defaultdict(list)
for v in vars_to_update:
p = tf.placeholder(v.dtype, shape=v.get_shape())
self.assign_ops[v.name] = (p, v.assign(p))
#p = tf.placeholder(v.dtype, shape=v.get_shape())
with tf.device('/cpu:0'):
p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name)
# multiple vars might share one savename
self.assign_ops[savename].append((p, v, v.assign(p)))
def update(self, prms):
"""
......@@ -28,15 +59,25 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
p, op = self.assign_ops[name]
varshape = tuple(p.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
assert name in self.assign_ops
for p, v, op in self.assign_ops[name]:
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
varshape = tuple(v.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
self.sess.run(op, feed_dict={p: value})
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment