Commit ce1c507d authored by Yuxin Wu's avatar Yuxin Wu

notes on stagingarea.

parent ffa4ed10
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import tensorflow as tf import tensorflow as tf
import copy
import six import six
import re
import pprint
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
...@@ -53,7 +56,19 @@ class DataParallelBuilder(GraphBuilder): ...@@ -53,7 +56,19 @@ class DataParallelBuilder(GraphBuilder):
grad_list: list of list of tuples, shape is Ngpu x Nvar x 2 grad_list: list of list of tuples, shape is Ngpu x Nvar x 2
""" """
nvars = [len(k) for k in grad_list] nvars = [len(k) for k in grad_list]
assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
def basename(x):
return re.sub('tower[0-9]+/', '', x.op.name)
if len(set(nvars)) != 1:
names_per_gpu = [set([basename(k[1]) for k in grad_and_vars]) for grad_and_vars in grad_list]
inters = copy.copy(names_per_gpu[0])
for s in names_per_gpu:
inters &= s
for s in names_per_gpu:
s -= inters
logger.error("Unique variables on towers: " + pprint.pformat(names_per_gpu))
raise ValueError("Number of gradients from each tower is different! " + str(nvars))
@staticmethod @staticmethod
def build_on_towers( def build_on_towers(
......
...@@ -499,11 +499,13 @@ class StagingInput(FeedfreeInput): ...@@ -499,11 +499,13 @@ class StagingInput(FeedfreeInput):
self._prefill() self._prefill()
return self.fetches return self.fetches
def __init__(self, input, towers=None, nr_stage=5): def __init__(self, input, towers=None, nr_stage=1):
""" """
Args: Args:
input (FeedfreeInput): input (FeedfreeInput):
nr_stage: number of elements to prefetch on each GPU. nr_stage: number of elements to prefetch on each GPU.
Since enqueue and dequeue are synchronized, prefetching 1
element should be sufficient.
towers: deprecated towers: deprecated
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
...@@ -515,7 +517,6 @@ class StagingInput(FeedfreeInput): ...@@ -515,7 +517,6 @@ class StagingInput(FeedfreeInput):
self._areas = [] self._areas = []
self._stage_ops = [] self._stage_ops = []
self._unstage_ops = [] self._unstage_ops = []
# self._size_ops = []
def _setup(self, inputs): def _setup(self, inputs):
self._input.setup(inputs) self._input.setup(inputs)
...@@ -542,6 +543,8 @@ class StagingInput(FeedfreeInput): ...@@ -542,6 +543,8 @@ class StagingInput(FeedfreeInput):
inputs[idx] = tf.identity(inputs[idx]) inputs[idx] = tf.identity(inputs[idx])
dtypes.append(dtype.base_dtype) dtypes.append(dtype.base_dtype)
# TODO tensorflow/benchmarks use static shapes here,
# though it doesn't seem to help. We can use it when it's known.
stage = StagingArea(dtypes, shapes=None) stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs)) self._stage_ops.append(stage.put(inputs))
self._areas.append(stage) self._areas.append(stage)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment