Commit 20c7fcb4 authored by Yuxin Wu's avatar Yuxin Wu

tabulated model summary

parent 546af8b2
include requirements.txt
include opt-requirements.txt
...@@ -18,26 +18,11 @@ except ImportError: ...@@ -18,26 +18,11 @@ except ImportError:
long_description = open('README.md').read() long_description = open('README.md').read()
# configure requirements # configure requirements
req = [ reqfile = os.path.join(CURRENT_DIR, 'requirements.txt')
'numpy', req = [x.strip() for x in open(reqfile).readlines()]
'six',
'termcolor', reqfile = os.path.join(CURRENT_DIR, 'opt-requirements.txt')
'tqdm>4.11.1', extra_req = [x.strip() for x in open(reqfile).readlines()]
'msgpack-python',
'msgpack-numpy',
'pyzmq',
'subprocess32;python_version<"3.0"',
'functools32;python_version<"3.0"',
]
extra_req = [
'pillow',
'scipy',
'h5py',
'lmdb',
'matplotlib',
'scikit-learn',
'tornado;python_version<"3.0"',
]
# parse scripts # parse scripts
scripts = ['scripts/plot-point.py', 'scripts/dump-model-params.py'] scripts = ['scripts/plot-point.py', 'scripts/dump-model-params.py']
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from termcolor import colored from termcolor import colored
from tabulate import tabulate
from ..utils import logger from ..utils import logger
from .summary import add_moving_summary from .summary import add_moving_summary
...@@ -18,18 +19,18 @@ def describe_model(): ...@@ -18,18 +19,18 @@ def describe_model():
if len(train_vars) == 0: if len(train_vars) == 0:
logger.info("No trainable variables in the graph!") logger.info("No trainable variables in the graph!")
return return
msg = [""]
total = 0 total = 0
data = []
for v in train_vars: for v in train_vars:
shape = v.get_shape() shape = v.get_shape()
ele = shape.num_elements() ele = shape.num_elements()
total += ele total += ele
msg.append("{}: shape={}, dim={}".format( data.append([v.name, shape.as_list(), ele])
v.name, shape.as_list(), ele)) table = tabulate(data, headers=['name', 'shape', 'dim'])
size_mb = total * 4 / 1024.0**2 size_mb = total * 4 / 1024.0**2
msg.append(colored( summary_msg = colored(
"Total #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan')) "\nTotal #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan')
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg)) logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
def get_shape_str(tensors): def get_shape_str(tensors):
......
...@@ -117,10 +117,8 @@ class Trainer(object): ...@@ -117,10 +117,8 @@ class Trainer(object):
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
# create session # create session
sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...") logger.info("Finalize the graph, create the session ...")
self.sess = self.config.session_creator.create_session()
self.sess = sess_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession( self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None) session_creator=ReuseSessionCreator(self.sess), hooks=None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment