# 5-level loop, forgive me...
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
lx, ly, lz = len(xs), len(ys), len(zs)
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
world_xyzs = (
torch.cat(
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
dim=-1,
)
.unsqueeze(0)
.to(count.device)
) # [1, N, 3]
# cascading
for cas in range(self.cascade):
bound = min(2**cas, self.bound)
half_grid_size = bound / resolution
# scale to current cascade's resolution
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
# split batch to avoid OOM
head = 0
while head < B:
tail = min(head + S, B)
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
cam_xyzs = cas_world_xyzs - poses[
head:tail, :3, 3
].unsqueeze(1)
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
# query if point is covered by any camera
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
mask_x = (
torch.abs(cam_xyzs[:, :, 0])
< cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
)
mask_y = (
torch.abs(cam_xyzs[:, :, 1])
< cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
)
mask = (
(mask_z & mask_x & mask_y).sum(0).reshape(lx, ly, lz)
) # [N] --> [lx, ly, lz]
# update count
count[
cas,
xi * S : xi * S + lx,
yi * S : yi * S + ly,
zi * S : zi * S + lz,
] += mask
head += S