Commit 20c7fcb4 authored by Yuxin Wu's avatar Yuxin Wu

tabulated model summary

parent 546af8b2
include requirements.txt
include opt-requirements.txt
......@@ -18,26 +18,11 @@ except ImportError:
long_description = open('README.md').read()
# configure requirements
req = [
'numpy',
'six',
'termcolor',
'tqdm>4.11.1',
'msgpack-python',
'msgpack-numpy',
'pyzmq',
'subprocess32;python_version<"3.0"',
'functools32;python_version<"3.0"',
]
extra_req = [
'pillow',
'scipy',
'h5py',
'lmdb',
'matplotlib',
'scikit-learn',
'tornado;python_version<"3.0"',
]
reqfile = os.path.join(CURRENT_DIR, 'requirements.txt')
req = [x.strip() for x in open(reqfile).readlines()]
reqfile = os.path.join(CURRENT_DIR, 'opt-requirements.txt')
extra_req = [x.strip() for x in open(reqfile).readlines()]
# parse scripts
scripts = ['scripts/plot-point.py', 'scripts/dump-model-params.py']
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from termcolor import colored
from tabulate import tabulate
from ..utils import logger
from .summary import add_moving_summary
......@@ -18,18 +19,18 @@ def describe_model():
if len(train_vars) == 0:
logger.info("No trainable variables in the graph!")
return
msg = [""]
total = 0
data = []
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
data.append([v.name, shape.as_list(), ele])
table = tabulate(data, headers=['name', 'shape', 'dim'])
size_mb = total * 4 / 1024.0**2
msg.append(colored(
"Total #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan'))
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg))
summary_msg = colored(
"\nTotal #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
def get_shape_str(tensors):
......
......@@ -117,10 +117,8 @@ class Trainer(object):
self._callbacks.setup_graph(weakref.proxy(self))
# create session
sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...")
self.sess = sess_creator.create_session()
self.sess = self.config.session_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment