You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
3.3 KiB
117 lines
3.3 KiB
''' |
|
获取Mesh的深度图 |
|
|
|
required: |
|
- open3d |
|
- numpy |
|
- numba |
|
''' |
|
|
|
import open3d as o3d |
|
import numpy as np |
|
from numba import njit |
|
|
|
@njit |
|
def _process_all_points(Z, width, height, cx, cy, fx, fy, c2w): |
|
"""JIT优化的所有点处理函数""" |
|
# 手动创建像素网格(替代np.meshgrid) |
|
u = np.zeros((height, width), dtype=np.float64) |
|
v = np.zeros((height, width), dtype=np.float64) |
|
|
|
for y in range(height): |
|
for x in range(width): |
|
u[y, x] = x |
|
v[y, x] = y |
|
|
|
# 创建输出数组 |
|
Z_all = np.zeros_like(Z) |
|
|
|
# 手动处理有效点 |
|
for y in range(height): |
|
for x in range(width): |
|
if Z[y, x] > 0 and Z[y, x] < 1e9 and not np.isnan(Z[y, x]): |
|
Z_all[y, x] = Z[y, x] |
|
else: |
|
Z_all[y, x] = -10000 |
|
|
|
# 计算所有点的3D坐标 |
|
X_all = (u - cx) * Z_all / fx |
|
Y_all = (v - cy) * Z_all / fy |
|
|
|
# 创建齐次坐标 |
|
points_3d = np.zeros((height * width, 4)) |
|
points_3d[:, 0] = X_all.reshape(-1) |
|
points_3d[:, 1] = Y_all.reshape(-1) |
|
points_3d[:, 2] = Z_all.reshape(-1) |
|
points_3d[:, 3] = 1.0 |
|
|
|
# 坐标变换 |
|
return points_3d @ c2w.T |
|
|
|
@njit |
|
def _process_valid_points(Z, width, height, cx, cy, fx, fy, c2w): |
|
"""JIT优化的有效点处理函数""" |
|
# 首先计算有效点的数量 |
|
valid_count = 0 |
|
for y in range(height): |
|
for x in range(width): |
|
if Z[y, x] > 0 and Z[y, x] < 1e9 and not np.isnan(Z[y, x]): |
|
valid_count += 1 |
|
|
|
# 创建输出数组 |
|
X_valid = np.zeros(valid_count) |
|
Y_valid = np.zeros(valid_count) |
|
Z_valid = np.zeros(valid_count) |
|
|
|
# 填充有效点数据 |
|
idx = 0 |
|
for y in range(height): |
|
for x in range(width): |
|
if Z[y, x] > 0 and Z[y, x] < 1e9 and not np.isnan(Z[y, x]): |
|
Z_valid[idx] = Z[y, x] |
|
X_valid[idx] = (x - cx) * Z_valid[idx] / fx |
|
Y_valid[idx] = (y - cy) * Z_valid[idx] / fy |
|
idx += 1 |
|
|
|
# 创建齐次坐标 |
|
points_3d = np.zeros((valid_count, 4)) |
|
points_3d[:, 0] = X_valid |
|
points_3d[:, 1] = Y_valid |
|
points_3d[:, 2] = Z_valid |
|
points_3d[:, 3] = 1.0 |
|
|
|
# 坐标变换 |
|
return points_3d @ c2w.T |
|
|
|
def get_world_points(depth_map: o3d.geometry.Image, K: np.ndarray, c2w: np.ndarray, all_point: bool = False) -> np.ndarray: |
|
''' |
|
获取深度图对应的世界坐标点云 |
|
|
|
param: |
|
depth_map: 深度图 |
|
K: 相机内参矩阵 |
|
c2w: 相机坐标系到世界坐标系的变换矩阵 |
|
all_point: 是否返回所有点(包括无效点) |
|
|
|
return: |
|
np.ndarray: 点云坐标(世界坐标系) |
|
- 当all_point=False时:只包含有效深度值的点云 |
|
- 当all_point=True时:包含所有点的点云,无效点深度值设为10000 |
|
''' |
|
# 转换为numpy数组 |
|
Z = np.asarray(depth_map) |
|
height, width = Z.shape |
|
|
|
# 获取内参 |
|
fx, fy = K[0, 0], K[1, 1] |
|
cx, cy = K[0, 2], K[1, 2] |
|
|
|
# 确保数据类型兼容性(Numba要求) |
|
Z = Z.astype(np.float64) |
|
c2w = c2w.astype(np.float64) |
|
|
|
if all_point: |
|
return _process_all_points(Z, width, height, cx, cy, fx, fy, c2w) |
|
else: |
|
return _process_valid_points(Z, width, height, cx, cy, fx, fy, c2w) |
|
|
|
|