Commit 209da29e authored by Yuxin Wu's avatar Yuxin Wu

[DQN] fix the layout when channel>1

parent c5fd5e17
......@@ -53,7 +53,6 @@ class Model(DQNModel):
super(Model, self).__init__(IMAGE_SIZE, 1, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image):
""" image: [N, H, W, C * history] in [0,255]"""
image = image / 255.0
with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
l = (LinearWrap(image)
......
......@@ -15,10 +15,12 @@ class Model(ModelDesc):
learning_rate = 1e-3
def __init__(self, image_shape, channel, history, method, num_actions, gamma):
assert len(image_shape) == 2, image_shape
self.channel = channel
self._shape2d = image_shape
self._shape3d = image_shape + (channel, )
self._shape4d_for_prediction = (-1, ) + image_shape + (channel * history, )
self._shape2d = tuple(image_shape)
self._shape3d = self._shape2d + (channel, )
self._shape4d_for_prediction = (-1, ) + self._shape2d + (history * channel, )
self._channel = channel
self.history = history
self.method = method
......@@ -31,7 +33,7 @@ class Model(ModelDesc):
# The first h are the current state, and the last h are the next state.
return [tf.placeholder(tf.uint8,
(None,) + self._shape2d +
(self._channel * (self.history + 1),),
((self.history + 1) * self.channel,),
'comb_state'),
tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'reward'),
......@@ -43,20 +45,22 @@ class Model(ModelDesc):
@auto_reuse_variable_scope
def get_DQN_prediction(self, image):
""" image: [N, H, W, history * C] in [0,255]"""
return self._get_DQN_prediction(image)
def build_graph(self, comb_state, action, reward, isOver):
comb_state = tf.cast(comb_state, tf.float32)
comb_state = tf.reshape(comb_state, [-1] + list(self._shape3d) + [self.history + 1])
comb_state = tf.reshape(
comb_state, [-1] + list(self._shape2d) + [self.history + 1, self.channel])
state = tf.slice(comb_state, [0, 0, 0, 0, 0], [-1, -1, -1, -1, self.history])
state = tf.slice(comb_state, [0, 0, 0, 0, 0], [-1, -1, -1, self.history, -1])
state = tf.reshape(state, self._shape4d_for_prediction, name='state')
self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training:
return
reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 0, 1], [-1, -1, -1, -1, self.history], name='next_state')
next_state = tf.slice(comb_state, [0, 0, 0, 1, 0], [-1, -1, -1, self.history, -1], name='next_state')
next_state = tf.reshape(next_state, self._shape4d_for_prediction)
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
......
......@@ -27,8 +27,11 @@ class MapState(gym.ObservationWrapper):
class FrameStack(gym.Wrapper):
"""
Buffer observations and stack across channels (last axis).
The output observation has shape (H, W, History * Channel)
"""
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
......
......@@ -25,7 +25,8 @@ class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self._state_transpose = list(range(1, len(state_shape) + 1)) + [0]
assert len(state_shape) == 3, state_shape
# self._state_transpose = list(range(1, len(state_shape) + 1)) + [0]
self._channel = state_shape[2] if len(state_shape) == 3 else 1
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
self.history_len = int(history_len)
......@@ -57,7 +58,7 @@ class ReplayMemory(object):
self._hist.append(exp)
def recent_state(self):
""" return a list of (hist_len-1,) + STATE_SIZE """
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst = list(self._hist)
states = [np.zeros(self.state_shape, dtype='uint8')] * (self._hist.maxlen - len(lst))
states.extend([k.state for k in lst])
......@@ -65,7 +66,7 @@ class ReplayMemory(object):
def sample(self, idx):
""" return a tuple of (s,r,a,o),
where s is of shape [H, W, channel * (hist_len+1)]"""
where s is of shape [H, W, (hist_len+1) * channel]"""
idx = (self._curr_pos + idx) % self._curr_size
k = self.history_len + 1
if idx + k <= self._curr_size:
......@@ -84,14 +85,14 @@ class ReplayMemory(object):
# the next_state is a different episode if current_state.isOver==True
def _pad_sample(self, state, reward, action, isOver):
# state: Hist+1,H,W,C
for k in range(self.history_len - 2, -1, -1):
if isOver[k]:
state = copy.deepcopy(state)
state[:k + 1].fill(0)
break
# move the first dim to the last
state = state.transpose(*self._state_transpose)
state = state.reshape(self._shape3d)
state = state.transpose(1, 2, 0, 3).reshape(self._shape3d)
return (state, reward[-2], action[-2], isOver[-2])
def _slice(self, arr, start, end):
......@@ -202,10 +203,11 @@ class ExpReplay(DataFlow, Callback):
# build a history state
history = self.mem.recent_state()
history.append(old_s)
history = np.concatenate(history, axis=-1)
history = np.concatenate(history, axis=-1) # H,W,HistxC
history = np.expand_dims(history, axis=0)
# assume batched network
q_values = self.predictor(np.expand_dims(history, 0))[0][0] # this is the bottleneck
q_values = self.predictor(history)[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
......
......@@ -52,18 +52,21 @@ Recommended configurations are listed in the table below.
The code is only valid for training with 1, 2, 4 or >=8 GPUs.
Not training with 8 GPUs may result in different performance from the table below.
### Inference:
To predict on an image (and show output in a window):
```
./train.py --predict input.jpg --load /path/to/model --config SAME-AS-TRAINING
```
Evaluate the performance of a model on COCO.
(Several trained models can be downloaded in [model zoo](http://models.tensorpack.com/FasterRCNN)):
To Evaluate the performance of a model on COCO:
```
./train.py --evaluate output.json --load /path/to/COCO-R50C4-MaskRCNN-Standard.npz \
--config MODE_MASK=True DATA.BASEDIR=/path/to/COCO/DIR
--config SAME-AS-TRAINING
```
Evaluation or prediction will need the same `--config` used during training.
Several trained models can be downloaded in the table below. Evaluation and
prediction will need to be run with the corresponding training configs.
## Results
......
......@@ -99,7 +99,7 @@ _C.TRAIN.BASE_LR = 1e-2 # defined for a total batch size of 8. Otherwise it wil
_C.TRAIN.WARMUP = 1000 # in terms of iterations. This is not affected by #GPUs
_C.TRAIN.STEPS_PER_EPOCH = 500
# Schedule means "steps" only when total batch size is 8.
# LR_SCHEDULE means "steps" only when total batch size is 8.
# Otherwise the actual steps to decrease learning rate are computed from the schedule.
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
......
......@@ -49,7 +49,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
Returns:
sampled_boxes: tx4 floatbox, the rois
sampled_labels: t labels, in [0, #class-1]. Positive means foreground.
sampled_labels: t int64 labels, in [0, #class-1]. Positive means foreground.
fg_inds_wrt_gt: #fg indices, each in range [0, m-1].
It contains the matching GT of each foreground roi.
"""
......
......@@ -17,11 +17,11 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
"""
Args:
mask_logits: #fg x #category xhxw
fg_labels: #fg, in 1~#class
fg_labels: #fg, in 1~#class, int64
fg_target_masks: #fgxhxw, int
"""
num_fg = tf.size(fg_labels)
indices = tf.stack([tf.range(num_fg), tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
num_fg = tf.size(fg_labels, out_type=tf.int64)
indices = tf.stack([tf.range(num_fg), fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fgxhxw
mask_probs = tf.sigmoid(mask_logits)
......
......@@ -620,7 +620,6 @@ if __name__ == '__main__':
session_init=session_init,
)
if is_horovod:
# horovod mode has the best speed for this model
trainer = HorovodTrainer(average=False)
else:
# nccl mode has better speed than cpu mode
......
......@@ -303,7 +303,7 @@ class MapDataComponent(MapData):
r = self._func(dp[self._index])
if r is None:
return None
dp = copy(dp) # shallow copy to avoid modifying the list
dp = copy(dp) # shallow copy to avoid modifying the datapoint
dp[self._index] = r
return dp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment