Commit ab1951d5 authored by Yuxin Wu's avatar Yuxin Wu

expose layer_register() in common.py

parent a5688de9
......@@ -13,7 +13,7 @@ from collections import defaultdict
import six
from six.moves import queue
from ..models._common import disable_layer_logging
from ..models.common import disable_layer_logging
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
......
......@@ -10,7 +10,7 @@ from tensorflow.python.training import moving_averages
from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context
from ..utils import logger, building_rtfd
from ._common import layer_register
from .common import layer_register
__all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
......
# -*- coding: UTF-8 -*-
# File: _common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -11,15 +11,19 @@ from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary
from ..utils import logger, building_rtfd
from ..utils.argtools import shape2d
# make sure each layer is only logged once
_layer_logged = set()
__all__ = ['layer_register', 'disable_layer_logging']
def disable_layer_logging():
"""
Disable the shape logging for all layers from this moment on. Can be
useful when creating multiple towers.
"""
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
......@@ -32,12 +36,15 @@ def layer_register(
use_scope=True):
"""
Register a layer.
:param summary_activation: Define the default behavior of whether to
Args:
summary_activation (bool): Define the default behavior of whether to
summary the output(activation) of this layer.
Can be overriden when creating the layer.
:param log_shape: log input/output shape of this layer
:param use_scope: whether to call this layer with an extra first argument as scope
if set to False, will try to figure out whether the first argument is scope name
log_shape (bool): log input/output shape of this layer
use_scope (bool): whether to call this layer with an extra first argument as scope.
If set to False, will try to figure out whether the first argument
is scope name or not.
"""
def wrapper(func):
......@@ -104,8 +111,3 @@ def layer_register(
wrapper = decorator(wrapper)
return wrapper
def shape4d(a):
# for use with tensorflow NHWC ops
return [1] + shape2d(a) + [1]
......@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from ._common import layer_register, shape2d, shape4d
from .common import layer_register
from ..utils.argtools import shape2d, shape4d
__all__ = ['Conv2D', 'Deconv2D']
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ._common import layer_register
from .common import layer_register
from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected']
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ._common import layer_register
from .common import layer_register
from ._test import TestModel
__all__ = ['ImageSample']
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ._common import layer_register
from .common import layer_register
from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
......
......@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
from ._common import layer_register, shape4d
from ..utils.argtools import shape2d
from .common import layer_register
from ..utils.argtools import shape2d, shape4d
from ..tfutils import symbolic_functions as symbf
from ._test import TestModel
......
......@@ -8,7 +8,7 @@ import re
from ..utils import logger
from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context
from ._common import layer_register
from .common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ._common import layer_register
from .common import layer_register
__all__ = ['ConcatWith']
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ._common import layer_register
from .common import layer_register
__all__ = ['SoftMax']
......
......@@ -45,7 +45,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
have workers that run on multiGPUs
"""
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
from tensorpack.models.common import disable_layer_logging
disable_layer_logging()
self.predictor = OfflinePredictor(self.config)
if self.idx == 0:
......
......@@ -12,7 +12,8 @@ if six.PY2:
else:
import functools
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
__all__ = ['map_arg', 'memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once']
def map_arg(**maps):
......@@ -85,6 +86,19 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a):
"""
Ensuer a 4D shape, to use with NHWC functions.
Args:
a: a int or tuple/list of length 2
Returns:
list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]``.
"""
return [1] + shape2d(a) + [1]
@memoized
def log_once(message, func):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment