Skip to content

Latest commit

 

History

History

0928.Minimize Malware Spread II

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

English Version

题目描述

(这个问题与 尽量减少恶意软件的传播 是一样的,不同之处用粗体表示。)

在节点网络中,只有当 graph[i][j] = 1 时,每个节点 i 能够直接连接到另一个节点 j

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

我们可以从初始列表中删除一个节点,并完全移除该节点以及从该节点到任何其他节点的任何连接。如果移除这一节点将最小化 M(initial), 则返回该节点。如果有多个节点满足条件,就返回索引最小的节点。

 

示例 1:

输出:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输入:0

示例 2:

输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
输出:1

示例 3:

输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
输出:1

 

提示:

  1. 1 < graph.length = graph[0].length <= 300
  2. 0 <= graph[i][j] == graph[j][i] <= 1
  3. graph[i][i] = 1
  4. 1 <= initial.length < graph.length
  5. 0 <= initial[i] < graph.length

解法

并查集。

模板 1——朴素并查集:

# 初始化,p存储每个点的父节点
p = list(range(n))

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]

# 合并a和b所在的两个集合
p[find(a)] = find(b)

模板 2——维护 size 的并查集:

# 初始化,p存储每个点的父节点,size只有当节点是祖宗节点时才有意义,表示祖宗节点所在集合中,点的数量
p = list(range(n))
size = [1] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]

# 合并a和b所在的两个集合
if find(a) != find(b):
    size[find(b)] += size[find(a)]
    p[find(a)] = find(b)

模板 3——维护到祖宗节点距离的并查集:

# 初始化,p存储每个点的父节点,d[x]存储x到p[x]的距离
p = list(range(n))
d = [0] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        t = find(p[x])
        d[x] += d[p[x]]
        p[x] = t
    return p[x]

# 合并a和b所在的两个集合
p[find(a)] = find(b)
d[find(a)] = distance

对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。

然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。

被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。

Python3

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        p = list(range(n))
        size = [1] * n

        def find(x):
            if p[x] != x:
                p[x] = find(p[x])
            return p[x]

        clean = [True] * n
        for i in initial:
            clean[i] = False

        for i in range(n):
            if not clean[i]:
                continue
            for j in range(i + 1, n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    pa, pb = find(i), find(j)
                    if pa == pb:
                        continue
                    p[pa] = pb
                    size[pb] += size[pa]

        cnt = collections.Counter()
        mp = {}
        for i in initial:
            s = set()
            for j in range(n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    s.add(find(j))
            for e in s:
                cnt[e] += 1
            mp[i] = s

        mx = -1
        res = 0
        for i, s in mp.items():
            t = 0
            for e in s:
                if cnt[e] == 1:
                    t += size[e]
            if mx < t or (mx == t and i < res):
                mx = t
                res = i
        return res

Java

class Solution {
    private int[] p;

    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        p = new int[n];
        int[] size = new int[n];
        for (int i = 0; i < n; ++i) {
            p[i] = i;
            size[i] = 1;
        }
        boolean[] clean = new boolean[n];
        Arrays.fill(clean, true);
        for (int i : initial) {
            clean[i] = false;
        }
        for (int i = 0; i < n; ++i) {
            if (!clean[i]) {
                continue;
            }
            for (int j = i + 1; j < n; ++j) {
                if (!clean[j]) {
                    continue;
                }
                if (graph[i][j] == 1) {
                    int pa = find(i), pb = find(j);
                    if (pa == pb) {
                        continue;
                    }
                    p[pa] = pb;
                    size[pb] += size[pa];
                }
            }
        }
        int[] cnt = new int[n];
        Map<Integer, Set<Integer>> mp = new HashMap<>();
        for (int i : initial) {
            Set<Integer> s = new HashSet<>();
            for (int j = 0; j < n; ++j) {
                if (!clean[j]) {
                    continue;
                }
                if (graph[i][j] == 1) {
                    s.add(find(j));
                }
            }
            for (int e : s) {
                cnt[e] += 1;
            }
            mp.put(i, s);
        }
        int mx = -1;
        int res = 0;
        for (Map.Entry<Integer, Set<Integer>> entry : mp.entrySet()) {
            int i = entry.getKey();
            int t = 0;
            for (int e : entry.getValue()) {
                if (cnt[e] == 1) {
                    t += size[e];
                }
            }
            if (mx < t || (mx == t && i < res)) {
                mx = t;
                res = i;
            }
        }
        return res;
    }

    private int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }
}

C++

class Solution {
public:
    vector<int> p;

    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        vector<int> size(n, 1);
        for (int i = 0; i < n; ++i) p.push_back(i);
        vector<bool> clean(n, true);
        for (int i : initial) clean[i] = false;
        for (int i = 0; i < n; ++i)
        {
            if (!clean[i]) continue;
            for (int j = i + 1; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j])
                {
                    int pa = find(i), pb = find(j);
                    if (pa == pb) continue;
                    p[pa] = pb;
                    size[pb] += size[pa];
                }
            }
        }
        vector<int> cnt(n, 0);
        unordered_map<int, unordered_set<int>> mp;
        for (int i : initial)
        {
            unordered_set<int> s;
            for (int j = 0; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j]) s.insert(find(j));
            }
            for (int e : s) ++cnt[e];
            mp[i] = s;
        }
        int mx = -1;
        int res = 0;
        for (auto item : mp)
        {
            int i = item.first;
            int t = 0;
            for (int e : item.second)
            {
                if (cnt[e] == 1) t += size[e];
            }
            if (mx < t || (mx == t && i < res))
            {
                mx = t;
                res = i;
            }
        }
        return res;
    }

    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }
};

Go

var p []int

func minMalwareSpread(graph [][]int, initial []int) int {
	n := len(graph)
	p = make([]int, n)
	size := make([]int, n)
	clean := make([]bool, n)
	for i := 0; i < n; i++ {
		p[i] = i
		size[i] = 1
		clean[i] = true
	}
	for _, i := range initial {
		clean[i] = false
	}
	for i := 0; i < n; i++ {
		if !clean[i] {
			continue
		}
		for j := i + 1; j < n; j++ {
			if !clean[j] {
				continue
			}
			if graph[i][j] == 1 {
				pa, pb := find(i), find(j)
				if pa == pb {
					continue
				}
				p[pa] = pb
				size[pb] += size[pa]
			}
		}
	}
	cnt := make([]int, n)
	mp := make(map[int]map[int]bool)
	for _, i := range initial {
		s := make(map[int]bool)
		for j := 0; j < n; j++ {
			if !clean[j] {
				continue
			}
			if graph[i][j] == 1 {
				s[find(j)] = true
			}
		}
		for e, _ := range s {
			cnt[e]++
		}
		mp[i] = s
	}
	mx, res := -1, 0
	for i, s := range mp {
		t := 0
		for e, _ := range s {
			if cnt[e] == 1 {
				t += size[e]
			}
		}
		if mx < t || (mx == t && i < res) {
			mx, res = t, i
		}
	}
	return res
}

func find(x int) int {
	if p[x] != x {
		p[x] = find(p[x])
	}
	return p[x]
}

...