single tools scripts
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

203 lines
7.2 KiB

import cv2
import os
import argparse
import numpy as np
from ultralytics import YOLO
# ===================== 模型路径配置 =====================
# 请确保此路径下有模型文件,如果没有会自动下载
YOLO_FACE_MODEL_PATH = "/root/code/lenovo_qinzhe_ssh/sam3/sam3_new/face/yolov8n-face.pt"
# =======================================================
# 加载YOLO人脸检测模型
yolo_model = YOLO(YOLO_FACE_MODEL_PATH)
# 支持的图片格式(包含大小写)
SUPPORTED_FORMATS = {
'jpg': 'jpg',
'jpeg': 'jpeg',
'png': 'png',
'bmp': 'bmp',
'JPG': 'jpg',
'JPEG': 'jpeg',
'PNG': 'png',
'BMP': 'bmp'
}
def detect_face_bbox(image, conf_threshold=0.5):
"""
使用YOLO检测人脸边界框
:param image: 输入BGR图像
:param conf_threshold: 置信度阈值
:return: 人脸边界框列表 [(x1, y1, x2, y2), ...],若无检测结果返回空列表
"""
results = yolo_model(image, conf=conf_threshold)
bboxes = []
if results and len(results) > 0:
for box in results[0].boxes:
if box.conf[0] >= conf_threshold:
x1, y1, x2, y2 = map(int, box.xyxy[0])
bboxes.append((x1, y1, x2, y2))
return bboxes
def is_supported_image(file_path):
"""
检查文件是否为支持的图片格式
"""
if not os.path.isfile(file_path):
return False
file_ext = os.path.splitext(file_path)[1].lstrip('.')
return file_ext in SUPPORTED_FORMATS
def get_output_path(input_img_path, output_arg):
"""
根据输出参数确定最终的输出路径
:param input_img_path: 输入图片路径
:param output_arg: 命令行传入的-o参数(目录或文件路径)
:return: 最终的输出图片路径
"""
# 如果输出参数是目录
if os.path.isdir(output_arg):
img_name = os.path.basename(input_img_path)
return os.path.join(output_arg, img_name)
# 如果输出参数是文件路径(且目录存在)
output_dir = os.path.dirname(output_arg)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 检查输出文件格式是否支持
output_ext = os.path.splitext(output_arg)[1].lstrip('.')
if output_ext not in SUPPORTED_FORMATS:
# 使用输入文件的格式
input_ext = os.path.splitext(input_img_path)[1].lstrip('.')
input_format = SUPPORTED_FORMATS.get(input_ext, 'jpg')
output_arg = f"{os.path.splitext(output_arg)[0]}.{input_format}"
return output_arg
def crop_face_center_1024(input_path, output_arg, conf_threshold=0.5):
"""
以检测到的人脸为中心裁剪出1024×1024的图片
:param input_path: 输入图片文件路径
:param output_arg: 输出图片路径或目录路径
:param conf_threshold: YOLO人脸检测置信度阈值
:return: 裁剪是否成功
"""
# 读取图片(确保以彩色模式读取)
image = cv2.imread(input_path, cv2.IMREAD_COLOR)
if image is None:
print(f"❌ 无法读取图片:{input_path}")
return False
h, w = image.shape[:2]
# ========== YOLO人脸检测 ==========
face_bboxes = detect_face_bbox(image, conf_threshold)
if not face_bboxes:
print(f" 未检测到人脸:{input_path}")
return False
# 取第一个检测到的人脸
x1, y1, x2, y2 = face_bboxes[0]
# 计算人脸中心坐标
face_center_x = (x1 + x2) // 2
face_center_y = (y1 + y2) // 2
# 计算1024×1024裁剪区域的边界(以人脸为中心)
crop_size = 1024
half_size = crop_size // 2
# 计算裁剪区域的左上角和右下角坐标
crop_x1 = face_center_x - half_size
crop_y1 = face_center_y - half_size
crop_x2 = crop_x1 + crop_size
crop_y2 = crop_y1 + crop_size
# 处理边界越界问题(如果裁剪区域超出图片范围)
pad_left = max(0, -crop_x1)
pad_top = max(0, -crop_y1)
pad_right = max(0, crop_x2 - w)
pad_bottom = max(0, crop_y2 - h)
# 调整裁剪坐标到图片范围内
crop_x1_clamped = max(0, crop_x1)
crop_y1_clamped = max(0, crop_y1)
crop_x2_clamped = min(w, crop_x2)
crop_y2_clamped = min(h, crop_y2)
# 裁剪图片
cropped_image = image[crop_y1_clamped:crop_y2_clamped, crop_x1_clamped:crop_x2_clamped]
# 如果有越界,用黑色填充(保持图片维度)
if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
cropped_image = cv2.copyMakeBorder(
cropped_image,
pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT,
value=[0, 0, 0] # 黑色填充(BGR格式)
)
# 强制调整为1024×1024(确保输出尺寸准确)
cropped_image = cv2.resize(cropped_image, (1024, 1024), interpolation=cv2.INTER_AREA)
# 确定最终输出路径
output_path = get_output_path(input_path, output_arg)
# 保存图片(根据格式选择合适的编码)
file_ext = os.path.splitext(output_path)[1].lstrip('.')
file_format = SUPPORTED_FORMATS.get(file_ext, 'jpg')
save_success = False
if file_format == 'jpg' or file_format == 'jpeg':
# JPG格式:95%质量保存
save_success = cv2.imwrite(output_path, cropped_image, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
elif file_format == 'png':
# PNG格式:无损压缩
save_success = cv2.imwrite(output_path, cropped_image, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
else:
# 其他格式默认保存
save_success = cv2.imwrite(output_path, cropped_image)
if save_success:
print(f"✅ 裁剪完成:{output_path} (尺寸:1024×1024,人脸中心:({face_center_x}, {face_center_y}))")
return True
else:
print(f"❌ 保存图片失败:{output_path}")
return False
def main():
# 创建参数解析器
parser = argparse.ArgumentParser(description='以人脸为中心裁剪1024×1024图片的工具')
# 添加命令行参数
parser.add_argument('-i', '--input', required=True,
help='输入图片文件路径(仅支持单张图片,格式:jpg/jpeg/png/bmp)')
parser.add_argument('-o', '--output', required=True,
help='输出图片路径(如:/path/to/output.jpg)或输出目录路径(如:/path/to/output/)')
parser.add_argument('-c', '--conf', type=float, default=0.5,
help='人脸检测置信度阈值(默认0.5,值越高检测越严格)')
# 解析参数
args = parser.parse_args()
# 验证输入文件
if not os.path.exists(args.input):
print(f"❌ 错误:输入文件不存在:{args.input}")
return
if not is_supported_image(args.input):
print(f"❌ 错误:输入文件不是支持的图片格式(支持:{', '.join(set(SUPPORTED_FORMATS.values()))}")
return
# 执行裁剪
success = crop_face_center_1024(args.input, args.output, args.conf)
if success:
print("\n🎉 图片裁剪任务完成!")
else:
print("\n❌ 图片裁剪任务失败!")
if __name__ == "__main__":
main()