Commit 4ee67733 authored by Yuxin Wu's avatar Yuxin Wu

dorefa & resnet models

parent ea093029
...@@ -21,7 +21,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -21,7 +21,7 @@ def play_one_episode(player, func, verbose=False):
def f(s): def f(s):
spc = player.get_action_space() spc = player.get_action_space()
act = func([[s]])[0][0].argmax() act = func([[s]])[0][0].argmax()
if random.random() < 0.01: if random.random() < 0.001:
act = spc.sample() act = spc.sample()
if verbose: if verbose:
print(act) print(act)
......
...@@ -5,8 +5,8 @@ Code and model for the paper: ...@@ -5,8 +5,8 @@ Code and model for the paper:
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA. We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA.
But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32. But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32.
Pretrained model for 1-2-6-AlexNet is available at Pretrained model for 1-2-6-AlexNet is available
[google drive](https://drive.google.com/a/ megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ). [here](https://github.com/ppwwyyxx/tensorpack/releases/tag/alexnet-dorefa).
It's provided in the format of numpy dictionary, so it should be very easy to port into other applications. It's provided in the format of numpy dictionary, so it should be very easy to port into other applications.
## Preparation: ## Preparation:
......
...@@ -9,7 +9,5 @@ The validation error here is computed on test set. ...@@ -9,7 +9,5 @@ The validation error here is computed on test set.
![cifar10](cifar10-resnet.png) ![cifar10](cifar10-resnet.png)
<!-- Download model:
-Download model: [Cifar10 ResNet-110 (n=18)](https://github.com/ppwwyyxx/tensorpack/releases/tag/cifar10-resnet-110)
-[Cifar10 n=18](https://drive.google.com/open?id=0B308TeQzmFDLeHpSaHAxWGV1WDg)
-->
...@@ -36,7 +36,7 @@ class ModelSaver(Callback): ...@@ -36,7 +36,7 @@ class ModelSaver(Callback):
vars = tf.all_variables() vars = tf.all_variables()
var_dict = {} var_dict = {}
for v in vars: for v in vars:
name = v.op.name name = v.name
if re.match('tower[p1-9]', name): if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name)) #logger.info("Skip {} when saving model.".format(name))
continue continue
......
...@@ -11,7 +11,7 @@ import six ...@@ -11,7 +11,7 @@ import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name from .common import get_op_var_name
from .varmanip import SessionUpdate from .varmanip import SessionUpdate, get_savename_from_varname
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
...@@ -112,19 +112,17 @@ class SaverRestore(SessionInit): ...@@ -112,19 +112,17 @@ class SaverRestore(SessionInit):
var_dict = defaultdict(list) var_dict = defaultdict(list)
chkpt_vars_used = set() chkpt_vars_used = set()
for v in vars_to_restore: for v in vars_to_restore:
name = v.op.name name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if 'towerp' in name: # try to load both 'varname' and 'opname' from checkpoint
logger.error("No variable should be under 'towerp' name scope".format(v.name)) # because some old checkpoint might not have ':0'
# don't overwrite anything in the current prediction graph if name in vars_available:
continue var_dict[name].append(v)
if 'tower' in name: chkpt_vars_used.add(name)
name = re.sub('tower[p0-9]+/', '', name) elif name.endswith(':0'):
if self.prefix and name.startswith(self.prefix): name = name[:-2]
name = name[len(self.prefix)+1:]
if name in vars_available: if name in vars_available:
var_dict[name].append(v) var_dict[name].append(v)
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
#vars_available.remove(name)
else: else:
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name)) logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available): if len(chkpt_vars_used) < len(vars_available):
...@@ -141,12 +139,13 @@ class ParamRestore(SessionInit): ...@@ -141,12 +139,13 @@ class ParamRestore(SessionInit):
""" """
:param param_dict: a dict of {name: value} :param param_dict: a dict of {name: value}
""" """
# use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.VARIABLES) variables = tf.get_collection(tf.GraphKeys.VARIABLES)
variable_names = set([k.name for k in variables]) variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
intersect = variable_names & param_names intersect = variable_names & param_names
...@@ -159,7 +158,9 @@ class ParamRestore(SessionInit): ...@@ -159,7 +158,9 @@ class ParamRestore(SessionInit):
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess,
[v for v in variables if \
get_savename_from_varname(v.name) in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
......
...@@ -5,10 +5,37 @@ ...@@ -5,10 +5,37 @@
import six import six
import tensorflow as tf import tensorflow as tf
from collections import defaultdict
import re
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars'] __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname']
def get_savename_from_varname(
varname, varname_prefix=None,
savename_prefix=None):
"""
:param varname: a variable name in the graph
:param varname_prefix: an optional prefix that may need to be removed in varname
:param savename_prefix: an optional prefix to append to all savename
:returns: the name used to save the variable
"""
name = varname
if 'towerp' in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph
return None
if 'tower' in name:
name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix)+1:]
if savename_prefix is not None:
name = savename_prefix + '/' + name
return name
class SessionUpdate(object): class SessionUpdate(object):
""" Update the variables in a session """ """ Update the variables in a session """
...@@ -17,10 +44,14 @@ class SessionUpdate(object): ...@@ -17,10 +44,14 @@ class SessionUpdate(object):
:param vars_to_update: a collection of variables to update :param vars_to_update: a collection of variables to update
""" """
self.sess = sess self.sess = sess
self.assign_ops = {} self.assign_ops = defaultdict(list)
for v in vars_to_update: for v in vars_to_update:
p = tf.placeholder(v.dtype, shape=v.get_shape()) #p = tf.placeholder(v.dtype, shape=v.get_shape())
self.assign_ops[v.name] = (p, v.assign(p)) with tf.device('/cpu:0'):
p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name)
# multiple vars might share one savename
self.assign_ops[savename].append((p, v, v.assign(p)))
def update(self, prms): def update(self, prms):
""" """
...@@ -28,15 +59,25 @@ class SessionUpdate(object): ...@@ -28,15 +59,25 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update. Any name in prms must be in the graph and in vars_to_update.
""" """
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
p, op = self.assign_ops[name] assert name in self.assign_ops
varshape = tuple(p.get_shape().as_list()) for p, v, op in self.assign_ops[name]:
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
varshape = tuple(v.get_shape().as_list())
if varshape != value.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \ assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape) "{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name)) logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
self.sess.run(op, feed_dict={p: value}) self.sess.run(op, feed_dict={p: value})
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
def dump_session_params(path): def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as """ Dump value of all trainable + to_save variables to a dict and save to `path` as
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment