# 5-level loop, forgive me...
        for xi, xs in enumerate(X):
            for yi, ys in enumerate(Y):
                for zi, zs in enumerate(Z):
                    lx, ly, lz = len(xs), len(ys), len(zs)
                    # construct points
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    world_xyzs = (
                        torch.cat(
                            [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                            dim=-1,
                        )
                        .unsqueeze(0)
                        .to(count.device)
                    )  # [1, N, 3]

                    # cascading
                    for cas in range(self.cascade):
                        bound = min(2**cas, self.bound)
                        half_grid_size = bound / resolution
                        # scale to current cascade's resolution
                        cas_world_xyzs = world_xyzs * (bound - half_grid_size)

                        # split batch to avoid OOM
                        head = 0
                        while head < B:
                            tail = min(head + S, B)

                            # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
                            cam_xyzs = cas_world_xyzs - poses[
                                head:tail, :3, 3
                            ].unsqueeze(1)
                            cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3]  # [S, N, 3]

                            # query if point is covered by any camera
                            mask_z = cam_xyzs[:, :, 2] > 0  # [S, N]
                            mask_x = (
                                torch.abs(cam_xyzs[:, :, 0])
                                < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
                            )
                            mask_y = (
                                torch.abs(cam_xyzs[:, :, 1])
                                < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
                            )
                            mask = (
                                (mask_z & mask_x & mask_y).sum(0).reshape(lx, ly, lz)
                            )  # [N] --> [lx, ly, lz]

                            # update count
                            count[
                                cas,
                                xi * S : xi * S + lx,
                                yi * S : yi * S + ly,
                                zi * S : zi * S + lz,
                            ] += mask
                            head += S
By Anonymous, 2022-05-13 14:03:28