You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
9.7 KiB
237 lines
9.7 KiB
import os, oss2, sys, time |
|
import cv2 |
|
import numpy as np |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
from PIL import Image, ImageEnhance |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
import config, libs |
|
|
|
def ps_color_scale_adjustment(image, shadow=0, highlight=255, midtones=1): |
|
''' |
|
模拟 PS 的色阶调整; 0 <= Shadow < Highlight <= 255 |
|
:param image: 传入的图片 |
|
:param shadow: 黑场(0-Highlight) |
|
:param highlight: 白场(Shadow-255) |
|
:param midtones: 灰场(9.99-0.01) |
|
:return: 图片 |
|
''' |
|
if highlight > 255: |
|
highlight = 255 |
|
if shadow < 0: |
|
shadow = 0 |
|
if shadow >= highlight: |
|
shadow = highlight - 2 |
|
if midtones > 9.99: |
|
midtones = 9.99 |
|
if midtones < 0.01: |
|
midtones = 0.01 |
|
image = np.array(image, dtype=np.float16) |
|
# 计算白场 黑场离差 |
|
Diff = highlight - shadow |
|
image = image - shadow |
|
image[image < 0] = 0 |
|
image = (image / Diff) ** (1 / midtones) * 255 |
|
image[image > 255] = 255 |
|
image = np.array(image, dtype=np.uint8) |
|
|
|
return image |
|
|
|
|
|
def show_histogram(image, image_id, save_hist_dir, min_threshold, max_threshold): |
|
''' |
|
画出直方图展示 |
|
:param image: 导入图片 |
|
:param image_id: 图片id编号 |
|
:param save_hist_dir: 保存路径 |
|
:param min_threshold: 最小阈值 |
|
:param max_threshold: 最大阈值 |
|
:return: 原图image,和裁剪原图直方图高低阈值后的图片image_change |
|
''' |
|
plt.rcParams['font.family'] = 'SimHei' |
|
plt.rcParams['axes.unicode_minus'] = False |
|
plt.hist(image.ravel(), 254, range=(2, 256), density=False) |
|
plt.hist(image.ravel(), 96, range=(2, 50), density=False) # 放大 range(0, 50),bins值最好是range的两倍,显得更稀疏,便于对比 |
|
plt.hist(image.ravel(), 110, range=(200, 255), density=False) # 放大 range(225, 255) |
|
plt.annotate('thresh1=' + str(min_threshold), # 文本内容 |
|
xy=(min_threshold, 0), # 箭头指向位置 # 阈值设定值! |
|
xytext=(min_threshold, 500000), # 文本位置 # 阈值设定值! |
|
arrowprops=dict(facecolor='black', width=1, shrink=5, headwidth=2)) # 箭头 |
|
plt.annotate('thresh2=' + str(max_threshold), # 文本内容 |
|
xy=(max_threshold, 0), # 箭头指向位置 # 阈值设定值! |
|
xytext=(max_threshold, 500000), # 文本位置 # 阈值设定值! |
|
arrowprops=dict(facecolor='black', width=1, shrink=5, headwidth=2)) # 箭头 |
|
# 在y轴上绘制一条直线 |
|
# plt.axhline(y=10000, color='r', linestyle='--', linewidth=0.5) |
|
plt.title(str(image_id)) |
|
# plt.show() |
|
# 保存直方图 |
|
save_hist_name = os.path.join(save_hist_dir, f'{image_id}_{min_threshold}&{max_threshold}.jpg') |
|
plt.savefig(save_hist_name) |
|
# 清空画布, 防止重叠展示 |
|
plt.clf() |
|
|
|
|
|
def low_find_histogram_range(image, target_frequency): |
|
''' |
|
循环查找在 target_frequency (y)频次限制下的直方图区间值(x) |
|
:param image: 导入图片 |
|
:param target_frequency: 直方图 y 频次限制条件 |
|
:return: 直方图区间 x,和 该区间频次 y |
|
''' |
|
# 计算灰度直方图 |
|
hist, bins = np.histogram(image, bins=256, range=[0, 256]) |
|
# 初始化区间和频次 |
|
interval = 2 |
|
frequency = hist[255] |
|
while frequency < target_frequency: |
|
# 更新区间和频次 |
|
interval += 1 |
|
# 检查直方图的频次是否为None,如果频次是None,则将其设为0,这样可以避免将None和int进行比较报错。 |
|
frequency = hist[interval] if hist[interval] is not None else 0 |
|
frequency += hist[interval] if hist[interval] is not None else 0 |
|
print(f'x={interval}, y={frequency}') |
|
# 如果频次接近10000则停止循环 |
|
if target_frequency - 2000 <= frequency <= target_frequency + 1000: |
|
break |
|
|
|
return interval, frequency |
|
|
|
|
|
def high_find_histogram_range(image, target_frequency): |
|
''' |
|
循环查找在 target_frequency (y)频次限制下的直方图区间值(x) |
|
:param image: 导入图片 |
|
:param target_frequency: 直方图 y 频次限制条件 |
|
:return: 直方图区间 x,和 该区间频次 y |
|
''' |
|
# 计算灰度直方图 |
|
hist, bins = np.histogram(image, bins=256, range=[0, 256]) |
|
# 初始化区间和频次 |
|
interval = 255 |
|
frequency = hist[255] |
|
while frequency < target_frequency: |
|
# 更新区间和频次 |
|
interval -= 1 |
|
# 检查直方图的频次是否为None,如果频次是None,则将其设为0,这样可以避免将None和int进行比较报错。 |
|
frequency = hist[interval] if hist[interval] is not None else 0 |
|
frequency += hist[interval] if hist[interval] is not None else 0 |
|
# 如果频次接近10000则停止循环 |
|
if target_frequency - 2000 <= frequency <= target_frequency + 2000: |
|
break |
|
|
|
return interval, frequency |
|
|
|
|
|
def sharpening_filter(image): |
|
''' |
|
锐化滤波器对图片进行锐化,增强图像中的边缘和细节 |
|
:param image: 导入图片 |
|
:return: 锐化后的图片 |
|
''' |
|
# sharp_kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) |
|
sharp_kernel = np.array([[0, -0.5, 0], [-0.5, 3, -0.5], [0, -0.5, 0]]) |
|
sharpened_image = cv2.filter2D(image, -1, sharp_kernel) |
|
return sharpened_image |
|
|
|
def reduce_sharpness(image, factor): |
|
''' |
|
使用PIL库减弱图像锐度 |
|
:param image: 图像 |
|
:param factor: 锐度因子,0表示最大程度减弱锐度,1表示原始图像 |
|
:return: 减弱锐度后的图像 |
|
''' |
|
# OpenCV 格式的图像转换为 PIL 的 Image 对象 |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
pil_image = Image.fromarray(image_rgb) |
|
enhancer = ImageEnhance.Sharpness(pil_image) |
|
reduced_image = enhancer.enhance(factor) |
|
# PIL 的 Image 对象转换为 OpenCV 的图像格式 |
|
image_array = np.array(reduced_image) |
|
sharpened_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR) |
|
|
|
return sharpened_image |
|
|
|
def find_last_x(image, slope_threshold = 1000): |
|
x = [] |
|
y = [] |
|
hist, bins = np.histogram(image, bins=256, range=[0, 256]) |
|
for i in range(2, 50): |
|
x.append(i) |
|
y.append(hist[i]) |
|
slopes = [abs(y[i + 1] - y[i]) for i in range(len(x) - 1)] |
|
|
|
current_interval = [] |
|
max_interval = [] |
|
max_x = {} |
|
for i, slope in enumerate(slopes): |
|
current_interval.append(slope) |
|
if slope >= slope_threshold: |
|
if len(current_interval) > len(max_interval): |
|
max_interval = current_interval.copy() |
|
max_x[x[i]] = slope |
|
current_interval = [] |
|
|
|
print(max_x) |
|
last_x = list(max_x)[-1] |
|
last_y = max_x[last_x] |
|
return last_x, last_y |
|
|
|
def main(pid, print_id): |
|
texture_filename = f'{input_dir}{pid}Tex1.{print_id}.jpg' |
|
input_image = cv2.imread(texture_filename) |
|
# low_x_thresh, low_y_frequency = low_find_histogram_range(input_image, low_y_limit) |
|
low_x_thresh, low_y_frequency = find_last_x(input_image, 1000) |
|
high_x_thresh, high_y_frequency = high_find_histogram_range(input_image, high_y_limit) |
|
print(f"{low_x_thresh} 区间, {low_y_frequency} 频次") |
|
print(f"{high_x_thresh} 区间, {high_y_frequency} 频次") |
|
high_output_image = ps_color_scale_adjustment(input_image, shadow=low_x_thresh, highlight=high_x_thresh, midtones=1) |
|
# high_output_image = ps_color_scale_adjustment(low_ouput_image, shadow=0, highlight=high_x_thresh, midtones=1) |
|
|
|
# 人体贴图和黑色背景交界处不进行锐化 |
|
gray = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) |
|
_, thresh = cv2.threshold(gray, 2, 255, cv2.THRESH_BINARY) |
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7)) |
|
gradient = cv2.morphologyEx(thresh, cv2.MORPH_GRADIENT, kernel) |
|
roi_gradient = cv2.bitwise_and(high_output_image, high_output_image, mask=gradient) |
|
|
|
# 锐化滤波器 |
|
# sharpened_image = sharpening_filter(high_output_image) |
|
sharpened_image = reduce_sharpness(high_output_image, factor=4) |
|
# 将原图边界替换锐化后的图片边界 |
|
sharpened_image[gradient != 0] = roi_gradient[gradient != 0] |
|
|
|
# 直方图标记并保存 |
|
# show_histogram(input_image, img_id, low_x_thresh, high_x_thresh) |
|
cv2.imwrite(texture_filename, sharpened_image, [cv2.IMWRITE_JPEG_QUALITY, 95]) # 保存图片的质量是原图的 95% |
|
|
|
|
|
if __name__ == "__main__": |
|
input_dir = "D:\\AI_pycharm\\Change_grayscale\\Texture_photos\\original_images\\" |
|
save_dir = "D:\\AI_pycharm\\Change_grayscale\\Texture_photos\\test\\" |
|
|
|
low_y_limit = 48000 |
|
high_y_limit = 13000 |
|
|
|
input_dir = '/data/datasets/texure_photos/' |
|
|
|
pids = '99724,99747,99762,99763,99777,99778,99807,99812,99823,99843,99405,99416,97984,97662,86153' |
|
for pid in pids.split(','): |
|
pid, print_id = pid.split('_') |
|
# 根据前缀获取文件列表 |
|
path = f'/data/datasets/texure_photos/' |
|
texture_filename = f'{path}{pid}Tex1.{print_id}.jpg' |
|
|
|
prefix = f'objs/print/{pid}/' |
|
if config.oss_bucket.object_exists(f'{prefix}{pid}Tex1.{print_id}.jpg'): |
|
print(f'{pid}Tex1.{print_id}.jpg 处理中...') |
|
if not os.path.exists(texture_filename): |
|
os.makedirs(path, exist_ok=True) |
|
config.oss_bucket.get_object_to_file(f'{prefix}{pid}Tex1.{print_id}.jpg', texture_filename) |
|
main(pid, print_id) |
|
else: |
|
print(f'文件已存在,直接上传') |
|
config.oss_bucket.put_object_from_file(f'objs/print/{pid}/{pid}Tex1.{print_id}.jpg', texture_filename) |
|
|
|
print(f'{pid}Tex1.{print_id}.jpg 处理完成') |
|
|