Commit 0594a9ad authored by Yuxin Wu's avatar Yuxin Wu

Improve model summary message.

parent 2da6f9ed
......@@ -196,4 +196,4 @@ class PeakMemoryTracker(Callback):
def _after_run(self, _, rv):
results = rv.results
for mem, dev in zip(results, self._devices):
self.trainer.monitors.put_scalar('PeakMemory(MB)' + dev, mem / 1e6)
self.trainer.monitors.put_scalar('PeakMemory(MB) ' + dev, mem / 1e6)
......@@ -21,6 +21,7 @@ def describe_trainable_vars():
logger.warn("No trainable variables in the graph!")
return
total = 0
total_bytes = 0
data = []
devices = set()
for v in train_vars:
......@@ -29,6 +30,7 @@ def describe_trainable_vars():
shape = v.get_shape()
ele = shape.num_elements()
total += ele
total_bytes += ele * v.dtype.size
devices.add(v.device)
data.append([v.name, shape.as_list(), ele, v.device])
......@@ -40,9 +42,9 @@ def describe_trainable_vars():
else:
table = tabulate(data, headers=['name', 'shape', 'dim', 'device'])
size_mb = total * 4 / 1024.0**2
size_mb = total_bytes / 1024.0**2
summary_msg = colored(
"\nTotal #vars={}, #param={} ({:.02f} MB assuming all float32)".format(
"\nTotal #vars={}, #params={}, size={:.02f}MB".format(
len(data), total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment