Commit e04bb2d5 authored by Yuxin Wu's avatar Yuxin Wu

fix LinearWrap imports. use varreplace in DoReFa.

parent d2d3e3c0
language: "python" language: "python"
python: python:
- "2.7" - "2.7"
- "3.5"
sudo: false sudo: false
cache: pip cache: pip
before_script: install:
- pip install flake8 - pip install flake8
before_script:
- flake8 --version - flake8 --version
script: script:
- flake8 . - flake8 .
- cd examples && flake8 . - cd examples && flake8 .
......
...@@ -15,6 +15,7 @@ import sys ...@@ -15,6 +15,7 @@ import sys
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -82,9 +83,10 @@ class Model(ModelDesc): ...@@ -82,9 +83,10 @@ class Model(ModelDesc):
image = image / 255.0 image = image / 255.0
fw, fa, fg = get_dorefa(BITW, BITA, BITG) fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs) v = old_get_variable(name, shape, **kwargs)
# don't binarize first and last layer # don't binarize first and last layer
...@@ -93,7 +95,6 @@ class Model(ModelDesc): ...@@ -93,7 +95,6 @@ class Model(ModelDesc):
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
return fw(v) return fw(v)
tf.get_variable = new_get_variable
def nonlin(x): def nonlin(x):
if BITA == 32: if BITA == 32:
...@@ -103,7 +104,8 @@ class Model(ModelDesc): ...@@ -103,7 +104,8 @@ class Model(ModelDesc):
def activate(x): def activate(x):
return fa(nonlin(x)) return fa(nonlin(x))
with argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ with replace_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity): argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 96, 12, stride=4, padding='VALID') .Conv2D('conv0', 96, 12, stride=4, padding='VALID')
...@@ -141,7 +143,6 @@ class Model(ModelDesc): ...@@ -141,7 +143,6 @@ class Model(ModelDesc):
.BatchNorm('bnfc1') .BatchNorm('bnfc1')
.apply(nonlin) .apply(nonlin)
.FullyConnected('fct', 1000, use_bias=True)()) .FullyConnected('fct', 1000, use_bias=True)())
tf.get_variable = old_get_variable
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
......
...@@ -11,6 +11,7 @@ import os ...@@ -11,6 +11,7 @@ import os
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -52,9 +53,10 @@ class Model(ModelDesc): ...@@ -52,9 +53,10 @@ class Model(ModelDesc):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
fw, fa, fg = get_dorefa(BITW, BITA, BITG) fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs) v = old_get_variable(name, shape, **kwargs)
# don't binarize first and last layer # don't binarize first and last layer
...@@ -63,7 +65,6 @@ class Model(ModelDesc): ...@@ -63,7 +65,6 @@ class Model(ModelDesc):
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
return fw(v) return fw(v)
tf.get_variable = new_get_variable
def cabs(x): def cabs(x):
return tf.minimum(1.0, tf.abs(x), name='cabs') return tf.minimum(1.0, tf.abs(x), name='cabs')
...@@ -73,7 +74,8 @@ class Model(ModelDesc): ...@@ -73,7 +74,8 @@ class Model(ModelDesc):
image = image / 256.0 image = image / 256.0
with argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ with replace_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 48, 5, padding='VALID', use_bias=True) .Conv2D('conv0', 48, 5, padding='VALID', use_bias=True)
...@@ -108,7 +110,6 @@ class Model(ModelDesc): ...@@ -108,7 +110,6 @@ class Model(ModelDesc):
.apply(fg).BatchNorm('bn6') .apply(fg).BatchNorm('bn6')
.apply(cabs) .apply(cabs)
.FullyConnected('fc1', 10, nl=tf.identity)()) .FullyConnected('fc1', 10, nl=tf.identity)())
tf.get_variable = old_get_variable
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
# compute the number of failed samples # compute the number of failed samples
......
...@@ -7,6 +7,8 @@ from types import ModuleType ...@@ -7,6 +7,8 @@ from types import ModuleType
import six import six
import os import os
import os.path import os.path
# this line is necessary for TFModuleFunc to work
import tensorflow as tf # noqa: F401
from ..utils import logger from ..utils import logger
__all__ = ['LinearWrap'] __all__ = ['LinearWrap']
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: varreplace.py
# Credit: Qinyao He
import tensorflow as tf
from tensorflow.python.ops import variable_scope
from contextlib import contextmanager
__all__ = ['replace_get_variable']
@contextmanager
def replace_get_variable(fn):
old_getv = tf.get_variable
old_vars_getv = variable_scope.get_variable
tf.get_variable = fn
variable_scope.get_variable = fn
yield
tf.get_variable = old_getv
variable_scope.get_variable = old_vars_getv
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment