Commit c43956a1 authored by Yuxin Wu's avatar Yuxin Wu

small fix on a3c

parent 3f14f3a7
......@@ -263,11 +263,11 @@ if __name__ == '__main__':
nr_gpu = get_nr_gpu()
if nr_gpu > 0:
if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu // 2:]
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu // 2] or [0]
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
trainer = AsyncMultiGPUTrainer
......
......@@ -221,7 +221,8 @@ class MapData(ProxyDataFlow):
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Be careful if func modifies datapoints.
Please make sure func doesn't modify the components
unless you're certain it's safe.
"""
super(MapData, self).__init__(ds)
self.func = func
......@@ -245,8 +246,9 @@ class MapDataComponent(MapData):
index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints. But be careful because func
may modify the components.
This proxy itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
"""
def f(dp):
r = func(dp[index])
......
......@@ -24,7 +24,7 @@ class PredictorFactory(object):
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list)
assert isinstance(self.towers, list), self.towers
def get_predictor(self, input_names, output_names, tower):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment