Commit 838b1df7 authored by ppwwyyxx's avatar ppwwyyxx

add shape summary

parent 5102a8f3
...@@ -35,15 +35,15 @@ def get_model(inputs): ...@@ -35,15 +35,15 @@ def get_model(inputs):
keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME) keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
image = tf.expand_dims(image, 3) conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5,
padding='valid')
pool0 = MaxPooling('pool0', conv0, 2) pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3) conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1 = MaxPooling('pool1', conv1, 2) pool1 = MaxPooling('pool1', conv1, 2)
conv2 = Conv2D('conv2', pool1, out_channel=32, kernel_shape=3)
fc0 = FullyConnected('fc0', pool1, 1024) fc0 = FullyConnected('fc0', conv2, 1024)
fc0 = tf.nn.dropout(fc0, keep_prob) fc0 = tf.nn.dropout(fc0, keep_prob)
# fc will have activation summary by default. disable this for the output layer # fc will have activation summary by default. disable this for the output layer
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from utils.summary import * from utils.summary import *
from utils import logger
def layer_register(summary_activation=False): def layer_register(summary_activation=False):
""" """
...@@ -17,15 +18,28 @@ def layer_register(summary_activation=False): ...@@ -17,15 +18,28 @@ def layer_register(summary_activation=False):
args = args[1:] args = args[1:]
do_summary = kwargs.pop( do_summary = kwargs.pop(
'summary_activation', summary_activation) 'summary_activation', summary_activation)
inputs = args[0]
if isinstance(inputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), inputs))
else:
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
ret = func(*args, **kwargs) outputs = func(*args, **kwargs)
if isinstance(outputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), outputs))
if do_summary: if do_summary:
ndim = ret.get_shape().ndims for x in outputs:
assert ndim >= 2, \ add_activation_summary(x, scope.name)
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!" else:
add_activation_summary(ret, scope.name) shape_str = str(outputs.get_shape().as_list())
return ret if do_summary:
add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
return outputs
return inner return inner
return wrapper return wrapper
...@@ -35,7 +49,7 @@ def shape2d(a): ...@@ -35,7 +49,7 @@ def shape2d(a):
""" """
if type(a) == int: if type(a) == int:
return [a, a] return [a, a]
if type(a) in [list, tuple]: if isinstance(a, (list, tuple)):
assert len(a) == 2 assert len(a) == 2
return list(a) return list(a)
raise RuntimeError("Illegal shape: {}".format(a)) raise RuntimeError("Illegal shape: {}".format(a))
......
...@@ -42,13 +42,13 @@ def start_train(config): ...@@ -42,13 +42,13 @@ def start_train(config):
max_epoch = int(config['max_epoch']) max_epoch = int(config['max_epoch'])
# build graph # build graph
G = tf.get_default_graph() G = tf.get_default_graph()
for v in input_vars: for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v) G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) G.add_to_collection(OUTPUT_VARS_KEY, v)
summary_model()
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
......
...@@ -9,6 +9,7 @@ import time ...@@ -9,6 +9,7 @@ import time
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import logger import logger
import tensorflow as tf
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals())
...@@ -29,3 +30,18 @@ def timed_operation(msg, log_start=False): ...@@ -29,3 +30,18 @@ def timed_operation(msg, log_start=False):
yield yield
logger.info('finished {}, time={:.2f}sec.'.format( logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
def summary_model():
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
...@@ -22,6 +22,9 @@ def add_activation_summary(x, name=None): ...@@ -22,6 +22,9 @@ def add_activation_summary(x, name=None):
Summary for an activation tensor x. Summary for an activation tensor x.
If name is None, use x.name If name is None, use x.name
""" """
ndim = x.get_shape().ndims
assert ndim >= 2, \
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
if name is None: if name is None:
name = x.name name = x.name
tf.histogram_summary(name + '/activations', x) tf.histogram_summary(name + '/activations', x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment