-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathpoolers.py
133 lines (115 loc) · 4.45 KB
/
poolers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
from mega_core.layers import ROIAlign
from .utils import cat
class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper.
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
"""
Arguments:
k_min (int)
k_max (int)
canonical_scale (int)
canonical_level (int)
eps (float)
"""
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
self.lvl0 = canonical_level
self.eps = eps
def __call__(self, boxlists):
"""
Arguments:
boxlists (list[BoxList])
"""
# Compute level ids
s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))
# Eqn.(1) in FPN paper
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
return target_lvls.to(torch.int64) - self.k_min
class Pooler(nn.Module):
"""
Pooler for Detection with or without FPN.
It currently hard-code ROIAlign in the implementation,
but that can be made more generic later on.
Also, the requirement of passing the scales is not strictly necessary, as they
can be inferred from the size of the feature map / size of original image,
which is available thanks to the BoxList.
"""
def __init__(self, output_size, scales, sampling_ratio):
"""
Arguments:
output_size (list[tuple[int]] or list[int]): output size for the pooled region
scales (list[float]): scales for each Pooler
sampling_ratio (int): sampling ratio for ROIAlign
"""
super(Pooler, self).__init__()
poolers = []
for scale in scales:
poolers.append(
ROIAlign(
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
)
)
self.poolers = nn.ModuleList(poolers)
self.output_size = output_size
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.map_levels = LevelMapper(lvl_min, lvl_max)
def convert_to_roi_format(self, boxes):
concat_boxes = cat([b.bbox for b in boxes], dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = cat(
[
torch.full((len(b), 1), i, dtype=dtype, device=device)
for i, b in enumerate(boxes)
],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def forward(self, x, boxes):
"""
Arguments:
x (list[Tensor]): feature maps for each level
boxes (list[BoxList]): boxes to be used to perform the pooling operation.
Returns:
result (Tensor)
"""
num_levels = len(self.poolers)
rois = self.convert_to_roi_format(boxes)
if num_levels == 1:
return self.poolers[0](x[0], rois)
levels = self.map_levels(boxes)
num_rois = len(rois)
num_channels = x[0].shape[1]
output_size = self.output_size[0]
dtype, device = x[0].dtype, x[0].device
result = torch.zeros(
(num_rois, num_channels, output_size, output_size),
dtype=dtype,
device=device,
)
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)
return result
def make_pooler(cfg, head_name):
resolution = cfg.MODEL[head_name].POOLER_RESOLUTION
scales = cfg.MODEL[head_name].POOLER_SCALES
sampling_ratio = cfg.MODEL[head_name].POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
return pooler