Commit 6d9d89a1 authored by Yuxin Wu's avatar Yuxin Wu

auto reuse variable scope

parent 14ea8af3
......@@ -25,10 +25,18 @@ tensorpack.tfutils.gradproc module
:undoc-members:
:show-inheritance:
tensorpack.tfutils.modelutils module
tensorpack.tfutils.model_utils module
------------------------------------
.. automodule:: tensorpack.tfutils.modelutils
.. automodule:: tensorpack.tfutils.model_utils
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.scope_utils module
------------------------------------
.. automodule:: tensorpack.tfutils.scope_utils
:members:
:undoc-members:
:show-inheritance:
......
......@@ -13,6 +13,7 @@ import argparse
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, RandomZData, GANModelDesc
"""
......@@ -47,6 +48,7 @@ class Model(GANModelDesc):
l = tf.nn.tanh(l, name='gen')
return l
@auto_reuse_variable_scope
def discriminator(self, imgs, y):
""" return a (b, 1) logits"""
yv = y
......@@ -86,7 +88,6 @@ class Model(GANModelDesc):
tf.summary.image('gen', image_gen, 30)
with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos, y)
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen, y)
self.build_losses(vecpos, vecneg)
......
......@@ -10,6 +10,7 @@ import argparse
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc
......@@ -51,6 +52,7 @@ class Model(GANModelDesc):
l = tf.tanh(l, name='gen')
return l
@auto_reuse_variable_scope
def discriminator(self, imgs):
""" return a (b, 1) logits"""
nf = 64
......@@ -81,7 +83,6 @@ class Model(GANModelDesc):
tf.summary.image('generated-samples', image_gen, max_outputs=30)
with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos)
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen)
self.build_losses(vecpos, vecneg)
......
......@@ -12,6 +12,7 @@ from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import SeparateGANTrainer, GANModelDesc
......@@ -46,7 +47,9 @@ class Model(GANModelDesc):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')]
@auto_reuse_variable_scope
def generator(self, img):
assert img is not None
with argscope([Conv2D, Deconv2D],
nl=BNLReLU, kernel_shape=4, stride=2), \
argscope(Deconv2D, nl=BNReLU):
......@@ -62,6 +65,7 @@ class Model(GANModelDesc):
.tf.sigmoid()())
return l
@auto_reuse_variable_scope
def discriminator(self, img):
with argscope(Conv2D, nl=BNLReLU, kernel_shape=4, stride=2):
l = Conv2D('conv0', img, NF, nl=LeakyReLU)
......@@ -100,9 +104,8 @@ class Model(GANModelDesc):
AB = self.generator(A)
with tf.variable_scope('A'):
BA = self.generator(B)
with tf.variable_scope('A', reuse=True):
ABA = self.generator(AB)
with tf.variable_scope('B', reuse=True):
with tf.variable_scope('B'):
BAB = self.generator(BA)
viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon')
......@@ -113,12 +116,10 @@ class Model(GANModelDesc):
with tf.variable_scope('discrim'):
with tf.variable_scope('A'):
A_dis_real, A_feats_real = self.discriminator(A)
with tf.variable_scope('A', reuse=True):
A_dis_fake, A_feats_fake = self.discriminator(BA)
with tf.variable_scope('B'):
B_dis_real, B_feats_real = self.discriminator(B)
with tf.variable_scope('B', reuse=True):
B_dis_fake, B_feats_fake = self.discriminator(AB)
with tf.name_scope('LossA'):
......
......@@ -15,6 +15,7 @@ import argparse
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, GANModelDesc
......@@ -84,6 +85,7 @@ class Model(GANModelDesc):
.ConcatWith(e1, 3)
.Deconv2D('deconv8', OUT_CH, nl=tf.tanh)())
@auto_reuse_variable_scope
def discriminator(self, inputs, outputs):
""" return a (b, 1) logits"""
l = tf.concat([inputs, outputs], 3)
......@@ -110,7 +112,6 @@ class Model(GANModelDesc):
fake_output = self.generator(input)
with tf.variable_scope('discrim'):
real_pred = self.discriminator(input, output)
with tf.variable_scope('discrim', reuse=True):
fake_pred = self.discriminator(input, fake_output)
self.build_losses(real_pred, fake_pred)
......
......@@ -13,6 +13,7 @@ import argparse
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc
......@@ -54,6 +55,7 @@ class Model(GANModelDesc):
l = tf.sigmoid(l, name='gen')
return l
@auto_reuse_variable_scope
def discriminator(self, imgs):
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.2):
......@@ -101,8 +103,6 @@ class Model(GANModelDesc):
# may need to investigate how bn stats should be updated across two discrim
with tf.variable_scope('discrim'):
real_pred, _ = self.discriminator(real_sample)
with tf.variable_scope('discrim', reuse=True):
fake_pred, dist_param = self.discriminator(fake_sample)
"""
......
......@@ -8,7 +8,7 @@ import six
import copy
from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.model_utils import get_shape_str
from ..utils import logger
from ..utils.develop import building_rtfd
......
......@@ -12,7 +12,7 @@ from ..utils import logger
from ..utils.naming import INPUTS_KEY
from ..utils.develop import deprecated, log_deprecated
from ..utils.argtools import memoized
from ..tfutils.modelutils import apply_slim_collections
from ..tfutils.model_utils import apply_slim_collections
__all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph']
......
......@@ -11,7 +11,7 @@ import tensorflow as tf
from ..utils import logger
from ..utils.develop import deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model
from ..tfutils.model_utils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
......
......@@ -20,7 +20,7 @@ _TO_IMPORT = set([
'common',
'sessinit',
'argscope',
'tower'
'tower',
])
_CURR_DIR = os.path.dirname(__file__)
......
......@@ -20,7 +20,6 @@ __all__ = ['get_default_sess_config',
'get_op_tensor_name',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
'get_name_scope_name',
]
......@@ -133,15 +132,3 @@ def get_op_or_tensor_by_name(name):
return f(name)
else:
return list(map(f, name))
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
import tensorflow as tf
from functools import wraps
import numpy as np
from ..tfutils import get_name_scope_name
from .scope_utils import get_name_scope_name
__all__ = ['Distribution',
'CategoricalDistribution', 'GaussianDistribution',
......
# -*- coding: UTF-8 -*-
# File: modelutils.py
# File: model_utils.py
# Author: tensorpack contributors
import tensorflow as tf
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: scope_utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import six
if six.PY2:
import functools32 as functools
else:
import functools
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
def auto_reuse_variable_scope(func):
used_scope = set()
@functools.wraps(func)
def wrapper(*args, **kwargs):
scope = tf.get_variable_scope()
h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope:
with tf.variable_scope(scope, reuse=True):
return func(*args, **kwargs)
else:
used_scope.add(h)
return func(*args, **kwargs)
return wrapper
......@@ -19,7 +19,7 @@ from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment