Commit 78c4cf13 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

support tf.layers in argscope (#778)

* support tf.layers in argscope

* rename

* "lib"->"module"; add docs

* typo
parent 099975c5
...@@ -20,6 +20,7 @@ from tensorpack.tfutils import summary, get_current_tower_context ...@@ -20,6 +20,7 @@ from tensorpack.tfutils import summary, get_current_tower_context
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
IMAGE_SIZE = 28 IMAGE_SIZE = 28
enable_argscope_for_module(tf.layers)
class Model(ModelDesc): class Model(ModelDesc):
...@@ -38,16 +39,17 @@ class Model(ModelDesc): ...@@ -38,16 +39,17 @@ class Model(ModelDesc):
image = image * 2 - 1 # center the pixels values at zero image = image * 2 - 1 # center the pixels values at zero
l = tf.layers.conv2d(image, 32, 3, padding='same', activation=tf.nn.relu, name='conv0') with argscope([tf.layers.conv2d], padding='same', activation=tf.nn.relu):
l = tf.layers.max_pooling2d(l, 2, 2, padding='valid') l = tf.layers.conv2d(image, 32, 3, name='conv0')
l = tf.layers.conv2d(l, 32, 3, padding='same', activation=tf.nn.relu, name='conv1') l = tf.layers.max_pooling2d(l, 2, 2, padding='valid')
l = tf.layers.conv2d(l, 32, 3, padding='same', activation=tf.nn.relu, name='conv2') l = tf.layers.conv2d(l, 32, 3, name='conv1')
l = tf.layers.max_pooling2d(l, 2, 2, padding='valid') l = tf.layers.conv2d(l, 32, 3, name='conv2')
l = tf.layers.conv2d(l, 32, 3, padding='same', activation=tf.nn.relu, name='conv3') l = tf.layers.max_pooling2d(l, 2, 2, padding='valid')
l = tf.layers.flatten(l) l = tf.layers.conv2d(l, 32, 3, name='conv3')
l = tf.layers.dense(l, 512, activation=tf.nn.relu, name='fc0') l = tf.layers.flatten(l)
l = tf.layers.dropout(l, rate=0.5, l = tf.layers.dense(l, 512, activation=tf.nn.relu, name='fc0')
training=get_current_tower_context().is_training) l = tf.layers.dropout(l, rate=0.5,
training=get_current_tower_context().is_training)
logits = tf.layers.dense(l, 10, activation=tf.identity, name='fc1') logits = tf.layers.dense(l, 10, activation=tf.identity, name='fc1')
tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
...@@ -60,7 +62,7 @@ class Model(ModelDesc): ...@@ -60,7 +62,7 @@ class Model(ModelDesc):
accuracy = tf.reduce_mean(correct, name='accuracy') accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error (in a moving_average fashion): # This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard # 1. write the value to tensorboard
# 2. write the value to stat.json # 2. write the value to stat.json
# 3. print the value after each epoch # 3. print the value after each epoch
train_error = tf.reduce_mean(1 - correct, name='train_error') train_error = tf.reduce_mean(1 - correct, name='train_error')
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
import copy import copy
from functools import wraps
from inspect import isfunction, getmembers
__all__ = ['argscope', 'get_arg_scope'] __all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module']
_ArgScopeStack = [] _ArgScopeStack = []
...@@ -60,3 +62,28 @@ def get_arg_scope(): ...@@ -60,3 +62,28 @@ def get_arg_scope():
return _ArgScopeStack[-1] return _ArgScopeStack[-1]
else: else:
return defaultdict(dict) return defaultdict(dict)
def argscope_mapper(func):
"""Decorator for function to support argscope
"""
@wraps(func)
def wrapped_func(*args, **kwargs):
actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs)
out_tensor = func(*args, **actual_args)
return out_tensor
# argscope requires this property
wrapped_func.symbolic_function = None
return wrapped_func
def enable_argscope_for_module(module):
"""
Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences.
It has been only tested to work well with `tf.layers` module.
"""
for name, obj in getmembers(module):
if isfunction(obj):
setattr(module, name, argscope_mapper(obj))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment