Commit 3356b8de authored by Yuxin Wu's avatar Yuxin Wu

Check GPU availability from session

parent f34f454e
......@@ -48,6 +48,7 @@ except ImportError:
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['cv2'].__version__ = '3.2.1' # fake version
sys.modules['msgpack'].version = (0, 5, 2)
import tensorpack
......
......@@ -15,6 +15,7 @@ from ..utils import logger
from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from ..utils.gpu import get_num_gpu
from ..utils.nvml import NVMLContext
from ..tfutils.common import gpu_available_in_session
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker']
......@@ -53,7 +54,7 @@ class GPUUtilizationTracker(Callback):
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
def _before_train(self):
# assert tf.test.is_gpu_available()
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
self._evt = mp.Event()
self._stop_evt = mp.Event()
self._queue = mp.Queue()
......@@ -212,8 +213,10 @@ class PeakMemoryTracker(Callback):
ops.append(MaxBytesInUse())
self._fetches = tf.train.SessionRunArgs(fetches=ops)
def _before_train(self):
assert gpu_available_in_session(), "PeakMemoryTracker only supports GPU!"
def _before_run(self, _):
# assert tf.test.is_gpu_available(), "PeakMemoryTracker only supports GPU!"
if self.local_step == self.trainer.steps_per_epoch - 1:
return self._fetches
return None
......
......@@ -139,6 +139,14 @@ def get_op_or_tensor_by_name(name):
return list(map(f, name))
def gpu_available_in_session():
sess = tf.get_default_session()
for dev in sess.list_devices():
if dev.device_type.lower() == 'gpu':
return True
return False
@deprecated("You should use get_tf_version_tuple instead due to the existence of TF 1.10")
def get_tf_version_number():
return float('.'.join(tf.VERSION.split('.')[:2]))
......
......@@ -430,8 +430,6 @@ class HorovodTrainer(SingleCostTrainer):
except AttributeError: # old horovod does not have local_size
pass
super(HorovodTrainer, self).initialize(session_creator, session_init)
# if not tf.test.is_gpu_available():
# logger.error("tf.test.is_gpu_available() == False")
# This broadcast belongs to the "intialize" stage
# It should not be delayed to the "before_train" stage.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment