Skip to content

Latest commit

 

History

History

1569.Number of Ways to Reorder Array to Get Same BST

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

English Version

题目描述

给你一个数组 nums 表示 1 到 n 的一个排列。我们按照元素在 nums 中的顺序依次插入一个初始为空的二叉搜索树(BST)。请你统计将 nums 重新排序后,统计满足如下条件的方案数:重排后得到的二叉搜索树与 nums 原本数字顺序得到的二叉搜索树相同。

比方说,给你 nums = [2,1,3],我们得到一棵 2 为根,1 为左孩子,3 为右孩子的树。数组 [2,3,1] 也能得到相同的 BST,但 [3,2,1] 会得到一棵不同的 BST 。

请你返回重排 nums 后,与原数组 nums 得到相同二叉搜索树的方案数。

由于答案可能会很大,请将结果对 10^9 + 7 取余数。

 

示例 1:

输入:nums = [2,1,3]
输出:1
解释:我们将 nums 重排, [2,3,1] 能得到相同的 BST 。没有其他得到相同 BST 的方案了。

示例 2:

输入:nums = [3,4,5,1,2]
输出:5
解释:下面 5 个数组会得到相同的 BST:
[3,1,2,4,5]
[3,1,4,2,5]
[3,1,4,5,2]
[3,4,1,2,5]
[3,4,1,5,2]

示例 3:

输入:nums = [1,2,3]
输出:0
解释:没有别的排列顺序能得到相同的 BST 。

 

提示:

  • 1 <= nums.length <= 1000
  • 1 <= nums[i] <= nums.length
  • nums 中所有数 互不相同 。

解法

方法一:组合计数 + 递归

我们设计一个函数 $dfs(nums)$,它的功能是计算以 $nums$ 为节点构成的二叉搜索树的方案数。那么答案就是 $dfs(nums)-1$,因为 $dfs(nums)$ 计算的是以 $nums$ 为节点构成的二叉搜索树的方案数,而题目要求的是重排后与原数组 $nums$ 得到相同二叉搜索树的方案数,因此答案需要减去一。

接下来,我们来看一下 $dfs(nums)$ 的计算方法。

对于一个数组 $nums$,它的第一个元素是根节点,那么它的左子树的元素都小于它,右子树的元素都大于它。因此,我们可以将数组分为三部分,第一部分是根节点,第二部分是左子树的元素,记为 $left$,第三部分是右子树的元素,记为 $right$。那么,左子树的元素个数为 $m$,右子树的元素个数为 $n$,那么 $left$$right$ 的方案数分别为 $dfs(left)$$dfs(right)$。我们可以在数组 $nums$$m + n$ 个位置中选择 $m$ 个位置放置左子树的元素,剩下的 $n$ 个位置放置右子树的元素,这样就能保证重排后与原数组 $nums$ 得到相同二叉搜索树。因此,$dfs(nums)$ 的计算方法为:

$$ dfs(nums) = C_{m+n}^m \times dfs(left) \times dfs(right) $$

其中 $C_{m+n}^m$ 表示从 $m + n$ 个位置中选择 $m$ 个位置的方案数,我们可以通过预处理得到。

注意答案的取模运算,因为 $dfs(nums)$ 的值可能会很大,所以我们需要在计算过程中对每一步的结果取模,最后再对整个结果取模。

时间复杂度 $O(n^2)$,空间复杂度 $O(n^2)$。其中 $n$ 是数组 $nums$ 的长度。

class Solution:
    def numOfWays(self, nums: List[int]) -> int:
        def dfs(nums):
            if len(nums) < 2:
                return 1
            left = [x for x in nums if x < nums[0]]
            right = [x for x in nums if x > nums[0]]
            m, n = len(left), len(right)
            a, b = dfs(left), dfs(right)
            return (((c[m + n][m] * a) % mod) * b) % mod

        n = len(nums)
        mod = 10**9 + 7
        c = [[0] * n for _ in range(n)]
        c[0][0] = 1
        for i in range(1, n):
            c[i][0] = 1
            for j in range(1, i + 1):
                c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod
        return (dfs(nums) - 1 + mod) % mod
class Solution {
    private int[][] c;
    private final int mod = (int) 1e9 + 7;

    public int numOfWays(int[] nums) {
        int n = nums.length;
        c = new int[n][n];
        c[0][0] = 1;
        for (int i = 1; i < n; ++i) {
            c[i][0] = 1;
            for (int j = 1; j <= i; ++j) {
                c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
            }
        }
        List<Integer> list = new ArrayList<>();
        for (int x : nums) {
            list.add(x);
        }
        return (dfs(list) - 1 + mod) % mod;
    }

    private int dfs(List<Integer> nums) {
        if (nums.size() < 2) {
            return 1;
        }
        List<Integer> left = new ArrayList<>();
        List<Integer> right = new ArrayList<>();
        for (int x : nums) {
            if (x < nums.get(0)) {
                left.add(x);
            } else if (x > nums.get(0)) {
                right.add(x);
            }
        }
        int m = left.size(), n = right.size();
        int a = dfs(left), b = dfs(right);
        return (int) ((long) a * b % mod * c[m + n][n] % mod);
    }
}
class Solution {
public:
    int numOfWays(vector<int>& nums) {
        int n = nums.size();
        const int mod = 1e9 + 7;
        int c[n][n];
        memset(c, 0, sizeof(c));
        c[0][0] = 1;
        for (int i = 1; i < n; ++i) {
            c[i][0] = 1;
            for (int j = 1; j <= i; ++j) {
                c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
            }
        }
        function<int(vector<int>)> dfs = [&](vector<int> nums) -> int {
            if (nums.size() < 2) {
                return 1;
            }
            vector<int> left, right;
            for (int& x : nums) {
                if (x < nums[0]) {
                    left.push_back(x);
                } else if (x > nums[0]) {
                    right.push_back(x);
                }
            }
            int m = left.size(), n = right.size();
            int a = dfs(left), b = dfs(right);
            return c[m + n][m] * 1ll * a % mod * b % mod;
        };
        return (dfs(nums) - 1 + mod) % mod;
    }
};
func numOfWays(nums []int) int {
	n := len(nums)
	const mod = 1e9 + 7
	c := make([][]int, n)
	for i := range c {
		c[i] = make([]int, n)
	}
	c[0][0] = 1
	for i := 1; i < n; i++ {
		c[i][0] = 1
		for j := 1; j <= i; j++ {
			c[i][j] = (c[i-1][j] + c[i-1][j-1]) % mod
		}
	}
	var dfs func(nums []int) int
	dfs = func(nums []int) int {
		if len(nums) < 2 {
			return 1
		}
		var left, right []int
		for _, x := range nums[1:] {
			if x < nums[0] {
				left = append(left, x)
			} else {
				right = append(right, x)
			}
		}
		m, n := len(left), len(right)
		a, b := dfs(left), dfs(right)
		return c[m+n][m] * a % mod * b % mod
	}
	return (dfs(nums) - 1 + mod) % mod
}
function numOfWays(nums: number[]): number {
    const n = nums.length;
    const mod = 1e9 + 7;
    const c = new Array(n).fill(0).map(() => new Array(n).fill(0));
    c[0][0] = 1;
    for (let i = 1; i < n; ++i) {
        c[i][0] = 1;
        for (let j = 1; j <= i; ++j) {
            c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
        }
    }
    const dfs = (nums: number[]): number => {
        if (nums.length < 2) {
            return 1;
        }
        const left: number[] = [];
        const right: number[] = [];
        for (let i = 1; i < nums.length; ++i) {
            if (nums[i] < nums[0]) {
                left.push(nums[i]);
            } else {
                right.push(nums[i]);
            }
        }
        const m = left.length;
        const n = right.length;
        const a = dfs(left);
        const b = dfs(right);
        return Number((BigInt(c[m + n][m]) * BigInt(a) * BigInt(b)) % BigInt(mod));
    };
    return (dfs(nums) - 1 + mod) % mod;
}