Commit 9227aa8e authored by Yuxin Wu's avatar Yuxin Wu

fix typo

parent d46d3926
...@@ -68,7 +68,7 @@ class GPUUtilizationTracker(Callback): ...@@ -68,7 +68,7 @@ class GPUUtilizationTracker(Callback):
self._evt.set() self._evt.set()
stats = self._queue.get() stats = self._queue.get()
for idx, dev in enumerate(self._devices): for idx, dev in enumerate(self._devices):
self.trainer.monitors.put_scalar('GPUUtil/{:.2f}'.format(dev), stats[idx]) self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx])
def _after_train(self): def _after_train(self):
self._stop_evt.set() self._stop_evt.set()
......
...@@ -9,6 +9,7 @@ try: ...@@ -9,6 +9,7 @@ try:
except ImportError: except ImportError:
pass pass
from contextlib import contextmanager
from itertools import chain from itertools import chain
from six.moves import range, zip from six.moves import range, zip
import threading import threading
...@@ -503,14 +504,15 @@ class StagingInput(FeedfreeInput): ...@@ -503,14 +504,15 @@ class StagingInput(FeedfreeInput):
self._prefill() self._prefill()
return self.fetches return self.fetches
def __init__(self, input, towers=None, nr_stage=1): def __init__(self, input, towers=None, nr_stage=1, device=None):
""" """
Args: Args:
input (FeedfreeInput): input (FeedfreeInput):
nr_stage: number of elements to prefetch on each GPU. nr_stage: number of elements to prefetch into each StagingArea, at the beginning.
Since enqueue and dequeue are synchronized, prefetching 1 Since enqueue and dequeue are synchronized, prefetching 1
element should be sufficient. element should be sufficient.
towers: deprecated towers: deprecated
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input = input self._input = input
...@@ -521,6 +523,7 @@ class StagingInput(FeedfreeInput): ...@@ -521,6 +523,7 @@ class StagingInput(FeedfreeInput):
self._areas = [] self._areas = []
self._stage_ops = [] self._stage_ops = []
self._unstage_ops = [] self._unstage_ops = []
self._device = device
def _setup(self, inputs): def _setup(self, inputs):
self._input.setup(inputs) self._input.setup(inputs)
...@@ -530,6 +533,7 @@ class StagingInput(FeedfreeInput): ...@@ -530,6 +533,7 @@ class StagingInput(FeedfreeInput):
def _get_callbacks(self): def _get_callbacks(self):
cbs = self._input.get_callbacks() cbs = self._input.get_callbacks()
# this callback has to happen after others, so StagingInput can be stacked together
cbs.append( cbs.append(
StagingInput.StagingCallback(self, self._nr_stage)) StagingInput.StagingCallback(self, self._nr_stage))
return cbs return cbs
...@@ -537,8 +541,16 @@ class StagingInput(FeedfreeInput): ...@@ -537,8 +541,16 @@ class StagingInput(FeedfreeInput):
def _size(self): def _size(self):
return self._input.size() return self._input.size()
@contextmanager
def _device_ctx(self):
if not self._device:
yield
else:
with tf.device(self._device):
yield
def _get_input_tensors(self): def _get_input_tensors(self):
with self.cached_name_scope(): with self.cached_name_scope(), self._device_ctx():
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
# Putting variables to stagingarea will cause trouble # Putting variables to stagingarea will cause trouble
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment