(这个问题与 尽量减少恶意软件的传播 是一样的,不同之处用粗体表示。)
在节点网络中,只有当 graph[i][j] = 1
时,每个节点 i
能够直接连接到另一个节点 j
。
一些节点 initial
最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。
假设 M(initial)
是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。
我们可以从初始列表中删除一个节点,并完全移除该节点以及从该节点到任何其他节点的任何连接。如果移除这一节点将最小化 M(initial)
, 则返回该节点。如果有多个节点满足条件,就返回索引最小的节点。
示例 1:
输出:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1] 输入:0
示例 2:
输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1] 输出:1
示例 3:
输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1] 输出:1
提示:
1 < graph.length = graph[0].length <= 300
0 <= graph[i][j] == graph[j][i] <= 1
graph[i][i] = 1
1 <= initial.length < graph.length
0 <= initial[i] < graph.length
并查集。
模板 1——朴素并查集:
# 初始化,p存储每个点的父节点
p = list(range(n))
# 返回x的祖宗节点
def find(x):
if p[x] != x:
# 路径压缩
p[x] = find(p[x])
return p[x]
# 合并a和b所在的两个集合
p[find(a)] = find(b)
模板 2——维护 size 的并查集:
# 初始化,p存储每个点的父节点,size只有当节点是祖宗节点时才有意义,表示祖宗节点所在集合中,点的数量
p = list(range(n))
size = [1] * n
# 返回x的祖宗节点
def find(x):
if p[x] != x:
# 路径压缩
p[x] = find(p[x])
return p[x]
# 合并a和b所在的两个集合
if find(a) != find(b):
size[find(b)] += size[find(a)]
p[find(a)] = find(b)
模板 3——维护到祖宗节点距离的并查集:
# 初始化,p存储每个点的父节点,d[x]存储x到p[x]的距离
p = list(range(n))
d = [0] * n
# 返回x的祖宗节点
def find(x):
if p[x] != x:
t = find(p[x])
d[x] += d[p[x]]
p[x] = t
return p[x]
# 合并a和b所在的两个集合
p[find(a)] = find(b)
d[find(a)] = distance
对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。
然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。
被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。
class Solution:
def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
n = len(graph)
p = list(range(n))
size = [1] * n
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
clean = [True] * n
for i in initial:
clean[i] = False
for i in range(n):
if not clean[i]:
continue
for j in range(i + 1, n):
if not clean[j]:
continue
if graph[i][j] == 1:
pa, pb = find(i), find(j)
if pa == pb:
continue
p[pa] = pb
size[pb] += size[pa]
cnt = Counter()
mp = {}
for i in initial:
s = set()
for j in range(n):
if not clean[j]:
continue
if graph[i][j] == 1:
s.add(find(j))
for e in s:
cnt[e] += 1
mp[i] = s
mx = -1
res = 0
for i, s in mp.items():
t = 0
for e in s:
if cnt[e] == 1:
t += size[e]
if mx < t or (mx == t and i < res):
mx = t
res = i
return res
class Solution {
private int[] p;
public int minMalwareSpread(int[][] graph, int[] initial) {
int n = graph.length;
p = new int[n];
int[] size = new int[n];
for (int i = 0; i < n; ++i) {
p[i] = i;
size[i] = 1;
}
boolean[] clean = new boolean[n];
Arrays.fill(clean, true);
for (int i : initial) {
clean[i] = false;
}
for (int i = 0; i < n; ++i) {
if (!clean[i]) {
continue;
}
for (int j = i + 1; j < n; ++j) {
if (!clean[j]) {
continue;
}
if (graph[i][j] == 1) {
int pa = find(i), pb = find(j);
if (pa == pb) {
continue;
}
p[pa] = pb;
size[pb] += size[pa];
}
}
}
int[] cnt = new int[n];
Map<Integer, Set<Integer>> mp = new HashMap<>();
for (int i : initial) {
Set<Integer> s = new HashSet<>();
for (int j = 0; j < n; ++j) {
if (!clean[j]) {
continue;
}
if (graph[i][j] == 1) {
s.add(find(j));
}
}
for (int e : s) {
cnt[e] += 1;
}
mp.put(i, s);
}
int mx = -1;
int res = 0;
for (Map.Entry<Integer, Set<Integer>> entry : mp.entrySet()) {
int i = entry.getKey();
int t = 0;
for (int e : entry.getValue()) {
if (cnt[e] == 1) {
t += size[e];
}
}
if (mx < t || (mx == t && i < res)) {
mx = t;
res = i;
}
}
return res;
}
private int find(int x) {
if (p[x] != x) {
p[x] = find(p[x]);
}
return p[x];
}
}
class Solution {
public:
vector<int> p;
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
vector<int> size(n, 1);
for (int i = 0; i < n; ++i) p.push_back(i);
vector<bool> clean(n, true);
for (int i : initial) clean[i] = false;
for (int i = 0; i < n; ++i)
{
if (!clean[i]) continue;
for (int j = i + 1; j < n; ++j)
{
if (!clean[j]) continue;
if (graph[i][j])
{
int pa = find(i), pb = find(j);
if (pa == pb) continue;
p[pa] = pb;
size[pb] += size[pa];
}
}
}
vector<int> cnt(n, 0);
unordered_map<int, unordered_set<int>> mp;
for (int i : initial)
{
unordered_set<int> s;
for (int j = 0; j < n; ++j)
{
if (!clean[j]) continue;
if (graph[i][j]) s.insert(find(j));
}
for (int e : s) ++cnt[e];
mp[i] = s;
}
int mx = -1;
int res = 0;
for (auto item : mp)
{
int i = item.first;
int t = 0;
for (int e : item.second)
{
if (cnt[e] == 1) t += size[e];
}
if (mx < t || (mx == t && i < res))
{
mx = t;
res = i;
}
}
return res;
}
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
};
var p []int
func minMalwareSpread(graph [][]int, initial []int) int {
n := len(graph)
p = make([]int, n)
size := make([]int, n)
clean := make([]bool, n)
for i := 0; i < n; i++ {
p[i] = i
size[i] = 1
clean[i] = true
}
for _, i := range initial {
clean[i] = false
}
for i := 0; i < n; i++ {
if !clean[i] {
continue
}
for j := i + 1; j < n; j++ {
if !clean[j] {
continue
}
if graph[i][j] == 1 {
pa, pb := find(i), find(j)
if pa == pb {
continue
}
p[pa] = pb
size[pb] += size[pa]
}
}
}
cnt := make([]int, n)
mp := make(map[int]map[int]bool)
for _, i := range initial {
s := make(map[int]bool)
for j := 0; j < n; j++ {
if !clean[j] {
continue
}
if graph[i][j] == 1 {
s[find(j)] = true
}
}
for e, _ := range s {
cnt[e]++
}
mp[i] = s
}
mx, res := -1, 0
for i, s := range mp {
t := 0
for e, _ := range s {
if cnt[e] == 1 {
t += size[e]
}
}
if mx < t || (mx == t && i < res) {
mx, res = t, i
}
}
return res
}
func find(x int) int {
if p[x] != x {
p[x] = find(p[x])
}
return p[x]
}