Commit 9227aa8e authored by Yuxin Wu's avatar Yuxin Wu

fix typo

parent d46d3926
......@@ -68,7 +68,7 @@ class GPUUtilizationTracker(Callback):
self._evt.set()
stats = self._queue.get()
for idx, dev in enumerate(self._devices):
self.trainer.monitors.put_scalar('GPUUtil/{:.2f}'.format(dev), stats[idx])
self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx])
def _after_train(self):
self._stop_evt.set()
......
......@@ -9,6 +9,7 @@ try:
except ImportError:
pass
from contextlib import contextmanager
from itertools import chain
from six.moves import range, zip
import threading
......@@ -503,14 +504,15 @@ class StagingInput(FeedfreeInput):
self._prefill()
return self.fetches
def __init__(self, input, towers=None, nr_stage=1):
def __init__(self, input, towers=None, nr_stage=1, device=None):
"""
Args:
input (FeedfreeInput):
nr_stage: number of elements to prefetch on each GPU.
nr_stage: number of elements to prefetch into each StagingArea, at the beginning.
Since enqueue and dequeue are synchronized, prefetching 1
element should be sufficient.
towers: deprecated
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
"""
assert isinstance(input, FeedfreeInput), input
self._input = input
......@@ -521,6 +523,7 @@ class StagingInput(FeedfreeInput):
self._areas = []
self._stage_ops = []
self._unstage_ops = []
self._device = device
def _setup(self, inputs):
self._input.setup(inputs)
......@@ -530,6 +533,7 @@ class StagingInput(FeedfreeInput):
def _get_callbacks(self):
cbs = self._input.get_callbacks()
# this callback has to happen after others, so StagingInput can be stacked together
cbs.append(
StagingInput.StagingCallback(self, self._nr_stage))
return cbs
......@@ -537,8 +541,16 @@ class StagingInput(FeedfreeInput):
def _size(self):
return self._input.size()
@contextmanager
def _device_ctx(self):
if not self._device:
yield
else:
with tf.device(self._device):
yield
def _get_input_tensors(self):
with self.cached_name_scope():
with self.cached_name_scope(), self._device_ctx():
inputs = self._input.get_input_tensors()
# Putting variables to stagingarea will cause trouble
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment