Commit b059ce49 authored by Yuxin Wu's avatar Yuxin Wu

update

parent aed3438b
......@@ -154,6 +154,7 @@ class MapData(ProxyDataFlow):
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
Note that if you use filter, ds.size() won't be correct.
"""
super(MapData, self).__init__(ds)
self.func = func
......@@ -170,7 +171,8 @@ class MapDataComponent(ProxyDataFlow):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
new value of dp[index]. return None to skip this datapoint.
Note that if you use filter, ds.size() won't be correct.
"""
super(MapDataComponent, self).__init__(ds)
self.func = func
......
......@@ -16,6 +16,13 @@ from ..utils import logger
# make sure each layer is only logged once
_layer_logged = set()
def disable_layer_logging():
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything()
def layer_register(summary_activation=False, log_shape=True):
"""
Register a layer.
......
......@@ -115,9 +115,13 @@ class PredictWorker(multiprocessing.Process):
self.config = config
def run(self):
logger.info("Worker {} use GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0'):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
......@@ -173,13 +177,13 @@ class DatasetPredictor(object):
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
import threading
import time
import copy
import re
import functools
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment