Commit c43956a1 authored by Yuxin Wu's avatar Yuxin Wu

small fix on a3c

parent 3f14f3a7
...@@ -263,11 +263,11 @@ if __name__ == '__main__': ...@@ -263,11 +263,11 @@ if __name__ == '__main__':
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu > 0: if nr_gpu > 0:
if nr_gpu > 1: if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu // 2:] predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else: else:
predict_tower = [0] predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu // 2] or [0] train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format( logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower)))) ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
trainer = AsyncMultiGPUTrainer trainer = AsyncMultiGPUTrainer
......
...@@ -221,7 +221,8 @@ class MapData(ProxyDataFlow): ...@@ -221,7 +221,8 @@ class MapData(ProxyDataFlow):
Note that if you use the filter feature, ``ds.size()`` will be incorrect. Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note: Note:
Be careful if func modifies datapoints. Please make sure func doesn't modify the components
unless you're certain it's safe.
""" """
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
...@@ -245,8 +246,9 @@ class MapDataComponent(MapData): ...@@ -245,8 +246,9 @@ class MapDataComponent(MapData):
index (int): index of the component. index (int): index of the component.
Note: Note:
This proxy itself doesn't modify the datapoints. But be careful because func This proxy itself doesn't modify the datapoints.
may modify the components. But please make sure func doesn't modify the components
unless you're certain it's safe.
""" """
def f(dp): def f(dp):
r = func(dp[index]) r = func(dp[index])
......
...@@ -24,7 +24,7 @@ class PredictorFactory(object): ...@@ -24,7 +24,7 @@ class PredictorFactory(object):
def fn(_): def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs()) self.model.build_graph(self.model.get_reused_placehdrs())
self._tower_builder = PredictorTowerBuilder(fn) self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list) assert isinstance(self.towers, list), self.towers
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment