55
55
56
56
<!-- 这里可写通用的实现逻辑 -->
57
57
58
- 并查集。
58
+ 逆向思维并查集。对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。
59
+
60
+ 然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。
61
+
62
+ > 被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。
63
+
64
+ 以下是并查集的几个常用模板。
59
65
60
66
模板 1——朴素并查集:
61
67
@@ -114,12 +120,6 @@ p[find(a)] = find(b)
114
120
d[find(a)] = distance
115
121
```
116
122
117
- 对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。
118
-
119
- 然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。
120
-
121
- > 被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。
122
-
123
123
<!-- tabs:start -->
124
124
125
125
### ** Python3**
@@ -129,56 +129,43 @@ d[find(a)] = distance
129
129
``` python
130
130
class Solution :
131
131
def minMalwareSpread (self , graph : List[List[int ]], initial : List[int ]) -> int :
132
- n = len (graph)
133
- p = list (range (n))
134
- size = [1 ] * n
135
-
136
132
def find (x ):
137
133
if p[x] != x:
138
134
p[x] = find(p[x])
139
135
return p[x]
140
136
137
+ def union (a , b ):
138
+ pa, pb = find(a), find(b)
139
+ if pa != pb:
140
+ size[pb] += size[pa]
141
+ p[pa] = pb
142
+
143
+ n = len (graph)
144
+ p = list (range (n))
145
+ size = [1 ] * n
141
146
clean = [True ] * n
142
147
for i in initial:
143
148
clean[i] = False
144
-
145
149
for i in range (n):
146
150
if not clean[i]:
147
151
continue
148
152
for j in range (i + 1 , n):
149
- if not clean[j]:
150
- continue
151
- if graph[i][j] == 1 :
152
- pa, pb = find(i), find(j)
153
- if pa == pb:
154
- continue
155
- p[pa] = pb
156
- size[pb] += size[pa]
157
-
153
+ if clean[j] and graph[i][j] == 1 :
154
+ union(i, j)
158
155
cnt = Counter()
159
156
mp = {}
160
157
for i in initial:
161
- s = set ()
162
- for j in range (n):
163
- if not clean[j]:
164
- continue
165
- if graph[i][j] == 1 :
166
- s.add(find(j))
167
- for e in s:
168
- cnt[e] += 1
158
+ s = {find(j) for j in range (n) if clean[j] and graph[i][j] == 1 }
159
+ for root in s:
160
+ cnt[root] += 1
169
161
mp[i] = s
170
162
171
- mx = - 1
172
- res = 0
163
+ mx, ans = - 1 , 0
173
164
for i, s in mp.items():
174
- t = 0
175
- for e in s:
176
- if cnt[e] == 1 :
177
- t += size[e]
178
- if mx < t or (mx == t and i < res):
179
- mx = t
180
- res = i
181
- return res
165
+ t = sum (size[root] for root in s if cnt[root] == 1 )
166
+ if mx < t or mx == t and i < ans:
167
+ mx, ans = t, i
168
+ return ans
182
169
```
183
170
184
171
### ** Java**
@@ -188,11 +175,12 @@ class Solution:
188
175
``` java
189
176
class Solution {
190
177
private int [] p;
178
+ private int [] size;
191
179
192
180
public int minMalwareSpread (int [][] graph , int [] initial ) {
193
181
int n = graph. length;
194
182
p = new int [n];
195
- int [] size = new int [n];
183
+ size = new int [n];
196
184
for (int i = 0 ; i < n; ++ i) {
197
185
p[i] = i;
198
186
size[i] = 1 ;
@@ -207,16 +195,8 @@ class Solution {
207
195
continue ;
208
196
}
209
197
for (int j = i + 1 ; j < n; ++ j) {
210
- if (! clean[j]) {
211
- continue ;
212
- }
213
- if (graph[i][j] == 1 ) {
214
- int pa = find(i), pb = find(j);
215
- if (pa == pb) {
216
- continue ;
217
- }
218
- p[pa] = pb;
219
- size[pb] += size[pa];
198
+ if (clean[j] && graph[i][j] == 1 ) {
199
+ union(i, j);
220
200
}
221
201
}
222
202
}
@@ -225,34 +205,31 @@ class Solution {
225
205
for (int i : initial) {
226
206
Set<Integer > s = new HashSet<> ();
227
207
for (int j = 0 ; j < n; ++ j) {
228
- if (! clean[j]) {
229
- continue ;
230
- }
231
- if (graph[i][j] == 1 ) {
208
+ if (clean[j] && graph[i][j] == 1 ) {
232
209
s. add(find(j));
233
210
}
234
211
}
235
- for (int e : s) {
236
- cnt[e ] += 1 ;
212
+ for (int root : s) {
213
+ cnt[root ] += 1 ;
237
214
}
238
215
mp. put(i, s);
239
216
}
240
217
int mx = - 1 ;
241
- int res = 0 ;
218
+ int ans = 0 ;
242
219
for (Map . Entry<Integer , Set<Integer > > entry : mp. entrySet()) {
243
220
int i = entry. getKey();
244
221
int t = 0 ;
245
- for (int e : entry. getValue()) {
246
- if (cnt[e ] == 1 ) {
247
- t += size[e ];
222
+ for (int root : entry. getValue()) {
223
+ if (cnt[root ] == 1 ) {
224
+ t += size[root ];
248
225
}
249
226
}
250
- if (mx < t || (mx == t && i < res )) {
227
+ if (mx < t || (mx == t && i < ans )) {
251
228
mx = t;
252
- res = i;
229
+ ans = i;
253
230
}
254
231
}
255
- return res ;
232
+ return ans ;
256
233
}
257
234
258
235
private int find (int x ) {
@@ -261,6 +238,15 @@ class Solution {
261
238
}
262
239
return p[x];
263
240
}
241
+
242
+ private void union (int a , int b ) {
243
+ int pa = find(a);
244
+ int pb = find(b);
245
+ if (pa != pb) {
246
+ size[pb] += size[pa];
247
+ p[pa] = pb;
248
+ }
249
+ }
264
250
}
265
251
```
266
252
@@ -270,100 +256,101 @@ class Solution {
270
256
class Solution {
271
257
public:
272
258
vector<int > p;
259
+ vector<int > size;
273
260
274
261
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
275
262
int n = graph.size();
276
- vector<int> size(n, 1);
277
- for (int i = 0; i < n; ++i) p.push_back(i);
263
+ p.resize(n);
264
+ size.resize(n);
265
+ for (int i = 0; i < n; ++i) p[i] = i;
266
+ fill(size.begin(), size.end(), 1);
278
267
vector<bool> clean(n, true);
279
268
for (int i : initial) clean[i] = false;
280
269
for (int i = 0; i < n; ++i)
281
270
{
282
271
if (!clean[i]) continue;
283
272
for (int j = i + 1; j < n; ++j)
284
- {
285
- if (!clean[j]) continue;
286
- if (graph[i][j])
287
- {
288
- int pa = find(i), pb = find(j);
289
- if (pa == pb) continue;
290
- p[pa] = pb;
291
- size[pb] += size[pa];
292
- }
293
- }
273
+ if (clean[j] && graph[i][j] == 1) merge(i, j);
294
274
}
295
275
vector<int > cnt (n, 0);
296
276
unordered_map<int, unordered_set<int >> mp;
297
277
for (int i : initial)
298
278
{
299
279
unordered_set<int > s;
300
280
for (int j = 0; j < n; ++j)
301
- {
302
- if (!clean[ j] ) continue;
303
- if (graph[ i] [ j ] ) s.insert(find(j));
304
- }
281
+ if (clean[ j] && graph[ i] [ j ] == 1) s.insert(find(j));
305
282
for (int e : s) ++cnt[ e] ;
306
283
mp[ i] = s;
307
284
}
308
- int mx = -1;
309
- int res = 0;
310
- for (auto item : mp)
285
+ int mx = -1, ans = 0;
286
+ for (auto& [ i, s] : mp)
311
287
{
312
- int i = item.first;
313
288
int t = 0;
314
- for (int e : item.second)
315
- {
316
- if (cnt[ e] == 1) t += size[ e] ;
317
- }
318
- if (mx < t || (mx == t && i < res))
289
+ for (int root : s)
290
+ if (cnt[ root] == 1)
291
+ t += size[ root] ;
292
+ if (mx < t || (mx == t && i < ans))
319
293
{
320
294
mx = t;
321
- res = i;
295
+ ans = i;
322
296
}
323
297
}
324
- return res ;
298
+ return ans ;
325
299
}
326
300
327
301
int find(int x) {
328
302
if (p[x] != x) p[x] = find(p[x]);
329
303
return p[x];
330
304
}
305
+
306
+ void merge(int a, int b) {
307
+ int pa = find(a), pb = find(b);
308
+ if (pa != pb)
309
+ {
310
+ size[pb] += size[pa];
311
+ p[pa] = pb;
312
+ }
313
+ }
331
314
};
332
315
```
333
316
334
317
### **Go**
335
318
336
319
```go
337
- var p []int
338
-
339
320
func minMalwareSpread(graph [][]int, initial []int) int {
340
321
n := len(graph)
341
- p = make([]int, n)
322
+ p : = make([]int, n)
342
323
size := make([]int, n)
343
324
clean := make([]bool, n)
344
325
for i := 0; i < n; i++ {
345
- p[i] = i
346
- size[i] = 1
347
- clean[i] = true
326
+ p[i], size[i], clean[i] = i, 1, true
348
327
}
349
328
for _, i := range initial {
350
329
clean[i] = false
351
330
}
331
+
332
+ var find func(x int) int
333
+ find = func(x int) int {
334
+ if p[x] != x {
335
+ p[x] = find(p[x])
336
+ }
337
+ return p[x]
338
+ }
339
+ union := func(a, b int) {
340
+ pa, pb := find(a), find(b)
341
+ if pa != pb {
342
+ size[pb] += size[pa]
343
+ p[pa] = pb
344
+ }
345
+ }
346
+
352
347
for i := 0; i < n; i++ {
353
348
if !clean[i] {
354
349
continue
355
350
}
356
351
for j := i + 1; j < n; j++ {
357
- if !clean[j] {
358
- continue
359
- }
360
- if graph[i][j] == 1 {
361
- pa, pb := find(i), find(j)
362
- if pa == pb {
363
- continue
364
- }
365
- p[pa] = pb
366
- size[pb] += size[pa]
352
+ if clean[j] && graph[i][j] == 1 {
353
+ union(i, j)
367
354
}
368
355
}
369
356
}
@@ -372,38 +359,28 @@ func minMalwareSpread(graph [][]int, initial []int) int {
372
359
for _, i := range initial {
373
360
s := make(map[int]bool)
374
361
for j := 0; j < n; j++ {
375
- if !clean[j] {
376
- continue
377
- }
378
- if graph[i][j] == 1 {
362
+ if clean[j] && graph[i][j] == 1 {
379
363
s[find(j)] = true
380
364
}
381
365
}
382
- for e , _ := range s {
383
- cnt[e ]++
366
+ for root , _ := range s {
367
+ cnt[root ]++
384
368
}
385
369
mp[i] = s
386
370
}
387
- mx, res := -1, 0
371
+ mx, ans := -1, 0
388
372
for i, s := range mp {
389
373
t := 0
390
- for e , _ := range s {
391
- if cnt[e ] == 1 {
392
- t += size[e ]
374
+ for root , _ := range s {
375
+ if cnt[root ] == 1 {
376
+ t += size[root ]
393
377
}
394
378
}
395
- if mx < t || (mx == t && i < res ) {
396
- mx, res = t, i
379
+ if mx < t || (mx == t && i < ans ) {
380
+ mx, ans = t, i
397
381
}
398
382
}
399
- return res
400
- }
401
-
402
- func find(x int) int {
403
- if p[x] != x {
404
- p[x] = find(p[x])
405
- }
406
- return p[x]
383
+ return ans
407
384
}
408
385
```
409
386
0 commit comments