Commit e04bb2d5 authored by Yuxin Wu's avatar Yuxin Wu

fix LinearWrap imports. use varreplace in DoReFa.

parent d2d3e3c0
language: "python"
python:
- "2.7"
- "3.5"
sudo: false
cache: pip
before_script:
install:
- pip install flake8
before_script:
- flake8 --version
script:
- flake8 .
- cd examples && flake8 .
......
......@@ -15,6 +15,7 @@ import sys
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from dorefa import get_dorefa
"""
......@@ -82,9 +83,10 @@ class Model(ModelDesc):
image = image / 255.0
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
# don't binarize first and last layer
......@@ -93,7 +95,6 @@ class Model(ModelDesc):
else:
logger.info("Binarizing weight {}".format(v.op.name))
return fw(v)
tf.get_variable = new_get_variable
def nonlin(x):
if BITA == 32:
......@@ -103,7 +104,8 @@ class Model(ModelDesc):
def activate(x):
return fa(nonlin(x))
with argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
with replace_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
.Conv2D('conv0', 96, 12, stride=4, padding='VALID')
......@@ -141,7 +143,6 @@ class Model(ModelDesc):
.BatchNorm('bnfc1')
.apply(nonlin)
.FullyConnected('fct', 1000, use_bias=True)())
tf.get_variable = old_get_variable
prob = tf.nn.softmax(logits, name='output')
......
......@@ -11,6 +11,7 @@ import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from dorefa import get_dorefa
"""
......@@ -52,9 +53,10 @@ class Model(ModelDesc):
is_training = get_current_tower_context().is_training
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
# don't binarize first and last layer
......@@ -63,7 +65,6 @@ class Model(ModelDesc):
else:
logger.info("Binarizing weight {}".format(v.op.name))
return fw(v)
tf.get_variable = new_get_variable
def cabs(x):
return tf.minimum(1.0, tf.abs(x), name='cabs')
......@@ -73,7 +74,8 @@ class Model(ModelDesc):
image = image / 256.0
with argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
with replace_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
.Conv2D('conv0', 48, 5, padding='VALID', use_bias=True)
......@@ -108,7 +110,6 @@ class Model(ModelDesc):
.apply(fg).BatchNorm('bn6')
.apply(cabs)
.FullyConnected('fc1', 10, nl=tf.identity)())
tf.get_variable = old_get_variable
prob = tf.nn.softmax(logits, name='output')
# compute the number of failed samples
......
......@@ -7,6 +7,8 @@ from types import ModuleType
import six
import os
import os.path
# this line is necessary for TFModuleFunc to work
import tensorflow as tf # noqa: F401
from ..utils import logger
__all__ = ['LinearWrap']
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: varreplace.py
# Credit: Qinyao He
import tensorflow as tf
from tensorflow.python.ops import variable_scope
from contextlib import contextmanager
__all__ = ['replace_get_variable']
@contextmanager
def replace_get_variable(fn):
old_getv = tf.get_variable
old_vars_getv = variable_scope.get_variable
tf.get_variable = fn
variable_scope.get_variable = fn
yield
tf.get_variable = old_getv
variable_scope.get_variable = old_vars_getv
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment