Commit 838b1df7 authored by ppwwyyxx's avatar ppwwyyxx

add shape summary

parent 5102a8f3
......@@ -35,15 +35,15 @@ def get_model(inputs):
keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
image = tf.expand_dims(image, 3)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5,
padding='valid')
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1 = MaxPooling('pool1', conv1, 2)
conv2 = Conv2D('conv2', pool1, out_channel=32, kernel_shape=3)
fc0 = FullyConnected('fc0', pool1, 1024)
fc0 = FullyConnected('fc0', conv2, 1024)
fc0 = tf.nn.dropout(fc0, keep_prob)
# fc will have activation summary by default. disable this for the output layer
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
from utils.summary import *
from utils import logger
def layer_register(summary_activation=False):
"""
......@@ -17,15 +18,28 @@ def layer_register(summary_activation=False):
args = args[1:]
do_summary = kwargs.pop(
'summary_activation', summary_activation)
inputs = args[0]
if isinstance(inputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), inputs))
else:
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
with tf.variable_scope(name) as scope:
ret = func(*args, **kwargs)
if do_summary:
ndim = ret.get_shape().ndims
assert ndim >= 2, \
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
add_activation_summary(ret, scope.name)
return ret
outputs = func(*args, **kwargs)
if isinstance(outputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), outputs))
if do_summary:
for x in outputs:
add_activation_summary(x, scope.name)
else:
shape_str = str(outputs.get_shape().as_list())
if do_summary:
add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
return outputs
return inner
return wrapper
......@@ -35,7 +49,7 @@ def shape2d(a):
"""
if type(a) == int:
return [a, a]
if type(a) in [list, tuple]:
if isinstance(a, (list, tuple)):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
......
......@@ -42,13 +42,13 @@ def start_train(config):
max_epoch = int(config['max_epoch'])
# build graph
G = tf.get_default_graph()
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
summary_model()
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
......
......@@ -9,6 +9,7 @@ import time
import sys
from contextlib import contextmanager
import logger
import tensorflow as tf
def global_import(name):
p = __import__(name, globals(), locals())
......@@ -29,3 +30,18 @@ def timed_operation(msg, log_start=False):
yield
logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start))
def summary_model():
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
......@@ -22,6 +22,9 @@ def add_activation_summary(x, name=None):
Summary for an activation tensor x.
If name is None, use x.name
"""
ndim = x.get_shape().ndims
assert ndim >= 2, \
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
if name is None:
name = x.name
tf.histogram_summary(name + '/activations', x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment