Commit 2fc3be15 authored by Yuxin Wu's avatar Yuxin Wu

suppress some prctl warnings

parent 6334b813
......@@ -20,25 +20,18 @@ from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils import viz
from imagenet_utils import (
fbresnet_augmentor, image_preprocess, compute_loss_and_error)
fbresnet_augmentor, ImageNetModel)
from resnet_model import (
preresnet_basicblock, preresnet_group)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
class Model(ModelDesc):
def inputs(self):
return [tf.placeholder(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
tf.placeholder(tf.int32, [None], 'label')]
def build_graph(self, image, label):
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
class Model(ImageNetModel):
def get_logits(self, image):
cfg = {
18: ([2, 2, 2, 2], preresnet_basicblock),
34: ([3, 4, 6, 3], preresnet_basicblock),
......@@ -58,11 +51,7 @@ class Model(ModelDesc):
print(convmaps)
convmaps = GlobalAvgPooling('gap', convmaps)
logits = FullyConnected('linearnew', convmaps, 1000)
loss = compute_loss_and_error(logits, label)
wd_cost = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost)
return tf.add_n([loss, wd_cost], name='cost')
return logits
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
......@@ -150,7 +139,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--depth', type=int, default=18)
parser.add_argument('--load', help='load model')
parser.add_argument('--cam', action='store_true')
parser.add_argument('--cam', action='store_true', help='run visualization')
args = parser.parse_args()
DEPTH = args.depth
......
......@@ -28,7 +28,6 @@ setup(
"termcolor>=1.1",
"tabulate>=0.7.7",
"tqdm>4.11.1",
"pyarrow>=0.9.0",
"msgpack>=0.5.2",
"msgpack-numpy>=0.4.0",
"pyzmq>=16",
......
......@@ -363,5 +363,5 @@ class StatMonitorParamSetter(HyperParamSetter):
self.last_changed_epoch = self.epoch_num
logger.info(
"[StatMonitorParamSetter] Triggered, history of {}: ".format(
self.stat_name) + ','.join(map(str, hist)))
self.stat_name) + ','.join([str(round(x, 3)) for x in hist]))
return self.value_func(self.get_current_value())
......@@ -14,6 +14,7 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
def check_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
if isinstance(img.dtype, np.integer):
assert img.dtype == np.uint8, \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
......
......@@ -149,13 +149,14 @@ class MultiProcessPrefetchData(ProxyDataFlow):
"""
class _Worker(mp.Process):
def __init__(self, ds, queue):
def __init__(self, ds, queue, idx):
super(MultiProcessPrefetchData._Worker, self).__init__()
self.ds = ds
self.queue = queue
self.idx = idx
def run(self):
enable_death_signal()
enable_death_signal(_warn=self.idx == 0)
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
......@@ -186,8 +187,8 @@ lead of failure on some of the code.")
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)]
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue, idx)
for idx in range(self.nr_proc)]
ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs)
......@@ -249,14 +250,15 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
"""
class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm):
def __init__(self, ds, conn_name, hwm, idx):
super(PrefetchDataZMQ._Worker, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
self.idx = idx
def run(self):
enable_death_signal()
enable_death_signal(_warn=self.idx == 0)
self.ds.reset_state()
context = zmq.Context()
socket = context.socket(zmq.PUSH)
......@@ -315,8 +317,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename)
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm)
for _ in range(self.nr_proc)]
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm, idx)
for idx in range(self.nr_proc)]
self._start_processes()
......
......@@ -224,7 +224,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.hwm = hwm
def run(self):
enable_death_signal()
enable_death_signal(_warn=self.identity == b'0')
ctx = zmq.Context()
socket = ctx.socket(zmq.REP)
socket.setsockopt(zmq.IDENTITY, self.identity)
......
......@@ -356,7 +356,7 @@ class GradientPacker(object):
split_size_last = self._total_size - split_size * (self._num_split - 1)
self._split_sizes = [split_size] * (self._num_split - 1) + [split_size_last]
logger.info(
"Will pack {} gradients of total number={} into {} splits.".format(
"Will pack {} gradients of total dimension={} into {} splits.".format(
len(self._sizes), self._total_size, self._num_split))
return True
......
......@@ -173,7 +173,7 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def enable_death_signal():
def enable_death_signal(_warn=True):
"""
Set the "death signal" of the current process, so that
the current process will be cleaned with guarantee
......@@ -184,6 +184,7 @@ def enable_death_signal():
try:
import prctl # pip install python-prctl
except ImportError:
if _warn:
log_once('Install python-prctl so that processes can be cleaned with guarantee.', 'warn')
return
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment