Commit 2fc3be15 authored by Yuxin Wu's avatar Yuxin Wu

suppress some prctl warnings

parent 6334b813
...@@ -20,25 +20,18 @@ from tensorpack.utils.gpu import get_num_gpu ...@@ -20,25 +20,18 @@ from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils import viz from tensorpack.utils import viz
from imagenet_utils import ( from imagenet_utils import (
fbresnet_augmentor, image_preprocess, compute_loss_and_error) fbresnet_augmentor, ImageNetModel)
from resnet_model import ( from resnet_model import (
preresnet_basicblock, preresnet_group) preresnet_basicblock, preresnet_group)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None DEPTH = None
class Model(ModelDesc): class Model(ImageNetModel):
def inputs(self):
return [tf.placeholder(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
tf.placeholder(tf.int32, [None], 'label')]
def build_graph(self, image, label):
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def get_logits(self, image):
cfg = { cfg = {
18: ([2, 2, 2, 2], preresnet_basicblock), 18: ([2, 2, 2, 2], preresnet_basicblock),
34: ([3, 4, 6, 3], preresnet_basicblock), 34: ([3, 4, 6, 3], preresnet_basicblock),
...@@ -58,11 +51,7 @@ class Model(ModelDesc): ...@@ -58,11 +51,7 @@ class Model(ModelDesc):
print(convmaps) print(convmaps)
convmaps = GlobalAvgPooling('gap', convmaps) convmaps = GlobalAvgPooling('gap', convmaps)
logits = FullyConnected('linearnew', convmaps, 1000) logits = FullyConnected('linearnew', convmaps, 1000)
return logits
loss = compute_loss_and_error(logits, label)
wd_cost = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost)
return tf.add_n([loss, wd_cost], name='cost')
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
...@@ -150,7 +139,7 @@ if __name__ == '__main__': ...@@ -150,7 +139,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--depth', type=int, default=18) parser.add_argument('--depth', type=int, default=18)
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--cam', action='store_true') parser.add_argument('--cam', action='store_true', help='run visualization')
args = parser.parse_args() args = parser.parse_args()
DEPTH = args.depth DEPTH = args.depth
......
...@@ -28,7 +28,6 @@ setup( ...@@ -28,7 +28,6 @@ setup(
"termcolor>=1.1", "termcolor>=1.1",
"tabulate>=0.7.7", "tabulate>=0.7.7",
"tqdm>4.11.1", "tqdm>4.11.1",
"pyarrow>=0.9.0",
"msgpack>=0.5.2", "msgpack>=0.5.2",
"msgpack-numpy>=0.4.0", "msgpack-numpy>=0.4.0",
"pyzmq>=16", "pyzmq>=16",
......
...@@ -363,5 +363,5 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -363,5 +363,5 @@ class StatMonitorParamSetter(HyperParamSetter):
self.last_changed_epoch = self.epoch_num self.last_changed_epoch = self.epoch_num
logger.info( logger.info(
"[StatMonitorParamSetter] Triggered, history of {}: ".format( "[StatMonitorParamSetter] Triggered, history of {}: ".format(
self.stat_name) + ','.join(map(str, hist))) self.stat_name) + ','.join([str(round(x, 3)) for x in hist]))
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
...@@ -14,6 +14,7 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', ...@@ -14,6 +14,7 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
def check_dtype(img): def check_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
if isinstance(img.dtype, np.integer): if isinstance(img.dtype, np.integer):
assert img.dtype == np.uint8, \ assert img.dtype == np.uint8, \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype) "[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
......
...@@ -149,13 +149,14 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -149,13 +149,14 @@ class MultiProcessPrefetchData(ProxyDataFlow):
""" """
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue, idx):
super(MultiProcessPrefetchData._Worker, self).__init__() super(MultiProcessPrefetchData._Worker, self).__init__()
self.ds = ds self.ds = ds
self.queue = queue self.queue = queue
self.idx = idx
def run(self): def run(self):
enable_death_signal() enable_death_signal(_warn=self.idx == 0)
# reset all ds so each process will produce different data # reset all ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
...@@ -186,8 +187,8 @@ lead of failure on some of the code.") ...@@ -186,8 +187,8 @@ lead of failure on some of the code.")
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue) self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue, idx)
for _ in range(self.nr_proc)] for idx in range(self.nr_proc)]
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
...@@ -249,14 +250,15 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -249,14 +250,15 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
""" """
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm): def __init__(self, ds, conn_name, hwm, idx):
super(PrefetchDataZMQ._Worker, self).__init__() super(PrefetchDataZMQ._Worker, self).__init__()
self.ds = ds self.ds = ds
self.conn_name = conn_name self.conn_name = conn_name
self.hwm = hwm self.hwm = hwm
self.idx = idx
def run(self): def run(self):
enable_death_signal() enable_death_signal(_warn=self.idx == 0)
self.ds.reset_state() self.ds.reset_state()
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.PUSH) socket = context.socket(zmq.PUSH)
...@@ -315,8 +317,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -315,8 +317,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow') pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename) _bind_guard(self.socket, pipename)
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm) self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm, idx)
for _ in range(self.nr_proc)] for idx in range(self.nr_proc)]
self._start_processes() self._start_processes()
......
...@@ -224,7 +224,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -224,7 +224,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.hwm = hwm self.hwm = hwm
def run(self): def run(self):
enable_death_signal() enable_death_signal(_warn=self.identity == b'0')
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.REP) socket = ctx.socket(zmq.REP)
socket.setsockopt(zmq.IDENTITY, self.identity) socket.setsockopt(zmq.IDENTITY, self.identity)
......
...@@ -356,7 +356,7 @@ class GradientPacker(object): ...@@ -356,7 +356,7 @@ class GradientPacker(object):
split_size_last = self._total_size - split_size * (self._num_split - 1) split_size_last = self._total_size - split_size * (self._num_split - 1)
self._split_sizes = [split_size] * (self._num_split - 1) + [split_size_last] self._split_sizes = [split_size] * (self._num_split - 1) + [split_size_last]
logger.info( logger.info(
"Will pack {} gradients of total number={} into {} splits.".format( "Will pack {} gradients of total dimension={} into {} splits.".format(
len(self._sizes), self._total_size, self._num_split)) len(self._sizes), self._total_size, self._num_split))
return True return True
......
...@@ -173,7 +173,7 @@ def ensure_proc_terminate(proc): ...@@ -173,7 +173,7 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def enable_death_signal(): def enable_death_signal(_warn=True):
""" """
Set the "death signal" of the current process, so that Set the "death signal" of the current process, so that
the current process will be cleaned with guarantee the current process will be cleaned with guarantee
...@@ -184,6 +184,7 @@ def enable_death_signal(): ...@@ -184,6 +184,7 @@ def enable_death_signal():
try: try:
import prctl # pip install python-prctl import prctl # pip install python-prctl
except ImportError: except ImportError:
if _warn:
log_once('Install python-prctl so that processes can be cleaned with guarantee.', 'warn') log_once('Install python-prctl so that processes can be cleaned with guarantee.', 'warn')
return return
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment