Commit 0594a9ad authored by Yuxin Wu's avatar Yuxin Wu

Improve model summary message.

parent 2da6f9ed
...@@ -196,4 +196,4 @@ class PeakMemoryTracker(Callback): ...@@ -196,4 +196,4 @@ class PeakMemoryTracker(Callback):
def _after_run(self, _, rv): def _after_run(self, _, rv):
results = rv.results results = rv.results
for mem, dev in zip(results, self._devices): for mem, dev in zip(results, self._devices):
self.trainer.monitors.put_scalar('PeakMemory(MB)' + dev, mem / 1e6) self.trainer.monitors.put_scalar('PeakMemory(MB) ' + dev, mem / 1e6)
...@@ -21,6 +21,7 @@ def describe_trainable_vars(): ...@@ -21,6 +21,7 @@ def describe_trainable_vars():
logger.warn("No trainable variables in the graph!") logger.warn("No trainable variables in the graph!")
return return
total = 0 total = 0
total_bytes = 0
data = [] data = []
devices = set() devices = set()
for v in train_vars: for v in train_vars:
...@@ -29,6 +30,7 @@ def describe_trainable_vars(): ...@@ -29,6 +30,7 @@ def describe_trainable_vars():
shape = v.get_shape() shape = v.get_shape()
ele = shape.num_elements() ele = shape.num_elements()
total += ele total += ele
total_bytes += ele * v.dtype.size
devices.add(v.device) devices.add(v.device)
data.append([v.name, shape.as_list(), ele, v.device]) data.append([v.name, shape.as_list(), ele, v.device])
...@@ -40,9 +42,9 @@ def describe_trainable_vars(): ...@@ -40,9 +42,9 @@ def describe_trainable_vars():
else: else:
table = tabulate(data, headers=['name', 'shape', 'dim', 'device']) table = tabulate(data, headers=['name', 'shape', 'dim', 'device'])
size_mb = total * 4 / 1024.0**2 size_mb = total_bytes / 1024.0**2
summary_msg = colored( summary_msg = colored(
"\nTotal #vars={}, #param={} ({:.02f} MB assuming all float32)".format( "\nTotal #vars={}, #params={}, size={:.02f}MB".format(
len(data), total, size_mb), 'cyan') len(data), total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg) logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment