Skip to content

feat: add solutions to lc problem: No.1577 #3887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,165 @@ tags:

<!-- solution:start -->

### 方法一:哈希表
### 方法一:哈希表 + 枚举

我们用哈希表 `cnt1` 统计 `nums1` 中每个数出现的次数,用哈希表 `cnt2` 统计 `nums2` 中每个数出现的次数
我们用哈希表 $\textit{cnt1}$ 统计 $\textit{nums1}$ 中每个数对 $(\textit{nums}[j], \textit{nums}[k])$ 出现的次数,其中 $0 \leq j \lt k < m$,其中 $m$ 为数组 $\textit{nums1}$ 的长度。用哈希表 $\textit{cnt2}$ 统计 $\textit{nums2}$ 中每个数对 $(\textit{nums}[j], \textit{nums}[k])$ 出现的次数,其中 $0 \leq j \lt k < n$,其中 $n$ 为数组 $\textit{nums2}$ 的长度

然后我们双重循环遍历两个哈希表,记当前 `cnt1` 遍历到的键值对为 $(a, x)$,当前 `cnt2` 遍历到的键值对为 $(b, y)$。接下来分情况讨论:
接下来,我们枚举数组 $\textit{nums1}$ 中的每个数 $x$,计算 $\textit{cnt2}[x^2]$ 的值,即 $\textit{nums2}$ 中有多少对数 $(\textit{nums}[j], \textit{nums}[k])$ 满足 $\textit{nums}[j] \times \textit{nums}[k] = x^2$。同理,我们枚举数组 $\textit{nums2}$ 中的每个数 $x$,计算 $\textit{cnt1}[x^2]$ 的值,即 $\textit{nums1}$ 中有多少对数 $(\textit{nums}[j], \textit{nums}[k])$ 满足 $\textit{nums}[j] \times \textit{nums}[k] = x^2$,最后将两者相加返回即可。

- 如果 $a^2$ 能被 $b$ 整除,设 $c=\frac{a^2}{b}$,若 $b=c$,那么答案加上 $x \times y \times (y - 1)$,否则答案加上 $x \times y \times cnt2[c]$。
- 如果 $b^2$ 能被 $a$ 整除,设 $c=\frac{b^2}{a}$,若 $a=c$,那么答案加上 $x \times (x - 1) \times y$,否则答案加上 $x \times cnt1[c] \times y$。
时间复杂度 $O(m^2 + n^2 + m + n)$,空间复杂度 $O(m^2 + n^2)$。其中 $m$ 和 $n$ 分别为数组 $\textit{nums1}$ 和 $\textit{nums2}$ 的长度。

最后将答案除以 $2$ 返回即可。
<!-- tabs:start -->

#### Python3

```python
class Solution:
def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
def count(nums: List[int]) -> Counter:
cnt = Counter()
for j in range(len(nums)):
for k in range(j + 1, len(nums)):
cnt[nums[j] * nums[k]] += 1
return cnt

def cal(nums: List[int], cnt: Counter) -> int:
return sum(cnt[x * x] for x in nums)

cnt1 = count(nums1)
cnt2 = count(nums2)
return cal(nums1, cnt2) + cal(nums2, cnt1)
```

#### Java

```java
class Solution {
public int numTriplets(int[] nums1, int[] nums2) {
var cnt1 = count(nums1);
var cnt2 = count(nums2);
return cal(cnt1, nums2) + cal(cnt2, nums1);
}

private Map<Long, Integer> count(int[] nums) {
Map<Long, Integer> cnt = new HashMap<>();
int n = nums.length;
for (int j = 0; j < n; ++j) {
for (int k = j + 1; k < n; ++k) {
long x = (long) nums[j] * nums[k];
cnt.merge(x, 1, Integer::sum);
}
}
return cnt;
}

private int cal(Map<Long, Integer> cnt, int[] nums) {
int ans = 0;
for (int x : nums) {
long y = (long) x * x;
ans += cnt.getOrDefault(y, 0);
}
return ans;
}
}
```

时间复杂度 $O(n \times m)$,空间复杂度 $O(n + m)$。其中 $n$ 和 $m$ 分别为数组 `nums1` 和 `nums2` 的长度。
#### C++

```cpp
class Solution {
public:
int numTriplets(vector<int>& nums1, vector<int>& nums2) {
auto cnt1 = count(nums1);
auto cnt2 = count(nums2);
return cal(cnt1, nums2) + cal(cnt2, nums1);
}

unordered_map<long long, int> count(vector<int>& nums) {
unordered_map<long long, int> cnt;
for (int i = 0; i < nums.size(); i++) {
for (int j = i + 1; j < nums.size(); j++) {
cnt[(long long) nums[i] * nums[j]]++;
}
}
return cnt;
}

int cal(unordered_map<long long, int>& cnt, vector<int>& nums) {
int ans = 0;
for (int x : nums) {
ans += cnt[(long long) x * x];
}
return ans;
}
};
```

#### Go

```go
func numTriplets(nums1 []int, nums2 []int) int {
cnt1 := count(nums1)
cnt2 := count(nums2)
return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
cnt := map[int]int{}
for j, x := range nums {
for _, y := range nums[j+1:] {
cnt[x*y]++
}
}
return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
for _, x := range nums {
ans += cnt[x*x]
}
return
}
```

#### TypeScript

```ts
function numTriplets(nums1: number[], nums2: number[]): number {
const cnt1 = count(nums1);
const cnt2 = count(nums2);
return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
const cnt: Map<number, number> = new Map();
for (let j = 0; j < nums.length; ++j) {
for (let k = j + 1; k < nums.length; ++k) {
const x = nums[j] * nums[k];
cnt.set(x, (cnt.get(x) || 0) + 1);
}
}
return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
return nums.reduce((acc, x) => acc + (cnt.get(x * x) || 0), 0);
}
```

<!-- tabs:end -->

<!-- solution:end -->

<!-- solution:start -->

### 方法二:哈希表 + 枚举优化

我们用哈希表 $\textit{cnt1}$ 统计 $\textit{nums1}$ 中每个数出现的次数,用哈希表 $\textit{cnt2}$ 统计 $\textit{nums2}$ 中每个数出现的次数。

接下来,我们枚举数组 $\textit{nums1}$ 中的每个数 $x$,然后枚举 $\textit{cnt2}$ 中的每个数对 $(y, v1)$,其中 $y$ 为 $\textit{cnt2}$ 的键,$v1$ 为 $\textit{cnt2}$ 的值。我们计算 $z = x^2 / y$,如果 $y \times z = x^2$,此时如果 $y = z$,说明 $y$ 和 $z$ 是同一个数,那么 $v1 = v2$,从 $v1$ 个数中任选两个数的方案数为 $v1 \times (v1 - 1) = v1 \times (v2 - 1)$;如果 $y \neq z$,那么 $v1$ 个数中任选两个数的方案数为 $v1 \times v2$。最后将所有方案数相加并除以 $2$ 即可。这里除以 $2$ 是因为我们统计的是对数对 $(j, k)$ 的方案数,而实际上 $(j, k)$ 和 $(k, j)$ 是同一种方案。

时间复杂度 $O(m \times n)$,空间复杂度 $O(m + n)$。其中 $m$ 和 $n$ 分别为数组 $\textit{nums1}$ 和 $\textit{nums2}$ 的长度。

<!-- tabs:start -->

Expand All @@ -96,95 +243,83 @@ tags:
```python
class Solution:
def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
def cal(nums: List[int], cnt: Counter) -> int:
ans = 0
for x in nums:
for y, v1 in cnt.items():
z = x * x // y
if y * z == x * x:
v2 = cnt[z]
ans += v1 * (v2 - int(y == z))
return ans // 2

cnt1 = Counter(nums1)
cnt2 = Counter(nums2)
ans = 0
for a, x in cnt1.items():
for b, y in cnt2.items():
if a * a % b == 0:
c = a * a // b
if b == c:
ans += x * y * (y - 1)
else:
ans += x * y * cnt2[c]
if b * b % a == 0:
c = b * b // a
if a == c:
ans += x * (x - 1) * y
else:
ans += x * y * cnt1[c]
return ans >> 1
return cal(nums1, cnt2) + cal(nums2, cnt1)
```

#### Java

```java
class Solution {
public int numTriplets(int[] nums1, int[] nums2) {
Map<Integer, Integer> cnt1 = new HashMap<>();
Map<Integer, Integer> cnt2 = new HashMap<>();
for (int v : nums1) {
cnt1.put(v, cnt1.getOrDefault(v, 0) + 1);
}
for (int v : nums2) {
cnt2.put(v, cnt2.getOrDefault(v, 0) + 1);
var cnt1 = count(nums1);
var cnt2 = count(nums2);
return cal(cnt1, nums2) + cal(cnt2, nums1);
}

private Map<Integer, Integer> count(int[] nums) {
Map<Integer, Integer> cnt = new HashMap<>();
for (int x : nums) {
cnt.merge(x, 1, Integer::sum);
}
return cnt;
}

private int cal(Map<Integer, Integer> cnt, int[] nums) {
long ans = 0;
for (var e1 : cnt1.entrySet()) {
long a = e1.getKey(), x = e1.getValue();
for (var e2 : cnt2.entrySet()) {
long b = e2.getKey(), y = e2.getValue();
if ((a * a) % b == 0) {
long c = a * a / b;
if (b == c) {
ans += x * y * (y - 1);
} else {
ans += x * y * cnt2.getOrDefault((int) c, 0);
}
}
if ((b * b) % a == 0) {
long c = b * b / a;
if (a == c) {
ans += x * (x - 1) * y;
} else {
ans += x * y * cnt1.getOrDefault((int) c, 0);
}
for (int x : nums) {
for (var e : cnt.entrySet()) {
int y = e.getKey(), v1 = e.getValue();
int z = (int) (1L * x * x / y);
if (y * z == x * x) {
int v2 = cnt.getOrDefault(z, 0);
ans += v1 * (y == z ? v2 - 1 : v2);
}
}
}
return (int) (ans >> 1);
return (int) (ans / 2);
}
}
```

#### Go

```go
func numTriplets(nums1 []int, nums2 []int) (ans int) {
cnt1 := map[int]int{}
cnt2 := map[int]int{}
for _, v := range nums1 {
cnt1[v]++
}
for _, v := range nums2 {
cnt2[v]++
func numTriplets(nums1 []int, nums2 []int) int {
cnt1 := count(nums1)
cnt2 := count(nums2)
return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
cnt := map[int]int{}
for _, x := range nums {
cnt[x]++
}
for a, x := range cnt1 {
for b, y := range cnt2 {
if a*a%b == 0 {
c := a * a / b
if b == c {
ans += x * y * (y - 1)
} else {
ans += x * y * cnt2[c]
}
}
if b*b%a == 0 {
c := b * b / a
if a == c {
ans += x * (x - 1) * y
} else {
ans += x * y * cnt1[c]
return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
for _, x := range nums {
for y, v1 := range cnt {
z := x * x / y
if y*z == x*x {
if v2, ok := cnt[z]; ok {
if y == z {
v2--
}
ans += v1 * v2
}
}
}
Expand All @@ -194,6 +329,38 @@ func numTriplets(nums1 []int, nums2 []int) (ans int) {
}
```

#### TypeScript

```ts
function numTriplets(nums1: number[], nums2: number[]): number {
const cnt1 = count(nums1);
const cnt2 = count(nums2);
return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
const cnt: Map<number, number> = new Map();
for (const x of nums) {
cnt.set(x, (cnt.get(x) || 0) + 1);
}
return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
let ans: number = 0;
for (const x of nums) {
for (const [y, v1] of cnt) {
const z = Math.floor((x * x) / y);
if (y * z == x * x) {
const v2 = cnt.get(z) || 0;
ans += v1 * (y === z ? v2 - 1 : v2);
}
}
}
return ans / 2;
}
```

<!-- tabs:end -->

<!-- solution:end -->
Expand Down
Loading
Loading