Commit b5e751b6 authored by Yuxin Wu's avatar Yuxin Wu

fix upsample

parent 6aa4328b
......@@ -25,7 +25,7 @@ WARMUP = 1000 # in steps
STEPS_PER_EPOCH = 500
LR_SCHEDULE = [150000, 230000, 280000]
LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
#LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
SHORT_EDGE_SIZE = 800
......@@ -75,5 +75,5 @@ MODE_FPN = True
FPN_NUM_CHANNEL = 256
FASTRCNN_FC_HEAD_DIM = 1024
FPN_RESOLUTION_REQUIREMENT = 32
TRAIN_FPN_NMS_TOPK = 2048
TEST_FPN_NMS_TOPK = 1024
TRAIN_FPN_NMS_TOPK = 2000
TEST_FPN_NMS_TOPK = 1000
......@@ -8,7 +8,7 @@ import itertools
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
imgaug, TestDataSpeed, PrefetchDataZMQ, MapData,
imgaug, TestDataSpeed, PrefetchDataZMQ, MapData, MultiProcessMapDataZMQ,
MapDataComponent, DataFromList)
# import tensorpack.utils.viz as tpviz
......@@ -341,8 +341,8 @@ def get_train_dataflow(add_mask=False):
# tpviz.interactive_imshow(viz)
return ret
ds = MapData(ds, preprocess)
ds = PrefetchDataZMQ(ds, 3)
ds = MultiProcessMapDataZMQ(ds, 5, preprocess)
#ds = PrefetchDataZMQ(ds, 3)
return ds
......
......@@ -578,7 +578,12 @@ def fpn_model(features):
def upsample2x(name, x):
# TODO may not be optimal in speed or math
return FixedUnPooling(name, x, 2, data_format='channels_first')
with tf.name_scope(name):
shape2d = tf.shape(x)[2:]
x = tf.transpose(x, [0, 2, 3, 1])
x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True)
x = tf.transpose(x, [0, 3, 1, 2])
return x
with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True,
......@@ -590,7 +595,7 @@ def fpn_model(features):
if idx == 0:
lat_sum_5432.append(lat)
else:
lat = lat + upsample2x('upsample_c{}'.format(5 - idx), lat_sum_5432[-1])
lat = lat + upsample2x('upsample_lat{}'.format(6 - idx), lat_sum_5432[-1])
lat_sum_5432.append(lat)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment