LeetCode 803. 打砖块

题目描述

有一个 mxnm x n 的二元网格,其中 11 表示砖块,00 表示空白。砖块 稳定(不会掉落)的前提是:

一块砖直接连接到网格的顶部,或者
至少有一块相邻(44 个方向之一)砖块 稳定 不会掉落时
给你一个数组 hitshits ,这是需要依次消除砖块的位置。每当消除 hits[i]=(rowi,coli)hits[i] = (rowi, coli) 位置上的砖块时,对应位置的砖块(若存在)会消失,然后其他的砖块可能因为这一消除操作而掉落。一旦砖块掉落,它会立即从网格中消失(即,它不会落在其他稳定的砖块上)。

返回一个数组 resultresult ,其中 result[i]result[i] 表示第 ii 次消除操作对应掉落的砖块数目。

注意,消除可能指向是没有砖块的空白位置,如果发生这种情况,则没有砖块掉落。

示例 1

输入:grid = [[1,0,0,0],[1,1,1,0]], hits = [[1,0]]
输出:[2]
解释:

网格开始为:
[[1,0,0,0],
 [1,1,1,0]]
消除 (1,0) 处加粗的砖块,得到网格:
[[1,0,0,0]
 [0,1,1,0]]
两个加粗的砖不再稳定,因为它们不再与顶部相连,也不再与另一个稳定的砖相邻,因此它们将掉落。得到网格:
[[1,0,0,0],
 [0,0,0,0]]
因此,结果为 [2] 。

示例 2

输入:grid = [[1,0,0,0],[1,1,0,0]], hits = [[1,1],[1,0]]
输出:[0,0]
解释:

网格开始为:
[[1,0,0,0],
 [1,1,0,0]]
消除 (1,1) 处加粗的砖块,得到网格:
[[1,0,0,0],
 [1,0,0,0]]
剩下的砖都很稳定,所以不会掉落。网格保持不变:
[[1,0,0,0], 
 [1,0,0,0]]
接下来消除 (1,0) 处加粗的砖块,得到网格:
[[1,0,0,0],
 [0,0,0,0]]
剩下的砖块仍然是稳定的,所以不会有砖块掉落。
因此,结果为 [0,0] 。

题解

题目大概意思是:对于图中有多个砖块,砖块稳定的前提是砖块与天花板相连或者与其他稳定的砖块相连。按照顺序消去多个砖块,求每个砖块被消去时,会导致多少其他砖块不稳定被消去。

因为是并查集之月 因为是稳定的前提是旁边的砖块稳定,或者 x=0x=0,每次删去一个砖块会导致相连的一群砖块不稳定,因此可以很容易想到并查集。

我们可以将与天花板连接称为 强稳定,与稳定的砖块相连称为 弱稳定。那么弱稳定可以构成一个稳定分类。部分砖块的移除会导致稳定分类划分成两个不同的分类(一个稳定分类,一个被消去分类)

但是,并查集没法解决删除问题,只能解决插入问题。因此这道题不能直接使用并查集,但是很显然删除和插入是对立的,只需要反向思考即可,把问题转换为插入问题。

把问题反向后,新的问题为:图中有多个砖块,图中有多个砖块,砖块稳定的前提是砖块与天花板相连或者与其他稳定的砖块相连。按照顺序插入多个砖块,求每个砖块被插入时,会导致多少个原本不稳定的砖块变为稳定

再来看就可以很方便地进行并查集来解决了,把原本图中所有的砖块建立并查集,如果他们与天花板相连或间接相连,就将其插入到天花板所在的并查集内。每次插入新的砖块,就将其四周的砖块插入到并查集内。

由于这里需要统计个数,因此需要维护一下每个并查集的个数,每次在插入前读一下天花板所在并查集的个数,插入后再读一次,做个减法即可得到每个砖块插入后稳定的砖块个数。
尽管并查集存在压缩特性,但是实际上只有connect时,才会修改并查集的代表,路径压缩只会使得每个节点可以在一步内访问到代表。
因此修改connect函数,将两个集合的代表存储的个数加起来并存储到新的代表对应的位置即可。
需要特别注意,如果新的砖块使得其他砖块稳定,并查集个数增加的也包含当前砖块,因此需要将结果减一。同时有需要考虑当前砖块本身不稳定以及只稳定了自己的情况,因此还需要判断不能结果不能是负数。

由于砖块使用二维坐标描述,因此使用 idx=x×n+y+1idx = x \times n + y + 1 作为其在并查集中的下标。其中 +1+1 的原因是,为了方便处理,将 00 作为天花板

首先将原本的图深拷贝一遍(可以用 copy.deepcopy()),将要消去的砖块在新图中删去。这样就得到了反向问题的初始输入。
接下来,遍历所有砖块,将其与其上下左右的砖块连接到同一个并查集(实际上,如果是从左往右从上往下遍历,只需要将砖块与右面和下面的连接),如果砖块 x=0x=0 则与天花板相连。
接下来反向遍历消去数组,与预处理的操作相同,与四周的砖块连接到同一并查集,并计算个数差值。


虽然题目本身描述、思维都很麻烦,但是理清楚后,实际上理解和实现都很简单。
考虑到并查集的路径压缩,尝试把获取个数也做了路径压缩,然而却导致耗时飙升至超时。

虽然超时,但是结果是对的。所以有必要考虑下为什么会导致这个问题。

def getC(t):
    c[t] = c[t] if t == f[t] else getC(f[t])
    return c[t]

在这里,只要是return c[t]return c[f[t]],都会超时,而return c[getF[t]]则不会超时

无论是c[t]c[f[t]]还是c[getF(t)]本质上并没有不同,导致超时的实际上是路径压缩的问题。由于前面的压缩只包括c的压缩,并未压缩f。在后面插入节点时,并未对这些节点进行过getF(),使得他们并未被压缩,而getC()又未压缩f,因此每次获取都需要大量的时间向上慢慢回溯。

这时,直接getF(t)反而更快,因为getF()可以认为必定只有一层。

def getC(t):
    return c[getF(t)]

代码

class Solution:
    def hitBricks(self, grid: List[List[int]], hits: List[List[int]]) -> List[int]:
        if len(grid) == 0:
            return [0] * len(hits)

        # import time
        # start = time.time()
        
        n = len(grid)
        m = len(grid[0])

        ngrid = [[grid[x][y] for y in range(m)] for x in range(n)]
        for h in hits:
            x, y = h
            ngrid[x][y] = 0

        # print(time.time() - start)

        # for x in range(n):
        #     for y in range(m):
        #         print(ngrid[x][y], end=" ")
        #     print()

        f = [i for i in range(m * n + 1)]
        c = [1 for i in range(m * n + 1)]
        def getIdx(x, y):
            return x * m + y + 1   
        def getF(t):
            f[t] = f[t] if t == f[t] else getF(f[t])
            return f[t]
        def getC(t):
            # c[t] = c[t] if t == f[t] else getC(f[t])
            return c[getF(t)]
        def connect(t1, t2):
            root1 = getF(t1)
            root2 = getF(t2)
            if root1 != root2:
                c[root2] = c[root1] + c[root2]
                c[root1] = c[root2]
                f[root1] = root2

        for x in range(n):
            for y in range(m):
                if ngrid[x][y] == 1:
                    t = getIdx(x, y)
                    for xx, yy in [(x+1,y), (x,y+1)]:
                        if 0 <= xx and xx < n and 0 <= yy and yy < m and ngrid[xx][yy] == 1:
                            tt = getIdx(xx, yy)
                            connect(t, tt)
                            # print("connect", t,x,y, tt,xx,yy)
                    if x == 0:
                        connect(0, t)

        # print(time.time() - start)

        # for x in range(n):
        #     for y in range(m):
        #         print(getF(getIdx(x, y)), end=" ")
        #     print()

        # for x in range(n):
        #     for y in range(m):
        #         print(getC(getIdx(x, y)), end=" ")
        #     print()


        res = [0] * len(hits)
        for i, hit in enumerate(reversed(hits)):
            x, y = hit
            if grid[x][y] == 1:
                ngrid[x][y] = 1
                beforeNum = getC(0)
                # print("insert",x,y)

                t = getIdx(x, y)
                for xx, yy in [(x-1,y),(x+1,y),(x,y-1),(x,y+1)]:
                    if 0 <= xx and xx < n and 0 <= yy and yy < m and ngrid[xx][yy] == 1:
                        tt = getIdx(xx, yy)
                        connect(t, tt)
                        # print("connect2", t,x,y, tt,xx,yy)
                if x == 0:
                    connect(0, t)

                nowNum = getC(0)
                # print("before", beforeNum, "now", nowNum)
                res[i] = max(nowNum - beforeNum - 1, 0)

        # print(time.time() - start)

        res.reverse()
        return res