Commit 55b2e0f1 authored by Yuxin Wu's avatar Yuxin Wu

TF2 compat

parent 9c1b1b7b
......@@ -252,9 +252,15 @@ class GPUMemoryTracker(Callback):
assert isinstance(devices, (list, tuple)), devices
devices = ['/gpu:{}'.format(x) if isinstance(x, int) else x for x in devices]
self._devices = devices
self._disabled = False
def _setup_graph(self):
from tensorflow.contrib.memory_stats import MaxBytesInUse
try:
from tensorflow.contrib.memory_stats import MaxBytesInUse
except ImportError:
logger.warning("GPUMemoryTracker is not available in TF2.")
self._disabled = True
return
ops = []
for dev in self._devices:
with tf.device(dev):
......@@ -262,10 +268,12 @@ class GPUMemoryTracker(Callback):
self._fetches = tf.train.SessionRunArgs(fetches=ops)
def _before_train(self):
assert gpu_available_in_session(), "PeakMemoryTracker only supports GPU!"
if not gpu_available_in_session():
self._disabled = True
logger.warning("GPUMemoryTracker only supports GPU!")
def _before_run(self, _):
if self.local_step == self.trainer.steps_per_epoch - 1:
if not self._disabled and self.local_step == self.trainer.steps_per_epoch - 1:
return self._fetches
return None
......
......@@ -149,9 +149,10 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
a list, contains the return values of `tower_fn` on each tower.
"""
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
if self.ps_device == 'gpu':
if self.ps_device == 'gpu' or get_tf_version_tuple() >= (2, 0):
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else:
# TODO not working in TF2 - cause model to run on CPUs
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
......@@ -204,6 +205,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average
assert mode in ['nccl', 'cpu', 'hierarchical'], mode
if get_tf_version_tuple() >= (2, 0) and mode == 'cpu':
mode = 'nccl' # cpu mode causes the entire model to get located on cpu
self._mode = mode
if self._mode == 'hierarchical' and len(towers) != 8:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment