Commit 6d9d89a1 authored by Yuxin Wu's avatar Yuxin Wu

auto reuse variable scope

parent 14ea8af3
...@@ -25,10 +25,18 @@ tensorpack.tfutils.gradproc module ...@@ -25,10 +25,18 @@ tensorpack.tfutils.gradproc module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.tfutils.modelutils module tensorpack.tfutils.model_utils module
------------------------------------ ------------------------------------
.. automodule:: tensorpack.tfutils.modelutils .. automodule:: tensorpack.tfutils.model_utils
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.scope_utils module
------------------------------------
.. automodule:: tensorpack.tfutils.scope_utils
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
......
...@@ -13,6 +13,7 @@ import argparse ...@@ -13,6 +13,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
...@@ -47,6 +48,7 @@ class Model(GANModelDesc): ...@@ -47,6 +48,7 @@ class Model(GANModelDesc):
l = tf.nn.tanh(l, name='gen') l = tf.nn.tanh(l, name='gen')
return l return l
@auto_reuse_variable_scope
def discriminator(self, imgs, y): def discriminator(self, imgs, y):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
yv = y yv = y
...@@ -86,7 +88,6 @@ class Model(GANModelDesc): ...@@ -86,7 +88,6 @@ class Model(GANModelDesc):
tf.summary.image('gen', image_gen, 30) tf.summary.image('gen', image_gen, 30)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos, y) vecpos = self.discriminator(image_pos, y)
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen, y) vecneg = self.discriminator(image_gen, y)
self.build_losses(vecpos, vecneg) self.build_losses(vecpos, vecneg)
......
...@@ -10,6 +10,7 @@ import argparse ...@@ -10,6 +10,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
...@@ -51,6 +52,7 @@ class Model(GANModelDesc): ...@@ -51,6 +52,7 @@ class Model(GANModelDesc):
l = tf.tanh(l, name='gen') l = tf.tanh(l, name='gen')
return l return l
@auto_reuse_variable_scope
def discriminator(self, imgs): def discriminator(self, imgs):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
nf = 64 nf = 64
...@@ -81,7 +83,6 @@ class Model(GANModelDesc): ...@@ -81,7 +83,6 @@ class Model(GANModelDesc):
tf.summary.image('generated-samples', image_gen, max_outputs=30) tf.summary.image('generated-samples', image_gen, max_outputs=30)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos) vecpos = self.discriminator(image_pos)
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen) vecneg = self.discriminator(image_gen)
self.build_losses(vecpos, vecneg) self.build_losses(vecpos, vecneg)
......
...@@ -12,6 +12,7 @@ from tensorpack import * ...@@ -12,6 +12,7 @@ from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
from GAN import SeparateGANTrainer, GANModelDesc from GAN import SeparateGANTrainer, GANModelDesc
...@@ -46,7 +47,9 @@ class Model(GANModelDesc): ...@@ -46,7 +47,9 @@ class Model(GANModelDesc):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'), return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')] InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')]
@auto_reuse_variable_scope
def generator(self, img): def generator(self, img):
assert img is not None
with argscope([Conv2D, Deconv2D], with argscope([Conv2D, Deconv2D],
nl=BNLReLU, kernel_shape=4, stride=2), \ nl=BNLReLU, kernel_shape=4, stride=2), \
argscope(Deconv2D, nl=BNReLU): argscope(Deconv2D, nl=BNReLU):
...@@ -62,6 +65,7 @@ class Model(GANModelDesc): ...@@ -62,6 +65,7 @@ class Model(GANModelDesc):
.tf.sigmoid()()) .tf.sigmoid()())
return l return l
@auto_reuse_variable_scope
def discriminator(self, img): def discriminator(self, img):
with argscope(Conv2D, nl=BNLReLU, kernel_shape=4, stride=2): with argscope(Conv2D, nl=BNLReLU, kernel_shape=4, stride=2):
l = Conv2D('conv0', img, NF, nl=LeakyReLU) l = Conv2D('conv0', img, NF, nl=LeakyReLU)
...@@ -100,9 +104,8 @@ class Model(GANModelDesc): ...@@ -100,9 +104,8 @@ class Model(GANModelDesc):
AB = self.generator(A) AB = self.generator(A)
with tf.variable_scope('A'): with tf.variable_scope('A'):
BA = self.generator(B) BA = self.generator(B)
with tf.variable_scope('A', reuse=True):
ABA = self.generator(AB) ABA = self.generator(AB)
with tf.variable_scope('B', reuse=True): with tf.variable_scope('B'):
BAB = self.generator(BA) BAB = self.generator(BA)
viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon') viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon')
...@@ -113,12 +116,10 @@ class Model(GANModelDesc): ...@@ -113,12 +116,10 @@ class Model(GANModelDesc):
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
with tf.variable_scope('A'): with tf.variable_scope('A'):
A_dis_real, A_feats_real = self.discriminator(A) A_dis_real, A_feats_real = self.discriminator(A)
with tf.variable_scope('A', reuse=True):
A_dis_fake, A_feats_fake = self.discriminator(BA) A_dis_fake, A_feats_fake = self.discriminator(BA)
with tf.variable_scope('B'): with tf.variable_scope('B'):
B_dis_real, B_feats_real = self.discriminator(B) B_dis_real, B_feats_real = self.discriminator(B)
with tf.variable_scope('B', reuse=True):
B_dis_fake, B_feats_fake = self.discriminator(AB) B_dis_fake, B_feats_fake = self.discriminator(AB)
with tf.name_scope('LossA'): with tf.name_scope('LossA'):
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
...@@ -84,6 +85,7 @@ class Model(GANModelDesc): ...@@ -84,6 +85,7 @@ class Model(GANModelDesc):
.ConcatWith(e1, 3) .ConcatWith(e1, 3)
.Deconv2D('deconv8', OUT_CH, nl=tf.tanh)()) .Deconv2D('deconv8', OUT_CH, nl=tf.tanh)())
@auto_reuse_variable_scope
def discriminator(self, inputs, outputs): def discriminator(self, inputs, outputs):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
l = tf.concat([inputs, outputs], 3) l = tf.concat([inputs, outputs], 3)
...@@ -110,7 +112,6 @@ class Model(GANModelDesc): ...@@ -110,7 +112,6 @@ class Model(GANModelDesc):
fake_output = self.generator(input) fake_output = self.generator(input)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
real_pred = self.discriminator(input, output) real_pred = self.discriminator(input, output)
with tf.variable_scope('discrim', reuse=True):
fake_pred = self.discriminator(input, fake_output) fake_pred = self.discriminator(input, fake_output)
self.build_losses(real_pred, fake_pred) self.build_losses(real_pred, fake_pred)
......
...@@ -13,6 +13,7 @@ import argparse ...@@ -13,6 +13,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import * from tensorpack.tfutils.distributions import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
...@@ -54,6 +55,7 @@ class Model(GANModelDesc): ...@@ -54,6 +55,7 @@ class Model(GANModelDesc):
l = tf.sigmoid(l, name='gen') l = tf.sigmoid(l, name='gen')
return l return l
@auto_reuse_variable_scope
def discriminator(self, imgs): def discriminator(self, imgs):
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \ with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.2): argscope(LeakyReLU, alpha=0.2):
...@@ -101,8 +103,6 @@ class Model(GANModelDesc): ...@@ -101,8 +103,6 @@ class Model(GANModelDesc):
# may need to investigate how bn stats should be updated across two discrim # may need to investigate how bn stats should be updated across two discrim
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
real_pred, _ = self.discriminator(real_sample) real_pred, _ = self.discriminator(real_sample)
with tf.variable_scope('discrim', reuse=True):
fake_pred, dist_param = self.discriminator(fake_sample) fake_pred, dist_param = self.discriminator(fake_sample)
""" """
......
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
import copy import copy
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str from ..tfutils.model_utils import get_shape_str
from ..utils import logger from ..utils import logger
from ..utils.develop import building_rtfd from ..utils.develop import building_rtfd
......
...@@ -12,7 +12,7 @@ from ..utils import logger ...@@ -12,7 +12,7 @@ from ..utils import logger
from ..utils.naming import INPUTS_KEY from ..utils.naming import INPUTS_KEY
from ..utils.develop import deprecated, log_deprecated from ..utils.develop import deprecated, log_deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.modelutils import apply_slim_collections from ..tfutils.model_utils import apply_slim_collections
__all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph'] __all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph']
......
...@@ -11,7 +11,7 @@ import tensorflow as tf ...@@ -11,7 +11,7 @@ import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model from ..tfutils.model_utils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
......
...@@ -20,7 +20,7 @@ _TO_IMPORT = set([ ...@@ -20,7 +20,7 @@ _TO_IMPORT = set([
'common', 'common',
'sessinit', 'sessinit',
'argscope', 'argscope',
'tower' 'tower',
]) ])
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
......
...@@ -20,7 +20,6 @@ __all__ = ['get_default_sess_config', ...@@ -20,7 +20,6 @@ __all__ = ['get_default_sess_config',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
'get_op_or_tensor_by_name', 'get_op_or_tensor_by_name',
'get_name_scope_name',
] ]
...@@ -133,15 +132,3 @@ def get_op_or_tensor_by_name(name): ...@@ -133,15 +132,3 @@ def get_op_or_tensor_by_name(name):
return f(name) return f(name)
else: else:
return list(map(f, name)) return list(map(f, name))
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
import tensorflow as tf import tensorflow as tf
from functools import wraps from functools import wraps
import numpy as np import numpy as np
from ..tfutils import get_name_scope_name from .scope_utils import get_name_scope_name
__all__ = ['Distribution', __all__ = ['Distribution',
'CategoricalDistribution', 'GaussianDistribution', 'CategoricalDistribution', 'GaussianDistribution',
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: modelutils.py # File: model_utils.py
# Author: tensorpack contributors # Author: tensorpack contributors
import tensorflow as tf import tensorflow as tf
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: scope_utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import six
if six.PY2:
import functools32 as functools
else:
import functools
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
def auto_reuse_variable_scope(func):
used_scope = set()
@functools.wraps(func)
def wrapper(*args, **kwargs):
scope = tf.get_variable_scope()
h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope:
with tf.variable_scope(scope, reuse=True):
return func(*args, **kwargs)
else:
used_scope.add(h)
return func(*args, **kwargs)
return wrapper
...@@ -19,7 +19,7 @@ from ..utils import logger ...@@ -19,7 +19,7 @@ from ..utils import logger
from ..utils.develop import deprecated, log_deprecated from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment