Commit 87fad54b authored by Yuxin Wu's avatar Yuxin Wu

Fix bug in resnet; improve logs for #1100

parent ed1030b7
......@@ -135,7 +135,7 @@ if __name__ == '__main__':
model = Model(args.depth, args.mode)
model.data_format = args.data_format
if model.weight_decay_norm:
if args.weight_decay_norm:
model.weight_decay_pattern = ".*/W|.*/gamma|.*/beta"
if args.eval:
......
......@@ -40,10 +40,11 @@ def describe_trainable_vars():
total += ele
total_bytes += ele * v.dtype.size
data.append([v.name, shape, ele, v.device, v.dtype.base_dtype.name])
headers = ['name', 'shape', 'dim', 'device', 'dtype']
headers = ['name', 'shape', '#elements', 'device', 'dtype']
dtypes = set([x[4] for x in data])
if len(dtypes) == 1:
dtypes = list(set([x[4] for x in data]))
if len(dtypes) == 1 and dtypes[0] == "float32":
# don't log the dtype if all vars are float32 (default dtype)
for x in data:
del x[4]
del headers[4]
......@@ -59,9 +60,11 @@ def describe_trainable_vars():
size_mb = total_bytes / 1024.0**2
summary_msg = colored(
"\nTotal #vars={}, #params={}, size={:.02f}MB".format(
len(data), total, size_mb), 'cyan')
logger.info(colored("Trainable Variables: \n", 'cyan') + table + summary_msg)
"\nNumber of trainable variables: {}".format(len(data)) +
"\nNumber of parameters (elements): {}".format(total) +
"\nStorage space needed for all trainable variables: {:.02f}MB".format(size_mb),
'cyan')
logger.info(colored("List of Trainable Variables: \n", 'cyan') + table + summary_msg)
def get_shape_str(tensors):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment