Commit bae419ca authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] move model architectures to generalized_rcnn.py (#1163)

parent 8ce73803
......@@ -20,9 +20,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png)
On one GTX 1080Ti, the ALE version took __~2 hours__ of training to reach 21 (maximum) score on
Pong, __~10 hours__ of training to reach 400 score on Breakout.
It runs at 80 batches (~5.1k trained frames, 320 seen frames, 1.3k game frames) per second on GTX 1080Ti.
On one GTX 1080Ti,
the ALE version took
__~2 hours__ of training to reach 21 (maximum) score on Pong,
__~10 hours__ of training to reach 400 score on Breakout.
It runs at 100 batches (6.4k trained frames, 400 seen frames, 1.6k game frames) per second on GTX 1080Ti.
This is likely the fastest open source TF implementation of DQN.
## How to use
......
......@@ -4,9 +4,10 @@ This is a minimal implementation that simply contains these files:
+ dataset.py: load and evaluate COCO dataset
+ data.py: prepare data for training & inference
+ common.py: common data preparation utilities
+ basemodel.py: implement backbones
+ backbone.py: implement backbones
+ model_box.py: implement box-related symbolic functions
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast-/Mask-/Cascade-RCNN models.
+ generalized_rcnn.py: implement variants of generalized R-CNN architecture
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models.
+ train.py: main entry script
+ utils/: third-party helper functions
+ eval.py: evaluation utilities
......
# -*- coding: utf-8 -*-
# File: basemodel.py
# File: backbone.py
import numpy as np
from contextlib import ExitStack, contextmanager
......
......@@ -95,7 +95,7 @@ _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG
_C.DATA.ABSOLUTE_COORD = True
_C.DATA.NUM_WORKERS = 5 # number of data loading workers
# basemodel ----------------------
# backbone ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
_C.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCKS = [3, 4, 23, 3] # for resnet101
......
This diff is collapsed.
......@@ -10,7 +10,7 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
from basemodel import GroupNorm
from backbone import GroupNorm
from config import config as cfg
from model_box import roi_align
from model_rpn import generate_rpn_proposals, rpn_losses
......
......@@ -10,7 +10,7 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized_method
from basemodel import GroupNorm
from backbone import GroupNorm
from config import config as cfg
from model_box import decode_bbox_target, encode_bbox_target
from utils.box_ops import pairwise_iou
......
......@@ -8,7 +8,7 @@ from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from basemodel import GroupNorm
from backbone import GroupNorm
from config import config as cfg
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment