Commit e457e2db authored by Yuxin Wu's avatar Yuxin Wu

ConcatWith

parent 881c6c4b
...@@ -33,6 +33,7 @@ def layer_register( ...@@ -33,6 +33,7 @@ def layer_register(
summary the output(activation) of this layer. summary the output(activation) of this layer.
Can be overriden when creating the layer. Can be overriden when creating the layer.
:param log_shape: log input/output shape of this layer :param log_shape: log input/output shape of this layer
:param use_scope: whether to call this layer with an extra first argument as scope
""" """
def wrapper(func): def wrapper(func):
......
...@@ -64,6 +64,7 @@ def LeakyReLU(x, alpha, name=None): ...@@ -64,6 +64,7 @@ def LeakyReLU(x, alpha, name=None):
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) #x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name) #return tf.mul(x, 0.5, name=name)
# TODO wrap it as a layer with use_scope=False?
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None) x = BatchNorm('bn', x, use_local_stat=None)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: shapes.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ._common import layer_register
__all__ = ['ConcatWith']
@layer_register(use_scope=False, log_shape=False)
def ConcatWith(x, dim, tensor):
"""
A wrapper around `tf.concat` to support `LinearWrap`
:param x: the input tensor
:param dim: the dimension along which to concatenate
:param tensor: a tensor or list of tensor to concatenate with x. x will be
at the beginning
:return: tf.concat(dim, [x] + [tensor])
"""
if type(tensor) != list:
tensor = [tensor]
return tf.concat(dim, [x] + tensor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment