Skip to content

Commit b5fb2b7

Browse files
authored
[feature]: support mmdet models config (#25)
* support mmdet models * add mmlab_models_usage_guide.md * remove tools/test.py
1 parent 10266f5 commit b5fb2b7

File tree

14 files changed

+668
-197
lines changed

14 files changed

+668
-197
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# model settings
2+
model = dict(
3+
type='MaskRCNN',
4+
# EasyCV backbone
5+
backbone=dict(
6+
type='ResNet',
7+
depth=50,
8+
num_stages=4,
9+
out_indices=(1, 2, 3, 4),
10+
frozen_stages=1,
11+
norm_cfg=dict(type='BN', requires_grad=True),
12+
norm_eval=True),
13+
# mmdet backbone
14+
# backbone=dict(
15+
# type='ResNet',
16+
# depth=50,
17+
# num_stages=4,
18+
# out_indices=(0, 1, 2, 3),
19+
# frozen_stages=1,
20+
# norm_cfg=dict(type='BN', requires_grad=True),
21+
# norm_eval=True,
22+
# style='pytorch',
23+
# init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
24+
neck=dict(
25+
type='FPN',
26+
in_channels=[256, 512, 1024, 2048],
27+
out_channels=256,
28+
num_outs=5),
29+
rpn_head=dict(
30+
type='RPNHead',
31+
in_channels=256,
32+
feat_channels=256,
33+
anchor_generator=dict(
34+
type='AnchorGenerator',
35+
scales=[8],
36+
ratios=[0.5, 1.0, 2.0],
37+
strides=[4, 8, 16, 32, 64]),
38+
bbox_coder=dict(
39+
type='DeltaXYWHBBoxCoder',
40+
target_means=[.0, .0, .0, .0],
41+
target_stds=[1.0, 1.0, 1.0, 1.0]),
42+
loss_cls=dict(
43+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
44+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
45+
roi_head=dict(
46+
type='StandardRoIHead',
47+
bbox_roi_extractor=dict(
48+
type='SingleRoIExtractor',
49+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
50+
out_channels=256,
51+
featmap_strides=[4, 8, 16, 32]),
52+
bbox_head=dict(
53+
type='Shared2FCBBoxHead',
54+
in_channels=256,
55+
fc_out_channels=1024,
56+
roi_feat_size=7,
57+
num_classes=80,
58+
bbox_coder=dict(
59+
type='DeltaXYWHBBoxCoder',
60+
target_means=[0., 0., 0., 0.],
61+
target_stds=[0.1, 0.1, 0.2, 0.2]),
62+
reg_class_agnostic=False,
63+
loss_cls=dict(
64+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
66+
mask_roi_extractor=dict(
67+
type='SingleRoIExtractor',
68+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
69+
out_channels=256,
70+
featmap_strides=[4, 8, 16, 32]),
71+
mask_head=dict(
72+
type='FCNMaskHead',
73+
num_convs=4,
74+
in_channels=256,
75+
conv_out_channels=256,
76+
num_classes=80,
77+
loss_mask=dict(
78+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
79+
# model training and testing settings
80+
train_cfg=dict(
81+
rpn=dict(
82+
assigner=dict(
83+
type='MaxIoUAssigner',
84+
pos_iou_thr=0.7,
85+
neg_iou_thr=0.3,
86+
min_pos_iou=0.3,
87+
match_low_quality=True,
88+
ignore_iof_thr=-1),
89+
sampler=dict(
90+
type='RandomSampler',
91+
num=256,
92+
pos_fraction=0.5,
93+
neg_pos_ub=-1,
94+
add_gt_as_proposals=False),
95+
allowed_border=-1,
96+
pos_weight=-1,
97+
debug=False),
98+
rpn_proposal=dict(
99+
nms_pre=2000,
100+
max_per_img=1000,
101+
nms=dict(type='nms', iou_threshold=0.7),
102+
min_bbox_size=0),
103+
rcnn=dict(
104+
assigner=dict(
105+
type='MaxIoUAssigner',
106+
pos_iou_thr=0.5,
107+
neg_iou_thr=0.5,
108+
min_pos_iou=0.5,
109+
match_low_quality=True,
110+
ignore_iof_thr=-1),
111+
sampler=dict(
112+
type='RandomSampler',
113+
num=512,
114+
pos_fraction=0.25,
115+
neg_pos_ub=-1,
116+
add_gt_as_proposals=True),
117+
mask_size=28,
118+
pos_weight=-1,
119+
debug=False)),
120+
test_cfg=dict(
121+
rpn=dict(
122+
nms_pre=1000,
123+
max_per_img=1000,
124+
nms=dict(type='nms', iou_threshold=0.7),
125+
min_bbox_size=0),
126+
rcnn=dict(
127+
score_thr=0.05,
128+
nms=dict(type='nms', iou_threshold=0.5),
129+
max_per_img=100,
130+
mask_thr_binary=0.5)))
131+
132+
mmlab_modules = [
133+
dict(type='mmdet', name='MaskRCNN', module='model'),
134+
# dict(type=MMDET, name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
135+
dict(type='mmdet', name='FPN', module='neck'),
136+
dict(type='mmdet', name='RPNHead', module='head'),
137+
dict(type='mmdet', name='StandardRoIHead', module='head'),
138+
]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Use mmdetection's models in EasyCV
2+
3+
For details of mmdetection, please refer to :https://github.com/open-mmlab/mmdetection
4+
5+
**We only support mmdet's models and do not support other series in mmlab and other modules such as transforms, dataset api, etc. are not supported either.**
6+
7+
The models module of EasyCV is divided into four modules: `backbone`, `head`, `neck`, and `model`.
8+
9+
So we support the models combination of EasyCV and mmdet from these four levels.
10+
11+
**We will not adapt the other apis involved in these four levels modules, we package the entire api for use.**
12+
13+
> **Note: **
14+
>
15+
> **If you want to combine the models part of mmdet and easycv, please pay attention to the compatibility between the apis, we do not guarantee that the api of EasyCV and mmdet are compatible.**
16+
17+
Take the `MaskRCNN` model as an example, please refer to [mask_rcnn_r50_fpn.py](https://github.com/alibaba/EasyCV/tree/master/configs/detection/mask_rcnn/mask_rcnn_r50_fpn.py). Except for the backbone, other parts in this model are all mmdet apis.
18+
19+
The framework of `MaskRCNN` can be divided into the following parts from the `backbone`, `head`, `neck`, and `model` levels:
20+
21+
- backbone: `ResNet`
22+
23+
- head:`RPNHead`, `StandardRoIHead`
24+
25+
- neck: `FPN`
26+
27+
- model: `MaskRCNN`
28+
29+
The configuration adapt for mmdet is as follows:
30+
31+
```python
32+
mmlab_modules = [
33+
dict(type='mmdet', name='MaskRCNN', module='model'),
34+
# dict(type='mmdet', name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
35+
dict(type='mmdet', name='FPN', module='neck'),
36+
dict(type='mmdet', name='RPNHead', module='head'),
37+
dict(type='mmdet', name='StandardRoIHead', module='head'),
38+
]
39+
```
40+
41+
> Parameters:
42+
>
43+
> - type: the name of the open source, only `mmdet` is supported
44+
> - name: the name of api
45+
> - Module: The name of the module to which the api belongs, only `backbone`,`head`,`neck`,`model` are supported.
46+
47+
In this configuration , the `head`, `neck`, and `model` parts specify the type as `mmdet`, except for `backbone`.
48+
49+
**No configured api will use the EasyCV api by default, , such as backbone (ResNet).**
50+
51+
**For other explicitly configured type as `mmdet`, we will use the mmdet api.**
52+
53+
Which is:
54+
55+
- `MaskRCNN`(model): Use mmdet's `MaskRCNN` api.
56+
57+
- `ResNet`(backbone): Use EasyCV's `ResNet` api.
58+
59+
> Note that the parameters of the `ResNet`of mmdet and EasyCV are different. Please pay attention to it!.
60+
61+
- `RPNHead`(head): Use mmdet's `RPNHead` api.
62+
63+
> Note that all the other apis configured in `RPNHead`, such as `AnchorGenerator`, `DeltaXYWHBBoxCoder`, etc., are all mmdet's apis, because we package the entire api for use.
64+
65+
- `StandardRoIHead`(head): Use mmdet's `StandardRoIHead` api.
66+
67+
> Note that all the other apis configured in `StandardRoIHead`, such as `SingleRoIExtractor`, `SingleRoIExtractor`, etc., are all mmdet's apis, because we package the entire api for use.
68+
69+
- `FPN`(neck): Use mmdet's `FPN` api.

easycv/apis/test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ def single_gpu_test(model, data_loader, mode='test', use_fp16=False, **kwargs):
119119
results[k].append(v)
120120

121121
if 'img_metas' in data:
122-
batch_size = len(data['img_metas'].data[0])
122+
if isinstance(data['img_metas'], list):
123+
batch_size = len(data['img_metas'][0].data[0])
124+
else:
125+
batch_size = len(data['img_metas'].data[0])
126+
123127
else:
124-
batch_size = data['img'].size(0)
128+
if isinstance(data['img'], list):
129+
batch_size = data['img'][0].size(0)
130+
else:
131+
batch_size = data['img'].size(0)
125132

126133
for _ in range(batch_size):
127134
prog_bar.update()

easycv/apis/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def train_model(model,
151151
if validate:
152152
interval = cfg.eval_config.pop('interval', 1)
153153
for idx, eval_pipe in enumerate(cfg.eval_pipelines):
154-
data = eval_pipe.data
154+
data = eval_pipe.get('data', None) or cfg.data.val
155155
dist_eval = eval_pipe.get('dist_eval', False)
156156

157157
evaluator_cfg = eval_pipe.evaluators[0]

easycv/core/evaluation/coco_evaluation.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,8 @@ def add_single_detected_image_info(self, image_id, detections_dict):
473473
groundtruth_masks_shape = self._image_id_to_mask_shape_map[image_id]
474474
detection_masks = detections_dict[
475475
standard_fields.DetectionResultFields.detection_masks]
476-
if groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
476+
if len(detection_masks
477+
) and groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
477478
raise ValueError(
478479
'Spatial shape of groundtruth masks and detection masks '
479480
'are incompatible: {} vs {}'.format(groundtruth_masks_shape,
@@ -601,6 +602,9 @@ def _evaluate_impl(self, prediction_dict, groundtruth_dict):
601602
else:
602603
groundtruth_is_crowd = groundtruth_is_crowd_list[idx]
603604

605+
gt_masks = np.array(
606+
[self._ann_to_mask(mask, height, width) for mask in gt_masks],
607+
dtype=np.uint8)
604608
groundtruth_dict = {
605609
'groundtruth_boxes': gt_boxes_absolute,
606610
'groundtruth_instance_masks': gt_masks,
@@ -609,6 +613,11 @@ def _evaluate_impl(self, prediction_dict, groundtruth_dict):
609613
}
610614
self.add_single_ground_truth_image_info(image_id, groundtruth_dict)
611615

616+
detection_masks = np.array([
617+
self._ann_to_mask(mask, height, width)
618+
for mask in detection_masks
619+
],
620+
dtype=np.uint8)
612621
# add detection info
613622
detection_dict = {
614623
'detection_masks': detection_masks,
@@ -621,6 +630,27 @@ def _evaluate_impl(self, prediction_dict, groundtruth_dict):
621630
self.clear()
622631
return eval_dict
623632

633+
def _ann_to_mask(self, segmentation, height, width):
634+
from xtcocotools import mask as maskUtils
635+
segm = segmentation
636+
h = height
637+
w = width
638+
639+
if type(segm) == list:
640+
# polygon -- a single object might consist of multiple parts
641+
# we merge all parts into one mask rle code
642+
rles = maskUtils.frPyObjects(segm, h, w)
643+
rle = maskUtils.merge(rles)
644+
elif type(segm['counts']) == list:
645+
# uncompressed RLE
646+
rle = maskUtils.frPyObjects(segm, h, w)
647+
else:
648+
# rle
649+
rle = segm
650+
651+
m = maskUtils.decode(rle)
652+
return m
653+
624654

625655
@EVALUATORS.register_module
626656
class CoCoPoseTopDownEvaluator(Evaluator):

easycv/datasets/detection/pipelines/mm_transforms.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,20 +1644,21 @@ def _poly2mask(self, mask_ann, img_h, img_w):
16441644
Returns:
16451645
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
16461646
"""
1647-
raise NotImplementedError
1648-
# if isinstance(mask_ann, list):
1649-
# # polygon -- a single object might consist of multiple parts
1650-
# # we merge all parts into one mask rle code
1651-
# rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
1652-
# rle = maskUtils.merge(rles)
1653-
# elif isinstance(mask_ann['counts'], list):
1654-
# # uncompressed RLE
1655-
# rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
1656-
# else:
1657-
# # rle
1658-
# rle = mask_ann
1659-
# mask = maskUtils.decode(rle)
1660-
# return mask
1647+
import xtcocotools.mask as maskUtils
1648+
1649+
if isinstance(mask_ann, list):
1650+
# polygon -- a single object might consist of multiple parts
1651+
# we merge all parts into one mask rle code
1652+
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
1653+
rle = maskUtils.merge(rles)
1654+
elif isinstance(mask_ann['counts'], list):
1655+
# uncompressed RLE
1656+
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
1657+
else:
1658+
# rle
1659+
rle = mask_ann
1660+
mask = maskUtils.decode(rle)
1661+
return mask
16611662

16621663
def process_polygons(self, polygons):
16631664
"""Convert polygons to list of ndarray and filter invalid polygons.
@@ -1687,20 +1688,20 @@ def _load_masks(self, results):
16871688
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
16881689
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
16891690
"""
1690-
raise NotImplementedError
1691-
1692-
# h, w = results['img_info']['height'], results['img_info']['width']
1693-
# gt_masks = results['ann_info']['masks']
1694-
# if self.poly2mask:
1695-
# gt_masks = BitmapMasks(
1696-
# [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
1697-
# else:
1698-
# gt_masks = PolygonMasks(
1699-
# [self.process_polygons(polygons) for polygons in gt_masks], h,
1700-
# w)
1701-
# results['gt_masks'] = gt_masks
1702-
# results['mask_fields'].append('gt_masks')
1703-
# return results
1691+
from mmdet.core import BitmapMasks, PolygonMasks
1692+
1693+
h, w = results['img_info']['height'], results['img_info']['width']
1694+
gt_masks = results['ann_info']['masks']
1695+
if self.poly2mask:
1696+
gt_masks = BitmapMasks(
1697+
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
1698+
else:
1699+
gt_masks = PolygonMasks(
1700+
[self.process_polygons(polygons) for polygons in gt_masks], h,
1701+
w)
1702+
results['gt_masks'] = gt_masks
1703+
results['mask_fields'].append('gt_masks')
1704+
return results
17041705

17051706
def _load_semantic_seg(self, results):
17061707
"""Private function to load semantic segmentation annotations.

easycv/datasets/detection/raw.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def evaluate(self, results, evaluators=None, logger=None):
7070
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
7171
for idx in range(len(results['img_metas']))
7272
]
73+
groundtruth_dict['groundtruth_instance_masks'] = [
74+
self.data_source.get_ann_info(idx).get('masks', None)
75+
for idx in range(len(results['img_metas']))
76+
]
7377

7478
for evaluator in evaluators:
7579
eval_result.update(evaluator.evaluate(results, groundtruth_dict))

0 commit comments

Comments
 (0)