Skip to content

Latest commit

 

History

History
416 lines (353 loc) · 10.1 KB

File metadata and controls

416 lines (353 loc) · 10.1 KB

English Version

题目描述

(这个问题与 尽量减少恶意软件的传播 是一样的,不同之处用粗体表示。)

在节点网络中,只有当 graph[i][j] = 1 时,每个节点 i 能够直接连接到另一个节点 j

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

我们可以从初始列表中删除一个节点,并完全移除该节点以及从该节点到任何其他节点的任何连接。如果移除这一节点将最小化 M(initial), 则返回该节点。如果有多个节点满足条件,就返回索引最小的节点。

 

示例 1:

输出:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输入:0

示例 2:

输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
输出:1

示例 3:

输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
输出:1

 

提示:

  1. 1 < graph.length = graph[0].length <= 300
  2. 0 <= graph[i][j] == graph[j][i] <= 1
  3. graph[i][i] = 1
  4. 1 <= initial.length < graph.length
  5. 0 <= initial[i] < graph.length

解法

并查集。

模板 1——朴素并查集:

# 初始化,p存储每个点的父节点
p = list(range(n))

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]

# 合并a和b所在的两个集合
p[find(a)] = find(b)

模板 2——维护 size 的并查集:

# 初始化,p存储每个点的父节点,size只有当节点是祖宗节点时才有意义,表示祖宗节点所在集合中,点的数量
p = list(range(n))
size = [1] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]

# 合并a和b所在的两个集合
if find(a) != find(b):
    size[find(b)] += size[find(a)]
    p[find(a)] = find(b)

模板 3——维护到祖宗节点距离的并查集:

# 初始化,p存储每个点的父节点,d[x]存储x到p[x]的距离
p = list(range(n))
d = [0] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        t = find(p[x])
        d[x] += d[p[x]]
        p[x] = t
    return p[x]

# 合并a和b所在的两个集合
p[find(a)] = find(b)
d[find(a)] = distance

对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。

然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。

被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。

Python3

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        p = list(range(n))
        size = [1] * n

        def find(x):
            if p[x] != x:
                p[x] = find(p[x])
            return p[x]

        clean = [True] * n
        for i in initial:
            clean[i] = False

        for i in range(n):
            if not clean[i]:
                continue
            for j in range(i + 1, n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    pa, pb = find(i), find(j)
                    if pa == pb:
                        continue
                    p[pa] = pb
                    size[pb] += size[pa]

        cnt = Counter()
        mp = {}
        for i in initial:
            s = set()
            for j in range(n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    s.add(find(j))
            for e in s:
                cnt[e] += 1
            mp[i] = s

        mx = -1
        res = 0
        for i, s in mp.items():
            t = 0
            for e in s:
                if cnt[e] == 1:
                    t += size[e]
            if mx < t or (mx == t and i < res):
                mx = t
                res = i
        return res

Java

class Solution {
    private int[] p;

    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        p = new int[n];
        int[] size = new int[n];
        for (int i = 0; i < n; ++i) {
            p[i] = i;
            size[i] = 1;
        }
        boolean[] clean = new boolean[n];
        Arrays.fill(clean, true);
        for (int i : initial) {
            clean[i] = false;
        }
        for (int i = 0; i < n; ++i) {
            if (!clean[i]) {
                continue;
            }
            for (int j = i + 1; j < n; ++j) {
                if (!clean[j]) {
                    continue;
                }
                if (graph[i][j] == 1) {
                    int pa = find(i), pb = find(j);
                    if (pa == pb) {
                        continue;
                    }
                    p[pa] = pb;
                    size[pb] += size[pa];
                }
            }
        }
        int[] cnt = new int[n];
        Map<Integer, Set<Integer>> mp = new HashMap<>();
        for (int i : initial) {
            Set<Integer> s = new HashSet<>();
            for (int j = 0; j < n; ++j) {
                if (!clean[j]) {
                    continue;
                }
                if (graph[i][j] == 1) {
                    s.add(find(j));
                }
            }
            for (int e : s) {
                cnt[e] += 1;
            }
            mp.put(i, s);
        }
        int mx = -1;
        int res = 0;
        for (Map.Entry<Integer, Set<Integer>> entry : mp.entrySet()) {
            int i = entry.getKey();
            int t = 0;
            for (int e : entry.getValue()) {
                if (cnt[e] == 1) {
                    t += size[e];
                }
            }
            if (mx < t || (mx == t && i < res)) {
                mx = t;
                res = i;
            }
        }
        return res;
    }

    private int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }
}

C++

class Solution {
public:
    vector<int> p;

    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        vector<int> size(n, 1);
        for (int i = 0; i < n; ++i) p.push_back(i);
        vector<bool> clean(n, true);
        for (int i : initial) clean[i] = false;
        for (int i = 0; i < n; ++i)
        {
            if (!clean[i]) continue;
            for (int j = i + 1; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j])
                {
                    int pa = find(i), pb = find(j);
                    if (pa == pb) continue;
                    p[pa] = pb;
                    size[pb] += size[pa];
                }
            }
        }
        vector<int> cnt(n, 0);
        unordered_map<int, unordered_set<int>> mp;
        for (int i : initial)
        {
            unordered_set<int> s;
            for (int j = 0; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j]) s.insert(find(j));
            }
            for (int e : s) ++cnt[e];
            mp[i] = s;
        }
        int mx = -1;
        int res = 0;
        for (auto item : mp)
        {
            int i = item.first;
            int t = 0;
            for (int e : item.second)
            {
                if (cnt[e] == 1) t += size[e];
            }
            if (mx < t || (mx == t && i < res))
            {
                mx = t;
                res = i;
            }
        }
        return res;
    }

    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }
};

Go

var p []int

func minMalwareSpread(graph [][]int, initial []int) int {
	n := len(graph)
	p = make([]int, n)
	size := make([]int, n)
	clean := make([]bool, n)
	for i := 0; i < n; i++ {
		p[i] = i
		size[i] = 1
		clean[i] = true
	}
	for _, i := range initial {
		clean[i] = false
	}
	for i := 0; i < n; i++ {
		if !clean[i] {
			continue
		}
		for j := i + 1; j < n; j++ {
			if !clean[j] {
				continue
			}
			if graph[i][j] == 1 {
				pa, pb := find(i), find(j)
				if pa == pb {
					continue
				}
				p[pa] = pb
				size[pb] += size[pa]
			}
		}
	}
	cnt := make([]int, n)
	mp := make(map[int]map[int]bool)
	for _, i := range initial {
		s := make(map[int]bool)
		for j := 0; j < n; j++ {
			if !clean[j] {
				continue
			}
			if graph[i][j] == 1 {
				s[find(j)] = true
			}
		}
		for e, _ := range s {
			cnt[e]++
		}
		mp[i] = s
	}
	mx, res := -1, 0
	for i, s := range mp {
		t := 0
		for e, _ := range s {
			if cnt[e] == 1 {
				t += size[e]
			}
		}
		if mx < t || (mx == t && i < res) {
			mx, res = t, i
		}
	}
	return res
}

func find(x int) int {
	if p[x] != x {
		p[x] = find(p[x])
	}
	return p[x]
}

...