diff --git a/fill_all_empty_faces_v1.2.py b/fill_all_empty_faces_v1.2.py index 1df0cd7..81cd892 100644 --- a/fill_all_empty_faces_v1.2.py +++ b/fill_all_empty_faces_v1.2.py @@ -1,547 +1,766 @@ -import torch +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +fill_all_empty_faces_v2_harmonic_atlas.py + +目标:替换原来的“一个缺色孤岛 -> 一个均值色块”的逻辑。 + +核心逻辑: +1. 正常 face 保持原有 UV,只因为 texture 高度变大而 remap v 坐标。 +2. missing face 不再按 region 写一个颜色点,也不再整组 face 指向同一个 vt。 +3. 在 3D mesh 拓扑上,对 missing region 的“顶点颜色”做 harmonic/Laplace 扩散: + - 边界顶点颜色来自相邻正常 face 的原始 texture 采样; + - 内部顶点颜色由图拉普拉斯扩散得到; + - 所以颜色从边缘向内部连续变化。 +4. 每个 missing face 在底部新增 atlas 中分配一个小三角 tile。 + - tile 内不是纯色,而是按三个顶点颜色做 barycentric 插值; + - 相邻 missing face 即使在 atlas 里不相邻,只要共享 3D 顶点,边上的插值颜色也连续; + - tile 周围做 padding,避免纹理过滤时采到白底/其他 tile。 + +输入输出参数保持原脚本一致: + --input_obj + --input_texture + --missing_faces + --output_obj + --output_texture + +额外参数都是可选。 +""" + import argparse -import numpy as np -import cv2 +import os import time -from collections import defaultdict +from collections import defaultdict, deque +from typing import Dict, List, Tuple, Set + +import cv2 +import numpy as np import tqdm -from multiprocessing import Pool, cpu_count -from typing import List, Tuple, Dict # 如果还未导入 - - -def read_vertices(obj_path): - vertices = [] - with open(obj_path, 'r') as file: - lines = file.readlines() - for line in lines: - if line.startswith('v '): # 顶点坐标 - vertices.append(list(map(float, line.split()[1:4]))) - vertices = torch.tensor(vertices) - return vertices - -def read_uvs(obj_path): - uv_coordinates = [] - with open(obj_path, 'r') as file: - lines = file.readlines() - for line in lines: - if line.startswith('vt '): # UV 坐标 - uv_coordinates.append(list(map(float, line.split()[1:3]))) - uv_coordinates = torch.tensor(uv_coordinates) - return uv_coordinates - -def read_faces(obj_path): - vertex_indices = [] - uv_indices = [] - with open(obj_path, 'r') as file: - lines = file.readlines() - - for line in lines: - if line.startswith('f '): # 面 - parts = line.split()[1:] - v_indices = [] - uv_indices_temp = [] - for face in parts: - v, vt = map(int, face.split('/')[:2]) - v_indices.append(v - 1) - uv_indices_temp.append(vt - 1) - vertex_indices.append(v_indices) - uv_indices.append(uv_indices_temp) - vertex_indices = torch.tensor(vertex_indices) - uv_indices = torch.tensor(uv_indices) - return vertex_indices, uv_indices - -def read_missing_faces(missing_faces_path): - with open(missing_faces_path, 'r') as file: - lines = file.readlines() - missing_color_faces = torch.tensor( - [int(line.strip()) for line in lines] + + +# ============================================================ +# OBJ / texture IO +# ============================================================ + +def _parse_obj_index(s: str, current_len: int) -> int: + """OBJ index: positive is 1-based, negative is relative to current list end.""" + idx = int(s) + if idx > 0: + return idx - 1 + if idx < 0: + return current_len + idx + raise ValueError("OBJ index 不能为 0") + + +def read_obj_basic(obj_path: str): + vertices: List[List[float]] = [] + uvs: List[List[float]] = [] + vertex_indices: List[List[int]] = [] + uv_indices: List[List[int]] = [] + + mtllib_lines: List[str] = [] + usemtl_name = "material_0" + + with open(obj_path, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + if line.startswith("mtllib "): + mtllib_lines.append(line.strip()) + elif line.startswith("usemtl "): + parts = line.strip().split(maxsplit=1) + if len(parts) == 2: + usemtl_name = parts[1] + elif line.startswith("v "): + parts = line.strip().split() + if len(parts) < 4: + continue + vertices.append([float(parts[1]), float(parts[2]), float(parts[3])]) + elif line.startswith("vt "): + parts = line.strip().split() + if len(parts) < 3: + continue + uvs.append([float(parts[1]), float(parts[2])]) + elif line.startswith("f "): + parts = line.strip().split()[1:] + if len(parts) != 3: + raise ValueError(f"当前脚本只支持三角面,发现非三角 face: {line.strip()}") + + v_row = [] + vt_row = [] + for token in parts: + sub = token.split("/") + if len(sub) < 2 or sub[1] == "": + raise ValueError(f"face 缺少 vt,无法处理: {line.strip()}") + v_row.append(_parse_obj_index(sub[0], len(vertices))) + vt_row.append(_parse_obj_index(sub[1], len(uvs))) + + vertex_indices.append(v_row) + uv_indices.append(vt_row) + + vertices_np = np.asarray(vertices, dtype=np.float32) + uvs_np = np.asarray(uvs, dtype=np.float32) + vertex_indices_np = np.asarray(vertex_indices, dtype=np.int64) + uv_indices_np = np.asarray(uv_indices, dtype=np.int64) + + if len(vertices_np) == 0 or len(uvs_np) == 0 or len(vertex_indices_np) == 0: + raise ValueError("OBJ 内没有读到完整的 v / vt / f 数据。") + + return vertices_np, uvs_np, vertex_indices_np, uv_indices_np, mtllib_lines, usemtl_name + + +def read_missing_faces(path: str, index_base: int = 0) -> np.ndarray: + arr = [] + with open(path, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + s = line.strip() + if not s: + continue + arr.append(int(s) - index_base) + return np.asarray(arr, dtype=np.int64) + + +def read_texture_rgb(path: str) -> np.ndarray: + bgr = cv2.imread(path, cv2.IMREAD_COLOR) + if bgr is None: + raise FileNotFoundError(f"无法读取贴图: {path}") + return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + +def write_texture_rgb(path: str, rgb: np.ndarray): + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + ok = cv2.imwrite(path, bgr, [cv2.IMWRITE_PNG_COMPRESSION, 3]) + if not ok: + raise IOError(f"写入贴图失败: {path}") + + +def write_obj_with_uv_coordinates( + filename: str, + vertices: np.ndarray, + uvs: np.ndarray, + vertex_indices: np.ndarray, + uv_indices: np.ndarray, + mtllib_lines: List[str] = None, + usemtl_name: str = "material_0", +): + os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True) + + if not mtllib_lines: + mtllib_lines = ["mtllib mesh.mtl"] + + estimated_size = len(vertices) * 48 + len(uvs) * 32 + len(vertex_indices) * 64 + buffer_size = min(max(int(estimated_size * 1.2), 64 * 1024 * 1024), 1024 * 1024 * 1024) + + lines: List[str] = [] + lines.extend(mtllib_lines) + + for v in vertices: + lines.append("v %.6f %.6f %.6f" % (v[0], v[1], v[2])) + + lines.append("") + for uv in uvs: + lines.append("vt %.8f %.8f" % (uv[0], uv[1])) + + lines.append("") + lines.append(f"usemtl {usemtl_name}") + + for v_idx, vt_idx in zip(vertex_indices, uv_indices): + lines.append( + "f %d/%d %d/%d %d/%d" % ( + int(v_idx[0]) + 1, int(vt_idx[0]) + 1, + int(v_idx[1]) + 1, int(vt_idx[1]) + 1, + int(v_idx[2]) + 1, int(vt_idx[2]) + 1, + ) ) - return missing_color_faces - -def read_uv_map(input_texture_path): - uv_map = cv2.imread(input_texture_path) - uv_map = cv2.cvtColor(uv_map, cv2.COLOR_BGR2RGB) - uv_map = torch.from_numpy(uv_map) - return uv_map - -def parse_obj_file_and_uv_map(obj_path, missing_faces_path, input_texture_path, device): - print(f"Reading OBJ file: {obj_path}") - - # vertices = [] - # uv_coordinates = [] - # vertex_indices = [] - # uv_indices = [] - # multiprocessing.set_start_method('spawn', force=True) - # multiprocessing.freeze_support() - start_time = time.time() - - - p = Pool(5) - uv_map_result = p.apply_async(read_uv_map, (input_texture_path,)) - vertices_result = p.apply_async(read_vertices, (obj_path,)) - uv_coordinates_result = p.apply_async(read_uvs, (obj_path,)) - faces_result = p.apply_async(read_faces, (obj_path,)) - missing_faces_result = p.apply_async(read_missing_faces, (missing_faces_path,)) - - p.close() - p.join() - - vertices = vertices_result.get() - uv_coordinates = uv_coordinates_result.get() - vertex_indices, uv_indices = faces_result.get() - missing_color_faces = missing_faces_result.get() - uv_map = uv_map_result.get() - - vertices = vertices.to(device) - uv_coordinates = uv_coordinates.to(device) - vertex_indices = vertex_indices.to(device) - uv_indices = uv_indices.to(device) - missing_color_faces = missing_color_faces.to(device) - uv_map = uv_map.to(device) - - end_time = time.time() - print(f"using: {end_time - start_time} seconds") - - # exit() - print("Converting to tensors...") - - return vertices, uv_coordinates, vertex_indices, uv_indices, missing_color_faces, uv_map - -def write_obj_with_uv_coordinates(filename, vertices, uvs, vertex_indices, uv_indices): - """ - 高性能OBJ文件写入函数 - - Parameters: - filename (str): 输出OBJ文件路径 - vertices (np.ndarray): 顶点数组 - uvs (np.ndarray): UV坐标数组 - vertex_indices (np.ndarray): 面的顶点索引 - uv_indices (np.ndarray): 面的UV索引 - """ - # 估算数据大小(以字节为单位) - estimated_size = ( - len(vertices) * 40 + # 每个顶点约40字节 (v x.xxxxxx y.xxxxxx z.xxxxxx\n) - len(uvs) * 30 + # 每个UV坐标约30字节 (vt x.xxxxxx y.xxxxxx\n) - len(vertex_indices) * 40 # 每个面约40字节 (f v1/vt1 v2/vt2 v3/vt3\n) - ) - - # 设置缓冲区大小为估算大小的1.2倍,最小256MB,最大1GB - buffer_size = min(max(int(estimated_size * 1.2), 256 * 1024 * 1024), 1024 * 1024 * 1024) - - # 使用格式化字符串和列表推导式优化字符串生成 - vertex_lines = ['v %.6f %.6f %.6f' % (v[0], v[1], v[2]) for v in vertices] - uv_lines = ['vt %.6f %.6f' % (uv[0], uv[1]) for uv in uvs] - - # 优化face数据处理 - face_lines = [] - face_format = 'f %d/%d %d/%d %d/%d' - for v_idx, uv_idx in zip(vertex_indices, uv_indices): - face_lines.append(face_format % ( - v_idx[0] + 1, uv_idx[0] + 1, - v_idx[1] + 1, uv_idx[1] + 1, - v_idx[2] + 1, uv_idx[2] + 1 - )) - - # 使用join一次性构建完整内容 - content = ['mtllib mesh.mtl'] + vertex_lines + [''] + uv_lines + [''] + ['usemtl material_0'] + face_lines - - # 一次性写入所有数据 - with open(filename, 'w', buffering=buffer_size) as f: - f.write('\n'.join(content)) - -def load_regions(filename): - regions = [] - with open(filename, 'r') as file: - for line in file: - parts = line.split(";") - if len(parts) != 2: - continue # Skip any lines that don't have exactly two parts - - first_set = set(int(x) for x in parts[0].strip().split()) - second_set = set(int(x) for x in parts[1].strip().split()) - regions.append((first_set, second_set)) - return regions + with open(filename, "w", buffering=buffer_size, encoding="utf-8") as f: + f.write("\n".join(lines)) + +# ============================================================ +# Mesh topology +# ============================================================ -def build_face_adjacency(vertices, faces): +def build_face_adjacency_and_vertex_faces(faces: np.ndarray, num_vertices: int): """ - 构建面的邻接关系,基于共享边 - - Args: - vertices: 顶点数组 - faces: 面片索引数组 (N x 3) - - Returns: - dict: 面片邻接关系字典,key为面片索引,value为邻接面片索引列表 + 基于共享边构建 face adjacency,同时构建 vertex -> faces。 + 返回:face_adjacency, vertex_faces """ - # 将faces转换为numpy数组以加快处理速度 - faces = np.asarray(faces) + faces = np.asarray(faces, dtype=np.int64) num_faces = len(faces) - - # 为每个面创建所有边 (Nx3x2) + + # vertex -> faces + vertex_faces: List[List[int]] = [[] for _ in range(num_vertices)] + for fi, row in enumerate(faces): + vertex_faces[int(row[0])].append(fi) + vertex_faces[int(row[1])].append(fi) + vertex_faces[int(row[2])].append(fi) + + # edge -> faces,向量化排序 edges = np.stack([ - np.column_stack((faces[:, i], faces[:, (i + 1) % 3])) - for i in range(3) - ], axis=1) - - # 确保边的方向一致 (较小的顶点索引在前) - edges.sort(axis=2) - - # 将边展平为 (Nx3, 2) 的形状 - edges = edges.reshape(-1, 2) - - # 创建边到面的映射 - edge_faces = np.repeat(np.arange(num_faces), 3) - - # 使用复合键对边进行排序 - edge_keys = edges[:, 0] * vertices.shape[0] + edges[:, 1] - sort_idx = np.argsort(edge_keys) - edges = edges[sort_idx] - edge_faces = edge_faces[sort_idx] - - # 找到重复的边(共享边) - same_edges = (edge_keys[sort_idx][1:] == edge_keys[sort_idx][:-1]) - edge_start_idx = np.where(same_edges)[0] - - # 构建邻接字典 - face_adjacency = defaultdict(list) - for idx in edge_start_idx: - face1, face2 = edge_faces[idx], edge_faces[idx + 1] - face_adjacency[face1].append(face2) - face_adjacency[face2].append(face1) - - return dict(face_adjacency) - -def find_groups_and_subgroups(face_adjacency, missing_faces): - """ - 找到相连的面组和它们的邻接面 - 返回: - regions: 列表,每个元素是一个元组 (missing_faces_set, adjacent_faces_set), - 与 load_regions() 函数返回格式保持一致 - """ - missing_faces_set = set(missing_faces.cpu().numpy()) - unused_faces = set(missing_faces.cpu().numpy()) - regions = [] - - total_faces = len(unused_faces) - with tqdm.tqdm(total=total_faces, desc="Processing faces") as pbar: - while unused_faces: - start_face = unused_faces.pop() - current_group = {start_face} - current_subgroup = set() - - stack = [start_face] - while stack: - face_idx = stack.pop() - - for neighbor in face_adjacency.get(face_idx, []): - if neighbor in unused_faces: - current_group.add(neighbor) - unused_faces.remove(neighbor) - stack.append(neighbor) - elif neighbor not in missing_faces_set: - current_subgroup.add(neighbor) - - regions.append((current_group, current_subgroup)) - pbar.update(total_faces - len(unused_faces) - pbar.n) - - # 输出统计信息 - print(f"\nTotal regions: {len(regions)}") - print(f"Average missing faces group size: {sum(len(g[0]) for g in regions)/len(regions):.2f}") - print(f"Largest missing faces group size: {max(len(g[0]) for g in regions)}") - print(f"Smallest missing faces group size: {min(len(g[0]) for g in regions)}") - - # 检查每个组是否都有邻接面 - for i, (group, subgroup) in enumerate(regions): - if not subgroup: - print(f"Warning: Region {i} with {len(group)} missing faces has no adjacent faces!") - + np.column_stack((faces[:, 0], faces[:, 1])), + np.column_stack((faces[:, 1], faces[:, 2])), + np.column_stack((faces[:, 2], faces[:, 0])), + ], axis=1).reshape(-1, 2) + edges.sort(axis=1) + edge_faces = np.repeat(np.arange(num_faces, dtype=np.int64), 3) + + key = edges[:, 0] * num_vertices + edges[:, 1] + order = np.argsort(key) + key = key[order] + edge_faces = edge_faces[order] + + adj = defaultdict(list) + same = key[1:] == key[:-1] + pos = np.where(same)[0] + for p in pos: + f0 = int(edge_faces[p]) + f1 = int(edge_faces[p + 1]) + if f0 != f1: + adj[f0].append(f1) + adj[f1].append(f0) + + return dict(adj), vertex_faces + + +def find_missing_regions(face_adjacency: Dict[int, List[int]], missing_faces: np.ndarray, total_faces: int): + missing_all = set(int(x) for x in missing_faces if 0 <= int(x) < total_faces) + unused = set(missing_all) + regions: List[Tuple[Set[int], Set[int]]] = [] + + pbar = tqdm.tqdm(total=len(unused), desc="Finding missing islands") + done = 0 + + while unused: + start = unused.pop() + group = {start} + boundary = set() + stack = [start] + + while stack: + f = stack.pop() + for nb in face_adjacency.get(f, []): + if nb in unused: + unused.remove(nb) + group.add(nb) + stack.append(nb) + elif nb not in missing_all: + boundary.add(nb) + + regions.append((group, boundary)) + new_done = len(missing_all) - len(unused) + pbar.update(new_done - done) + done = new_done + + pbar.close() + + if regions: + sizes = [len(x[0]) for x in regions] + print(f"Total missing islands: {len(regions)}") + print(f"Island size: min={min(sizes)}, max={max(sizes)}, avg={sum(sizes)/len(sizes):.2f}") + else: + print("No missing islands found.") + return regions -def compute_regions_face_colors( - regions: List[Tuple[set, set]], - uv_map: torch.Tensor, - uvs: torch.Tensor, - face_uv_indices: torch.Tensor, - device: str -) -> Dict[int, torch.Tensor]: + +# ============================================================ +# Texture sampling +# ============================================================ + +def sample_texture_bilinear_rgb(texture_rgb: np.ndarray, uv: np.ndarray) -> np.ndarray: + """uv: [2], v is OBJ convention, image y is flipped.""" + H, W = texture_rgb.shape[:2] + u = float(np.clip(uv[0], 0.0, 1.0)) + v = float(np.clip(uv[1], 0.0, 1.0)) + + x = u * (W - 1) + y = (1.0 - v) * (H - 1) + + x0 = int(np.floor(x)) + y0 = int(np.floor(y)) + x1 = min(x0 + 1, W - 1) + y1 = min(y0 + 1, H - 1) + + dx = x - x0 + dy = y - y0 + + c00 = texture_rgb[y0, x0].astype(np.float32) + c10 = texture_rgb[y0, x1].astype(np.float32) + c01 = texture_rgb[y1, x0].astype(np.float32) + c11 = texture_rgb[y1, x1].astype(np.float32) + + c0 = c00 * (1.0 - dx) + c10 * dx + c1 = c01 * (1.0 - dx) + c11 * dx + return c0 * (1.0 - dy) + c1 * dy + + +def sample_face_corner_color( + face_id: int, + corner: int, + texture_rgb: np.ndarray, + uvs: np.ndarray, + face_uv_indices: np.ndarray, + corner_inset: float = 0.90, +) -> np.ndarray: """ - 根据每个区域的边缘面UV坐标计算加权平均的颜色, - 当无有效采样时更新对应face_uv_indices。 - - 参数: - regions (List[Tuple[set, set]]): 每个区域为 (缺失面集合, 邻接面集合) - uv_map (torch.Tensor): 原始纹理贴图,RGB格式 - uvs (torch.Tensor): 原始UV坐标 - face_uv_indices (torch.Tensor): 每个面对应的UV索引 - device (str): 使用的设备("cuda"或"cpu") - - 返回: - Dict[int, torch.Tensor]: 键为区域索引,值为该区域加权平均计算得到的颜色(uint8) + 在正常 face 的某个顶点附近采样。 + 不直接采顶点本身,而是往三角形内部缩一点,避免 seam 边界采样不稳。 + corner_inset 越接近 1,越贴近该顶点。 """ - regions_face_color: Dict[int, torch.Tensor] = {} - for r_index, region in enumerate(tqdm.tqdm(regions, desc="Processing regions")): - region_faces_indexes = torch.tensor(list(region[0]), device=device) - region_edge_faces_indexes = torch.tensor(list(region[1]), device=device) - - if len(region_edge_faces_indexes) == 0: - continue - - # 获取边缘面的UV索引 - edge_face_uv_indices = face_uv_indices[region_edge_faces_indexes] - # 使用三角形的质心UV坐标来采样颜色 - triangle_uvs = uvs[edge_face_uv_indices] # shape: [num_faces, 3, 2] - centroid_uvs = triangle_uvs.mean(dim=1) # shape: [num_faces, 2] - - # 将UV坐标转换为像素坐标 - scale_tensor = torch.tensor([uv_map.shape[1] - 1, uv_map.shape[0] - 1], device=device) - pixel_coords = torch.round(centroid_uvs * scale_tensor) - pixel_coords[:, 1] = uv_map.shape[0] - 1 - pixel_coords[:, 1] - pixel_coords = pixel_coords.long().clamp(0, uv_map.shape[0] - 1) - - # 直接采样质心位置的颜色 - colors = uv_map[pixel_coords[:, 1], pixel_coords[:, 0]] # shape: [num_faces, 3] - - # 使用面积加权平均来计算最终颜色 - areas = torch.abs( - (triangle_uvs[:, 1, 0] - triangle_uvs[:, 0, 0]) * (triangle_uvs[:, 2, 1] - triangle_uvs[:, 0, 1]) - - (triangle_uvs[:, 2, 0] - triangle_uvs[:, 0, 0]) * (triangle_uvs[:, 1, 1] - triangle_uvs[:, 0, 1]) - ) * 0.5 - - if len(colors) > 0: - weighted_color = (colors.float() * areas.unsqueeze(1)).sum(dim=0) / areas.sum() - regions_face_color[r_index] = weighted_color.round().clamp(0, 255).to(torch.uint8) - else: - # 如果没有有效的采样点,使用第一个相邻面的UV坐标更新face_uv_indices - face_uv_indices[region_faces_indexes] = face_uv_indices[region_edge_faces_indexes[0]].unsqueeze(dim=0).clone() + tri_uv = uvs[face_uv_indices[face_id]] # [3,2] + w = np.full(3, (1.0 - corner_inset) / 2.0, dtype=np.float32) + w[corner] = corner_inset + uv = w @ tri_uv + return sample_texture_bilinear_rgb(texture_rgb, uv) + + +def median_color(colors: List[np.ndarray]) -> np.ndarray: + if len(colors) == 0: + return np.asarray([128, 128, 128], dtype=np.float32) + return np.median(np.stack(colors, axis=0).astype(np.float32), axis=0) - return regions_face_color +# ============================================================ +# Harmonic vertex color diffusion +# ============================================================ -def update_uv_map_and_indices( - uv_map: torch.Tensor, - uvs: torch.Tensor, - face_uv_indices: torch.Tensor, - regions: List[Tuple[set, set]], - regions_face_color: Dict[int, torch.Tensor], - device: str -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def build_region_vertex_graph(region_faces: Set[int], faces: np.ndarray): + """ + 对一个 missing island 构建 local vertex graph。 + 返回:unique_vertices, vertex_to_local, unique_edges, local_face_ids + """ + face_ids = np.asarray(sorted(region_faces), dtype=np.int64) + tri_v = faces[face_ids] # [F,3] + unique_vertices = np.unique(tri_v.reshape(-1)) + v2local = {int(v): i for i, v in enumerate(unique_vertices)} + + local_faces = np.vectorize(lambda x: v2local[int(x)], otypes=[np.int64])(tri_v) + + edges = np.concatenate([ + local_faces[:, [0, 1]], + local_faces[:, [1, 2]], + local_faces[:, [2, 0]], + ], axis=0) + edges.sort(axis=1) + edges = np.unique(edges, axis=0) + + return face_ids, unique_vertices, v2local, edges, local_faces + + +def compute_boundary_vertex_constraints( + unique_vertices: np.ndarray, + vertex_faces: List[List[int]], + faces: np.ndarray, + face_uv_indices: np.ndarray, + uvs: np.ndarray, + texture_rgb: np.ndarray, + missing_mask: np.ndarray, + corner_inset: float = 0.90, +): + """ + 对 region 内的每个顶点,如果它邻接任何非 missing face,就从这些正常 face 的对应 corner 采颜色。 + 这比“对边界 face 求一个均值”更合理,因为边界不同位置的颜色会保留下来。 """ - 根据计算得到的区域颜色,更新UV贴图及对应的UV坐标,并批量更新face_uv_indices。 - - 参数: - uv_map (torch.Tensor): 原始纹理贴图,RGB格式 - uvs (torch.Tensor): 原始UV坐标 - face_uv_indices (torch.Tensor): 原始面的UV索引 - regions (List[Tuple[set, set]]): 每个区域为 (缺失面集合, 邻接面集合) - regions_face_color (Dict[int, torch.Tensor]): 每个区域计算得到的颜色 - device (str): 使用的设备("cuda"或"cpu") - - 返回: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - new_uv_map: 更新后的UV贴图 - uvs_updated: 更新后的UV坐标(拼接上新计算的UV) - face_uv_indices: 更新后的face UV索引 + n = len(unique_vertices) + constrained = np.zeros(n, dtype=bool) + constraint_colors = np.zeros((n, 3), dtype=np.float32) + + for local_i, gv in enumerate(unique_vertices): + samples = [] + for f in vertex_faces[int(gv)]: + if missing_mask[f]: + continue + row = faces[f] + corners = np.where(row == gv)[0] + if len(corners) == 0: + continue + for c in corners: + samples.append(sample_face_corner_color( + f, + int(c), + texture_rgb, + uvs, + face_uv_indices, + corner_inset=corner_inset, + )) + + if samples: + constrained[local_i] = True + constraint_colors[local_i] = median_color(samples) + + return constrained, constraint_colors + + +def initialize_vertex_colors_by_bfs(edges: np.ndarray, constrained: np.ndarray, constraint_colors: np.ndarray) -> np.ndarray: + n = len(constrained) + colors = np.zeros((n, 3), dtype=np.float32) + + if constrained.any(): + colors[constrained] = constraint_colors[constrained] + fallback = np.mean(constraint_colors[constrained], axis=0) + else: + fallback = np.asarray([128, 128, 128], dtype=np.float32) + colors[:] = fallback + return colors + + adj = [[] for _ in range(n)] + for a, b in edges: + adj[int(a)].append(int(b)) + adj[int(b)].append(int(a)) + + q = deque(np.where(constrained)[0].tolist()) + visited = np.zeros(n, dtype=bool) + visited[constrained] = True + + while q: + cur = q.popleft() + for nb in adj[cur]: + if visited[nb]: + continue + colors[nb] = colors[cur] + visited[nb] = True + q.append(nb) + + colors[~visited] = fallback + return colors + + +def solve_harmonic_vertex_colors( + edges: np.ndarray, + constrained: np.ndarray, + constraint_colors: np.ndarray, + smooth_iters: int = 80, + self_weight: float = 0.0, +) -> np.ndarray: """ - total_regions = len(regions_face_color) - grid_size = uv_map.shape[1] // 3 - all_c = torch.div(torch.arange(total_regions, device=device), grid_size, rounding_mode='floor') - all_r = torch.remainder(torch.arange(total_regions, device=device), grid_size) - - # 创建新的颜色UV贴图 - color_uv_map = torch.full((int(uv_map.shape[0] / 2), uv_map.shape[1], 3), - 255, dtype=torch.uint8, device=device) - # 调整原始uvs的纵坐标 - uvs[:, 1] = uvs[:, 1] * (2 / 3) + 1 / 3 - - # 批量创建所有颜色块的坐标 - c_indices = all_c.unsqueeze(1).repeat(1, 9) * 3 + torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], - device=device).unsqueeze(0) - r_indices = all_r.unsqueeze(1).repeat(1, 9) * 3 + torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], - device=device).unsqueeze(0) - - # 批量设置颜色 - colors = torch.stack([color for _, color in sorted(regions_face_color.items(), key=lambda x: x[0])]) - colors_repeated = colors.unsqueeze(1).repeat(1, 9, 1) - color_uv_map[c_indices.flatten(), r_indices.flatten()] = colors_repeated.reshape(-1, 3) - - # 批量计算新的UV坐标 - pixels = torch.stack([ - all_r * 3 + 1, - uv_map.shape[0] + all_c * 3 + 1 - ], dim=1).to(device) - u_new = pixels[:, 0].float() / (uv_map.shape[1] - 1) - new_height = int(uv_map.shape[0] + uv_map.shape[0] / 2) - v_new = (new_height - 1 - pixels[:, 1].float()) / (new_height - 1) - new_uvs = torch.stack([u_new, v_new], dim=1) - - # 更新UV坐标:拼接新计算的UV - uvs_updated = torch.cat([uvs, new_uvs], dim=0) - uv_coordinates_start = uvs_updated.shape[0] - total_regions - - # 批量更新face_uv_indices - for i, (region_index, _) in enumerate(sorted(regions_face_color.items(), key=lambda x: x[0])): - region_faces_indexes = torch.tensor(list(regions[region_index][0]), device=device) - face_uv_indices[region_faces_indexes] = torch.full((1, 3), uv_coordinates_start + i, device=device) - - # 合并原始UV贴图和新的颜色UV贴图 - new_uv_map = torch.cat((uv_map, color_uv_map), dim=0) - - return new_uv_map, uvs_updated, face_uv_indices - -def group_regions_by_y_axis( - regions: List[Tuple[set, set]], - vertices: torch.Tensor, - triangle_vertex_indices: torch.Tensor, - device: str, - interval_size: float = 0.1 -) -> Dict[int, List[int]]: + 对 region 顶点颜色做 Dirichlet boundary harmonic extension。 + constrained 顶点保持边界颜色;非 constrained 顶点反复取邻居平均。 """ - 将区域按照y轴高度分组 + n = len(constrained) + colors = initialize_vertex_colors_by_bfs(edges, constrained, constraint_colors) + + if n == 0 or len(edges) == 0: + return colors + + e0 = edges[:, 0].astype(np.int64) + e1 = edges[:, 1].astype(np.int64) + + deg = np.zeros(n, dtype=np.float32) + np.add.at(deg, e0, 1.0) + np.add.at(deg, e1, 1.0) + denom = deg[:, None] + float(self_weight) + denom = np.maximum(denom, 1e-8) + + for _ in range(max(0, int(smooth_iters))): + acc = np.zeros_like(colors) + np.add.at(acc, e0, colors[e1]) + np.add.at(acc, e1, colors[e0]) + + if self_weight > 0: + acc += colors * float(self_weight) + + new_colors = acc / denom + # Dirichlet boundary:边界顶点锁住,不允许被内部均化掉。 + new_colors[constrained] = constraint_colors[constrained] + colors = new_colors - Args: - regions: 区域列表,每个区域为(缺失面集合, 邻接面集合)的元组 - vertices: 顶点坐标张量 - triangle_vertex_indices: 三角形顶点索引张量 - device: 计算设备 ('cuda' 或 'cpu') - interval_size: y轴分组的间隔大小,默认为0.1 + return np.clip(colors, 0, 255).astype(np.float32) - Returns: - Dict[int, List[int]]: 以y轴区间为键,区域索引列表为值的字典 + +# ============================================================ +# Per-face triangular atlas baking +# ============================================================ + +def precompute_triangle_tile(tile_size: int, pad: int): + """ + 预计算一个标准三角 tile 的 barycentric weights 和 padding 最近邻映射。 + 后续每个 missing face 只需要 weights @ 三个顶点颜色。 + """ + tile_size = int(tile_size) + pad = int(pad) + if tile_size < 4: + raise ValueError("tile_size 至少要 >= 4") + if pad < 1: + pad = 1 + if pad * 2 + 2 >= tile_size: + pad = max(1, tile_size // 4) + + # 标准右三角,三个 UV 角点都留 padding,避免双线性采样采到 tile 外。 + p0 = np.asarray([pad, pad], dtype=np.float32) + p1 = np.asarray([tile_size - pad - 1, pad], dtype=np.float32) + p2 = np.asarray([pad, tile_size - pad - 1], dtype=np.float32) + + yy, xx = np.meshgrid(np.arange(tile_size, dtype=np.float32), + np.arange(tile_size, dtype=np.float32), indexing="ij") + pts = np.stack([xx + 0.5, yy + 0.5], axis=-1) # [T,T,2], pixel centers + + # barycentric + v0 = p1 - p0 + v1 = p2 - p0 + v2 = pts - p0 + d00 = float(np.dot(v0, v0)) + d01 = float(np.dot(v0, v1)) + d11 = float(np.dot(v1, v1)) + denom = d00 * d11 - d01 * d01 + if abs(denom) < 1e-8: + raise ValueError("退化 tile triangle") + + d20 = v2[..., 0] * v0[0] + v2[..., 1] * v0[1] + d21 = v2[..., 0] * v1[0] + v2[..., 1] * v1[1] + w1 = (d11 * d20 - d01 * d21) / denom + w2 = (d00 * d21 - d01 * d20) / denom + w0 = 1.0 - w1 - w2 + weights = np.stack([w0, w1, w2], axis=-1).astype(np.float32) + + inside = (weights[..., 0] >= -1e-4) & (weights[..., 1] >= -1e-4) & (weights[..., 2] >= -1e-4) + + # padding:tile 外部/三角外部像素复制最近的三角内部像素颜色。 + inside_coords = np.argwhere(inside) + nearest_y = np.zeros((tile_size, tile_size), dtype=np.int32) + nearest_x = np.zeros((tile_size, tile_size), dtype=np.int32) + for y in range(tile_size): + for x in range(tile_size): + if inside[y, x]: + nearest_y[y, x] = y + nearest_x[y, x] = x + else: + d2 = (inside_coords[:, 0] - y) ** 2 + (inside_coords[:, 1] - x) ** 2 + k = int(np.argmin(d2)) + nearest_y[y, x] = int(inside_coords[k, 0]) + nearest_x[y, x] = int(inside_coords[k, 1]) + + padded_weights = weights[nearest_y, nearest_x] + + # UV 角点使用像素中心,更稳定。 + uv_points = np.stack([p0, p1, p2], axis=0).astype(np.float32) + + return padded_weights, uv_points + + +def build_harmonic_atlas_texture( + texture_rgb: np.ndarray, + original_uvs: np.ndarray, + faces: np.ndarray, + face_uv_indices: np.ndarray, + missing_faces: np.ndarray, + face_corner_colors: Dict[int, np.ndarray], + tile_size: int = 12, + tile_pad: int = 2, +): """ - y_intervals = defaultdict(list) - for r_index, region in enumerate(regions): - region_faces_indexes = torch.tensor(list(region[0]), device=device) - # 计算面组的平均y轴位置 - face_vertices = vertices[triangle_vertex_indices[region_faces_indexes]] - avg_y = face_vertices[:, :, 1].mean(dim=(0, 1)) - - # 根据y轴位置分配到对应区间 - interval_key = int(avg_y // interval_size) - y_intervals[interval_key].append(r_index) - - return dict(y_intervals) - -def align_regions_colors( - regions_face_color: Dict[int, torch.Tensor], - y_intervals: Dict[int, List[int]], - regions: List[Tuple[set, set]] -) -> Dict[int, torch.Tensor]: + 给每个 missing face 分配一个小三角 tile。 + tile 内按三个顶点颜色插值,不是纯色块。 """ - 对齐区间内的颜色 + H, W = texture_rgb.shape[:2] + missing_faces_sorted = np.asarray(sorted(set(int(x) for x in missing_faces)), dtype=np.int64) + n_missing = len(missing_faces_sorted) + + cols = max(1, W // tile_size) + rows = int(np.ceil(n_missing / cols)) + atlas_h = rows * tile_size + new_H = H + atlas_h + + print(f"atlas tiles: {n_missing}, cols={cols}, rows={rows}, atlas_h={atlas_h}") + + new_texture = np.full((new_H, W, 3), 255, dtype=np.uint8) + new_texture[:H, :, :] = texture_rgb + + # 原始 UV remap:保持原图像素位置不变,只是贴图高度变大。 + remapped_uvs = original_uvs.copy().astype(np.float32) + old_y = (1.0 - remapped_uvs[:, 1]) * (H - 1) + remapped_uvs[:, 1] = 1.0 - old_y / max(new_H - 1, 1) + + new_uvs: List[List[float]] = remapped_uvs.tolist() + new_face_uv_indices = face_uv_indices.copy() + + weights_map, uv_points = precompute_triangle_tile(tile_size, tile_pad) + weights_flat = weights_map.reshape(-1, 3) # [T*T,3] + + for i, f in enumerate(tqdm.tqdm(missing_faces_sorted, desc="Baking missing face tiles")): + col = i % cols + row = i // cols + x0 = col * tile_size + y0 = H + row * tile_size + + corner_colors = face_corner_colors.get(int(f)) + if corner_colors is None: + corner_colors = np.full((3, 3), 128, dtype=np.float32) + corner_colors = np.asarray(corner_colors, dtype=np.float32) + + tile = weights_flat @ corner_colors # [T*T,3] + tile = np.clip(np.rint(tile), 0, 255).astype(np.uint8).reshape(tile_size, tile_size, 3) + + new_texture[y0:y0 + tile_size, x0:x0 + tile_size, :] = tile + + new_vt = [] + for p in uv_points: + px = x0 + float(p[0]) + py = y0 + float(p[1]) + u = px / max(W - 1, 1) + v = 1.0 - py / max(new_H - 1, 1) + new_uvs.append([u, v]) + new_vt.append(len(new_uvs) - 1) + + new_face_uv_indices[int(f)] = np.asarray(new_vt, dtype=new_face_uv_indices.dtype) + + return new_texture, np.asarray(new_uvs, dtype=np.float32), new_face_uv_indices + + +# ============================================================ +# Main process +# ============================================================ + +def process( + input_obj_path: str, + input_texture_path: str, + missing_faces_path: str, + output_obj_path: str, + output_texture_path: str, + missing_index_base: int = 0, + tile_size: int = 12, + tile_pad: int = 2, + smooth_iters: int = 80, + self_weight: float = 0.0, + corner_inset: float = 0.90, +): + start = time.time() + + print("Reading input files...") + vertices, uvs, faces, face_uv_indices, mtllib_lines, usemtl_name = read_obj_basic(input_obj_path) + texture_rgb = read_texture_rgb(input_texture_path) + missing_faces = read_missing_faces(missing_faces_path, index_base=missing_index_base) + + total_faces = len(faces) + missing_faces = missing_faces[(missing_faces >= 0) & (missing_faces < total_faces)] + missing_faces = np.unique(missing_faces) + missing_mask = np.zeros(total_faces, dtype=bool) + missing_mask[missing_faces] = True + + print(f"vertices: {len(vertices)}") + print(f"uvs: {len(uvs)}") + print(f"faces: {len(faces)}") + print(f"missing faces: {len(missing_faces)}") + print(f"texture: {texture_rgb.shape[1]} x {texture_rgb.shape[0]}") + + if len(missing_faces) == 0: + write_obj_with_uv_coordinates(output_obj_path, vertices, uvs, faces, face_uv_indices, mtllib_lines, usemtl_name) + write_texture_rgb(output_texture_path, texture_rgb) + return + + t0 = time.time() + print("Building adjacency...") + face_adjacency, vertex_faces = build_face_adjacency_and_vertex_faces(faces, len(vertices)) + print(f"adjacency using: {time.time() - t0:.2f}s") + + t0 = time.time() + regions = find_missing_regions(face_adjacency, missing_faces, total_faces) + print(f"regions using: {time.time() - t0:.2f}s") + + # face -> three corner colors + face_corner_colors: Dict[int, np.ndarray] = {} + + for region_idx, (region_faces, _boundary_faces) in enumerate(tqdm.tqdm(regions, desc="Solving region colors")): + face_ids, unique_vertices, v2local, edges, local_faces = build_region_vertex_graph(region_faces, faces) + + constrained, constraint_colors = compute_boundary_vertex_constraints( + unique_vertices=unique_vertices, + vertex_faces=vertex_faces, + faces=faces, + face_uv_indices=face_uv_indices, + uvs=uvs, + texture_rgb=texture_rgb, + missing_mask=missing_mask, + corner_inset=corner_inset, + ) - Args: - regions_face_color: 每个区域的颜色 - y_intervals: 每个y轴区间的区域索引列表 + if not constrained.any(): + # 极少数情况:region 没有任何正常面顶点约束。 + # 这种输入本身没有颜色来源,只能给灰色,避免崩。 + print(f"Warning: region {region_idx} has no boundary vertex color constraints, using gray.") + vertex_colors = np.full((len(unique_vertices), 3), 128, dtype=np.float32) + else: + vertex_colors = solve_harmonic_vertex_colors( + edges=edges, + constrained=constrained, + constraint_colors=constraint_colors, + smooth_iters=smooth_iters, + self_weight=self_weight, + ) + + # 写成每个 missing face 的三个角颜色。 + for local_fi, f in enumerate(face_ids): + lf = local_faces[local_fi] + face_corner_colors[int(f)] = vertex_colors[lf] + + print("Building harmonic atlas...") + new_texture, new_uvs, new_face_uv_indices = build_harmonic_atlas_texture( + texture_rgb=texture_rgb, + original_uvs=uvs, + faces=faces, + face_uv_indices=face_uv_indices, + missing_faces=missing_faces, + face_corner_colors=face_corner_colors, + tile_size=tile_size, + tile_pad=tile_pad, + ) - Returns: - Dict[int, torch.Tensor]: 以y轴区间为键,颜色为值的字典 - """ - # aligned_regions_face_color = {} - large_group_threshold_min = 5000 - large_group_threshold_max = 100000 - for interval_key, region_indices in y_intervals.items(): - large_groups = [] - # normal_groups = [] - for r_index in region_indices: - region = regions[r_index] - if len(region[0]) >= large_group_threshold_min and len(region[0]) <= large_group_threshold_max: - large_groups.append((r_index, len(region[0]), regions_face_color[r_index])) - - # 查找 large_groups 中 len(region[0]) 最大的组,并获取其颜色 - if large_groups: - largest_group = max(large_groups, key=lambda x: x[1]) - color: torch.Tensor = largest_group[2] - for large_group in large_groups: - regions_face_color[large_group[0]] = color - - return regions_face_color - - -def process(input_obj_path, input_texture_path, missing_faces_path, output_obj_path, output_texture_path): - start_time = time.time() - - device = 'cuda' if torch.cuda.is_available() else 'cpu' - vertices, uvs, triangle_vertex_indices, face_uv_indices, missing_color_faces, uv_map = parse_obj_file_and_uv_map( - input_obj_path, missing_faces_path, input_texture_path, device=device) - - # 构建面的邻接关系和找到区域 - start_face_adjacency_time = time.time() - face_adjacency = build_face_adjacency(vertices.cpu().numpy(), triangle_vertex_indices.cpu().numpy()) - end_face_adjacency_time = time.time() - print(f"face_adjacency using: {end_face_adjacency_time - start_face_adjacency_time} seconds") - - start_find_groups_time = time.time() - regions = find_groups_and_subgroups(face_adjacency, missing_color_faces) - end_find_groups_time = time.time() - print(f"find_groups_and_subgroups using: {end_find_groups_time - start_find_groups_time} seconds") - - start_texture_map_time = time.time() - # 使用新封装的函数计算每个区域的加权平均颜色 - regions_face_color = compute_regions_face_colors(regions, uv_map, uvs, face_uv_indices, device) - end_texture_map_time = time.time() - print(f"texture_mapping_to_triangle using: {end_texture_map_time - start_texture_map_time} seconds") - - # 按y轴区间分组 - y_intervals = group_regions_by_y_axis( - regions, + print("Writing outputs...") + write_obj_with_uv_coordinates( + output_obj_path, vertices, - triangle_vertex_indices, - device + new_uvs, + faces, + new_face_uv_indices, + mtllib_lines=mtllib_lines, + usemtl_name=usemtl_name, ) + write_texture_rgb(output_texture_path, new_texture) + + print(f"output texture: {new_texture.shape[1]} x {new_texture.shape[0]}") + print(f"new uv count: {len(new_uvs)}") + print(f"Total using: {time.time() - start:.2f}s") - # 对齐区间内的颜色 - regions_face_color = align_regions_colors(regions_face_color, y_intervals, regions) - - # 更新UV贴图和面索引 - start_color_map_time = time.time() - new_uv_map, uvs, face_uv_indices = update_uv_map_and_indices(uv_map, uvs, face_uv_indices, regions, - regions_face_color, device) - end_color_map_time = time.time() - print(f"color_mapping_to_triangle using: {end_color_map_time - start_color_map_time} seconds") - - end_time = time.time() - print(f"using: {end_time - start_time} seconds") - - # 写入OBJ和纹理贴图 - start_write_time = time.time() - - vertices_cpu = vertices.cpu().numpy() - uvs_cpu = uvs.cpu().numpy() - triangle_vertex_indices_cpu = triangle_vertex_indices.cpu().numpy() - face_uv_indices_cpu = face_uv_indices.cpu().numpy() - new_uv_map_cpu = new_uv_map.cpu().numpy() - new_uv_map_bgr = cv2.cvtColor(new_uv_map_cpu, cv2.COLOR_RGB2BGR) - - with Pool(2) as p: - # 异步执行OBJ和纹理图写入操作 - obj_future = p.apply_async(write_obj_with_uv_coordinates, - (output_obj_path, vertices_cpu, uvs_cpu, - triangle_vertex_indices_cpu, face_uv_indices_cpu)) - - img_future = p.apply_async(cv2.imwrite, - (output_texture_path, new_uv_map_bgr, - [cv2.IMWRITE_PNG_COMPRESSION, 3])) - - obj_future.get() - img_future.get() - - end_write_time = time.time() - end_time = time.time() - print(f"Total file writing time: {end_write_time - start_write_time:.2f} seconds") - print(f"using: {end_time - start_time} seconds") def main(): - parser = argparse.ArgumentParser(description='Process OBJ files to fix missing color faces.') - parser.add_argument('--input_obj', type=str, required = True, help='Path to the input OBJ file') - parser.add_argument('--input_texture', type=str, required = True, help='Path to the texture file') - parser.add_argument('--missing_faces', type=str, required = True, help='Path to the file with indices of missing color faces') - parser.add_argument('--output_obj', type=str, required = True, help='Path to the output OBJ file') - parser.add_argument('--output_texture', type=str, required = True, help='Path to the texture file') - + parser = argparse.ArgumentParser(description="Fill missing color faces with harmonic vertex-color atlas baking.") + + # 原始接口保持一致 + parser.add_argument("--input_obj", type=str, required=True, help="Path to the input OBJ file") + parser.add_argument("--input_texture", type=str, required=True, help="Path to the input texture file") + parser.add_argument("--missing_faces", type=str, required=True, help="Path to missing face index file") + parser.add_argument("--output_obj", type=str, required=True, help="Path to the output OBJ file") + parser.add_argument("--output_texture", type=str, required=True, help="Path to the output texture file") + + # 可选参数 + parser.add_argument("--missing_index_base", type=int, default=0, + help="missing_faces.txt 的索引基准。原脚本默认 0-based;如果文件是 1-based,传 1。") + parser.add_argument("--tile_size", type=int, default=12, + help="每个 missing face 在新增 atlas 中的 tile 尺寸。越大越细,贴图越大。推荐 8~14。") + parser.add_argument("--tile_pad", type=int, default=2, + help="每个 tile 内三角形到边界的 padding。用于减少纹理过滤串色。推荐 2。") + parser.add_argument("--smooth_iters", type=int, default=80, + help="顶点颜色 harmonic 扩散迭代次数。越大越平滑,稍慢。推荐 50~120。") + parser.add_argument("--self_weight", type=float, default=0.0, + help="扩散时保留当前颜色的权重。一般用 0。") + parser.add_argument("--corner_inset", type=float, default=0.90, + help="边界顶点采样时往正常 face 内部缩进的比例。0.85~0.95 比较稳。") + args = parser.parse_args() - process(args.input_obj, args.input_texture, args.missing_faces, args.output_obj, args.output_texture) -if __name__ == '__main__': - main() \ No newline at end of file + process( + input_obj_path=args.input_obj, + input_texture_path=args.input_texture, + missing_faces_path=args.missing_faces, + output_obj_path=args.output_obj, + output_texture_path=args.output_texture, + missing_index_base=args.missing_index_base, + tile_size=args.tile_size, + tile_pad=args.tile_pad, + smooth_iters=args.smooth_iters, + self_weight=args.self_weight, + corner_inset=args.corner_inset, + ) + + +if __name__ == "__main__": + main()