class Node: def __init__(self): self.l = 0 self.r = 0 self.v = 0 class SegmentTree: def __init__(self, nums): n = len(nums) self.nums = nums self.tr = [Node() for _ in range(4 * n)] self.build(1, 1, n) def build(self, u, l, r): self.tr[u].l = l self.tr[u].r = r if l == r: self.tr[u].v = self.nums[l - 1] return mid = (l + r) >> 1 self.build(u << 1, l, mid) self.build(u << 1 | 1, mid + 1, r) self.pushup(u) def modify(self, u, x, v): if self.tr[u].l == x and self.tr[u].r == x: self.tr[u].v = v return mid = (self.tr[u].l + self.tr[u].r) >> 1 if x <= mid: self.modify(u << 1, x, v) else: self.modify(u << 1 | 1, x, v) self.pushup(u) def query(self, u, l, r): if self.tr[u].l >= l and self.tr[u].r <= r: return self.tr[u].v mid = (self.tr[u].l + self.tr[u].r) >> 1 v = 0 if l <= mid: v += self.query(u << 1, l, r) if r > mid: v += self.query(u << 1 | 1, l, r) return v def pushup(self, u): self.tr[u].v = self.tr[u << 1].v + self.tr[u << 1 | 1].v class NumMatrix: def __init__(self, matrix: List[List[int]]): self.trees = [SegmentTree(row) for row in matrix] def update(self, row: int, col: int, val: int) -> None: tree = self.trees[row] tree.modify(1, col + 1, val) def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int: return sum( self.trees[row].query(1, col1 + 1, col2 + 1) for row in range(row1, row2 + 1) ) # Your NumMatrix object will be instantiated and called as such: # obj = NumMatrix(matrix) # obj.update(row,col,val) # param_2 = obj.sumRegion(row1,col1,row2,col2)