Commit 43a44c1d authored by Yuxin Wu's avatar Yuxin Wu

print TF build info

parent 775aa3ca
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: common.py # File: common.py
from collections import defaultdict from collections import defaultdict, OrderedDict
from six.moves import map from six.moves import map
from tabulate import tabulate from tabulate import tabulate
import os import os
...@@ -165,6 +165,28 @@ def get_tf_version_tuple(): ...@@ -165,6 +165,28 @@ def get_tf_version_tuple():
return tuple(map(int, tf.__version__.split('.')[:2])) return tuple(map(int, tf.__version__.split('.')[:2]))
def parse_TF_build_info():
ret = OrderedDict()
from tensorflow.python.platform import build_info
try:
for k, v in list(build_info.build_info.items()):
if k == "cuda_version":
ret["TF built with CUDA"] = v
elif k == "cudnn_version":
ret["TF built with CUDNN"] = v
elif k == "cuda_compute_capabilities":
ret["TF compute capabilities"] = ",".join([k.replace("compute_", "") for k in v])
return ret
except AttributeError:
pass
try:
ret["TF built with CUDA"] = build_info.cuda_version_number
ret["TF built with CUDNN"] = build_info.cudnn_version_number
except AttributeError:
pass
return ret
def collect_env_info(): def collect_env_info():
""" """
Returns: Returns:
...@@ -195,9 +217,11 @@ def collect_env_info(): ...@@ -195,9 +217,11 @@ def collect_env_info():
if has_cuda: if has_cuda:
data.append(("Nvidia Driver", find_library("nvidia-ml"))) data.append(("Nvidia Driver", find_library("nvidia-ml")))
data.append(("CUDA", find_library("cudart"))) data.append(("CUDA libs", find_library("cudart")))
data.append(("CUDNN", find_library("cudnn"))) data.append(("CUDNN libs", find_library("cudnn")))
data.append(("NCCL", find_library("nccl"))) for k, v in parse_TF_build_info().items():
data.append((k, v))
data.append(("NCCL libs", find_library("nccl")))
# List devices with NVML # List devices with NVML
data.append( data.append(
......
...@@ -195,7 +195,7 @@ def get_checkpoint_path(path): ...@@ -195,7 +195,7 @@ def get_checkpoint_path(path):
Args: Args:
path: a user-input path path: a user-input path
Returns: Returns:
str: the argument that can be passed to NewCheckpointReader str: the argument that can be passed to `tf.train.NewCheckpointReader`
""" """
if os.path.basename(path) == path: if os.path.basename(path) == path:
path = os.path.join('.', path) # avoid #4921 and #6142 path = os.path.join('.', path) # avoid #4921 and #6142
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment