Commit 224b0da7 authored by Yuxin Wu's avatar Yuxin Wu

better shape logging, add GAP

parent c22edc77
......@@ -39,12 +39,12 @@ def layer_register(summary_activation=False):
inputs = args[0]
with tf.variable_scope(name) as scope:
outputs = self.f(*args, **kwargs)
if name not in _layer_logged:
if scope.name not in _layer_logged:
# log shape info and add activation
logger.info("{} input: {}".format(
name, get_shape_str(inputs)))
scope.name, get_shape_str(inputs)))
logger.info("{} output: {}".format(
name, get_shape_str(outputs)))
scope.name, get_shape_str(outputs)))
if do_summary:
if isinstance(outputs, list):
......@@ -52,7 +52,7 @@ def layer_register(summary_activation=False):
add_activation_summary(x, scope.name)
else:
add_activation_summary(outputs, scope.name)
_layer_logged.add(name)
_layer_logged.add(scope.name)
return outputs
return WrapedObject(func)
return wrapper
......
......@@ -8,7 +8,7 @@ import numpy
from ._common import *
from ..utils.symbolic_functions import *
__all__ = ['MaxPooling', 'FixedUnPooling']
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling']
@layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'):
......@@ -26,6 +26,26 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def AvgPooling(x, shape, stride=None, padding='VALID'):
"""
shape, stride: int or list/tuple of length 2
if stride is None, use shape by default
padding: 'VALID' or 'SAME'
"""
padding = padding.upper()
shape = shape4d(shape)
if stride is None:
stride = shape
else:
stride = shape4d(stride)
return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def GlobalAvgPooling(x):
assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2])
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
......
......@@ -41,13 +41,16 @@ def getlogger():
logger.addHandler(handler)
return logger
def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
logger = getlogger()
# logger file and directory:
global LOG_FILE, LOG_DIR
def _set_file(path):
if os.path.isfile(path):
backup_name = path + datetime.now().strftime('.%d-%H%M%S')
backup_name = path + '.' + get_time_str()
shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name))
hdl = logging.FileHandler(
......@@ -65,15 +68,14 @@ unless you're resuming from a previous task.""".format(dirname))
act = input().lower()
if act:
break
timestr = datetime.now().strftime('%m%d-%H%M%S')
if act == 'b':
backup_name = dirname + timestr
backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name)
info("Directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'n':
dirname = dirname + timestr
dirname = dirname + get_time_str()
info("Use a different log directory {}".format(dirname))
elif act == 'k':
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment