Commit b059ce49 authored by Yuxin Wu's avatar Yuxin Wu

update

parent aed3438b
...@@ -154,6 +154,7 @@ class MapData(ProxyDataFlow): ...@@ -154,6 +154,7 @@ class MapData(ProxyDataFlow):
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new :param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point. datapoint. return None to skip this data point.
Note that if you use filter, ds.size() won't be correct.
""" """
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
...@@ -170,7 +171,8 @@ class MapDataComponent(ProxyDataFlow): ...@@ -170,7 +171,8 @@ class MapDataComponent(ProxyDataFlow):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint component dp[index], returns a :param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint. new value of dp[index]. return None to skip this datapoint.
Note that if you use filter, ds.size() won't be correct.
""" """
super(MapDataComponent, self).__init__(ds) super(MapDataComponent, self).__init__(ds)
self.func = func self.func = func
......
...@@ -16,6 +16,13 @@ from ..utils import logger ...@@ -16,6 +16,13 @@ from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _layer_logged = set()
def disable_layer_logging():
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything()
def layer_register(summary_activation=False, log_shape=True): def layer_register(summary_activation=False, log_shape=True):
""" """
Register a layer. Register a layer.
......
...@@ -115,9 +115,13 @@ class PredictWorker(multiprocessing.Process): ...@@ -115,9 +115,13 @@ class PredictWorker(multiprocessing.Process):
self.config = config self.config = config
def run(self): def run(self):
logger.info("Worker {} use GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0'): with G.as_default(), tf.device('/gpu:0'):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config) self.func = get_predict_func(self.config)
if self.idx == 0: if self.idx == 0:
describe_model() describe_model()
...@@ -173,13 +177,13 @@ class DatasetPredictor(object): ...@@ -173,13 +177,13 @@ class DatasetPredictor(object):
die_cnt = 0 die_cnt = 0
while True: while True:
res = self.result_queue.get() res = self.result_queue.get()
pbar.update()
if res[0] != DIE: if res[0] != DIE:
yield res[1] yield res[1]
else: else:
die_cnt += 1 die_cnt += 1
if die_cnt == self.nr_gpu: if die_cnt == self.nr_gpu:
break break
pbar.update()
self.inqueue_proc.join() self.inqueue_proc.join()
self.inqueue_proc.terminate() self.inqueue_proc.terminate()
for p in self.workers: for p in self.workers:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
import threading import threading
import time
import copy import copy
import re import re
import functools import functools
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment