Commit 224b0da7 authored by Yuxin Wu's avatar Yuxin Wu

better shape logging, add GAP

parent c22edc77
...@@ -39,12 +39,12 @@ def layer_register(summary_activation=False): ...@@ -39,12 +39,12 @@ def layer_register(summary_activation=False):
inputs = args[0] inputs = args[0]
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
outputs = self.f(*args, **kwargs) outputs = self.f(*args, **kwargs)
if name not in _layer_logged: if scope.name not in _layer_logged:
# log shape info and add activation # log shape info and add activation
logger.info("{} input: {}".format( logger.info("{} input: {}".format(
name, get_shape_str(inputs))) scope.name, get_shape_str(inputs)))
logger.info("{} output: {}".format( logger.info("{} output: {}".format(
name, get_shape_str(outputs))) scope.name, get_shape_str(outputs)))
if do_summary: if do_summary:
if isinstance(outputs, list): if isinstance(outputs, list):
...@@ -52,7 +52,7 @@ def layer_register(summary_activation=False): ...@@ -52,7 +52,7 @@ def layer_register(summary_activation=False):
add_activation_summary(x, scope.name) add_activation_summary(x, scope.name)
else: else:
add_activation_summary(outputs, scope.name) add_activation_summary(outputs, scope.name)
_layer_logged.add(name) _layer_logged.add(scope.name)
return outputs return outputs
return WrapedObject(func) return WrapedObject(func)
return wrapper return wrapper
......
...@@ -8,7 +8,7 @@ import numpy ...@@ -8,7 +8,7 @@ import numpy
from ._common import * from ._common import *
from ..utils.symbolic_functions import * from ..utils.symbolic_functions import *
__all__ = ['MaxPooling', 'FixedUnPooling'] __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling']
@layer_register() @layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'): def MaxPooling(x, shape, stride=None, padding='VALID'):
...@@ -26,6 +26,26 @@ def MaxPooling(x, shape, stride=None, padding='VALID'): ...@@ -26,6 +26,26 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def AvgPooling(x, shape, stride=None, padding='VALID'):
"""
shape, stride: int or list/tuple of length 2
if stride is None, use shape by default
padding: 'VALID' or 'SAME'
"""
padding = padding.upper()
shape = shape4d(shape)
if stride is None:
stride = shape
else:
stride = shape4d(stride)
return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def GlobalAvgPooling(x):
assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2])
@layer_register() @layer_register()
def FixedUnPooling(x, shape, unpool_mat=None): def FixedUnPooling(x, shape, unpool_mat=None):
......
...@@ -41,13 +41,16 @@ def getlogger(): ...@@ -41,13 +41,16 @@ def getlogger():
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
logger = getlogger() logger = getlogger()
# logger file and directory: # logger file and directory:
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
def _set_file(path): def _set_file(path):
if os.path.isfile(path): if os.path.isfile(path):
backup_name = path + datetime.now().strftime('.%d-%H%M%S') backup_name = path + '.' + get_time_str()
shutil.move(path, backup_name) shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name)) info("Log file '{}' backuped to '{}'".format(path, backup_name))
hdl = logging.FileHandler( hdl = logging.FileHandler(
...@@ -65,15 +68,14 @@ unless you're resuming from a previous task.""".format(dirname)) ...@@ -65,15 +68,14 @@ unless you're resuming from a previous task.""".format(dirname))
act = input().lower() act = input().lower()
if act: if act:
break break
timestr = datetime.now().strftime('%m%d-%H%M%S')
if act == 'b': if act == 'b':
backup_name = dirname + timestr backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name) shutil.move(dirname, backup_name)
info("Directory'{}' backuped to '{}'".format(dirname, backup_name)) info("Directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd': elif act == 'd':
shutil.rmtree(dirname) shutil.rmtree(dirname)
elif act == 'n': elif act == 'n':
dirname = dirname + timestr dirname = dirname + get_time_str()
info("Use a different log directory {}".format(dirname)) info("Use a different log directory {}".format(dirname))
elif act == 'k': elif act == 'k':
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment