GPEN训练损失不下降?数据对质量检查实战方法
本镜像基于GPEN人像修复增强模型构建,预装了完整的深度学习开发环境,集成了推理及评估所需的所有依赖,开箱即用。
1. 镜像环境说明
| 组件 | 版本 |
|---|---|
| 核心框架 | PyTorch 2.5.0 |
| CUDA 版本 | 12.4 |
| Python 版本 | 3.11 |
| 推理代码位置 | /root/GPEN |
主要依赖库:
facexlib: 用于人脸检测与对齐basicsr: 基础超分框架支持opencv-python,numpy<2.0,datasets==2.21.0,pyarrow==12.0.1sortedcontainers,addict,yapf
2. 快速上手
2.1 激活环境
conda activate torch252.2 模型推理 (Inference)
进入代码目录并使用预置脚本进行推理测试:
cd /root/GPEN使用下面命令进行推理测试,可以通过命令行参数灵活指定输入图片。
# 场景 1:运行默认测试图 # 输出将保存为: output_Solvay_conference_1927.png python inference_gpen.py # 场景 2:修复自定义图片 # 输出将保存为: output_my_photo.jpg python inference_gpen.py --input ./my_photo.jpg # 场景 3:直接指定输出文件名 # 输出将保存为: custom_name.png python inference_gpen.py -i test.jpg -o custom_name.png推理结果将自动保存在项目根目录下,测试结果如下:
3. 已包含权重文件
为保证开箱即用及离线推理能力,镜像内已预下载以下模型权重(如果没有运行推理脚本会自动下载):
- ModelScope 缓存路径:
~/.cache/modelscope/hub/iic/cv_gpen_image-portrait-enhancement - 包含内容:完整的预训练生成器、人脸检测器及对齐模型。
4. 训练损失不下降?常见原因分析
在使用 GPEN 进行人像修复模型训练时,开发者常遇到“训练损失长时间不下降”或“判别器/生成器损失震荡剧烈”的问题。这类现象往往并非由模型结构缺陷导致,而是源于数据准备阶段的质量控制不足。
4.1 数据对齐错误是首要诱因
GPEN 是一种基于 GAN Prior 的监督式图像增强方法,其训练依赖于高质量的(低质, 高质)图像对。若这对图像在语义、空间或像素级未对齐,则模型无法学习到有效的映射关系。
例如:
- 使用不同角度的人脸作为配对样本
- 合成低质图时进行了裁剪但高质原图未同步处理
- 降质过程引入了非均匀噪声或局部形变
这会导致模型在优化过程中不断“自我矛盾”,从而表现为损失停滞。
4.2 降质方式不合理
许多用户直接采用模糊、压缩等方式生成低质图像,但这些操作可能与真实退化场景差异较大,且缺乏多样性。
推荐做法:
- 使用BSRGAN或RealESRGAN提供的 degradation pipeline 生成更贴近现实的低分辨率图像
- 引入多种退化模式(如 JPEG 压缩、高斯噪声、运动模糊等)提升泛化能力
4.3 数据分布偏差大
如果训练集中包含大量极端光照、遮挡严重或姿态异常的图像,而验证集以正脸清晰为主,则会出现“训练损失下降但验证指标差”的情况。
建议:
- 对人脸关键点进行检测,筛选出偏转角过大(>30°)的样本用于特定任务微调而非主训练集
- 统一亮度、对比度范围,避免模型过度关注颜色校正而非纹理恢复
5. 数据对质量检查实战方法
要解决训练损失不下降的问题,必须从源头——数据对质量——入手。以下是可落地的数据质检流程。
5.1 可视化比对:最直观的方法
编写一个简单的可视化脚本,将高低质量图像并排显示:
import cv2 import numpy as np import os def visualize_pairs(hr_dir, lr_dir, num_samples=5): hr_files = sorted(os.listdir(hr_dir))[:num_samples] for f in hr_files: hr_path = os.path.join(hr_dir, f) lr_path = os.path.join(lr_dir, f) # 假设文件名一致 if not os.path.exists(lr_path): print(f"Missing LR file: {lr_path}") continue hr_img = cv2.imread(hr_path) lr_img = cv2.imread(lr_path) # Resize LR to match HR size for comparison lr_resized = cv2.resize(lr_img, (hr_img.shape[1], hr_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Concatenate horizontally concat = np.hstack([lr_resized, hr_img]) cv2.imshow("LR (left) vs HR (right)", concat) cv2.waitKey(0) cv2.destroyAllWindows() # 调用示例 visualize_pairs("/data/ffhq_hr_512", "/data/ffhq_lr_512")提示:重点关注五官是否对齐、发型轮廓是否一致、是否有明显错位或扭曲。
5.2 关键点一致性检测
利用facexlib中的 DFLFaceAlignment 模块提取人脸关键点,计算两图之间的关键点误差(L2 distance)。
from facexlib.detection import init_detection_model, detect_faces from facexlib.alignment import init_alignment_model, get_face_landmarks def check_landmark_consistency(img1, img2, thres=10.0): face_detector = init_detection_model('retinaface_resnet50', half=False) landmark_model = init_alignment_model('dlib', half=False) bboxes1 = detect_faces(face_detector, img1) bboxes2 = detect_faces(face_detector, img2) if len(bboxes1) == 0 or len(bboxes2) == 0: return False, "No face detected" # 取最大人脸 bbox1 = max(bboxes1, key=lambda x: x[2]*x[3]) bbox2 = max(bboxes2, key=lambda x: x[2]*x[3]) landmarks1 = get_face_landmarks(img1, [bbox1], eye_dist=True)[0] landmarks2 = get_face_landmarks(img2, [bbox2], eye_dist=True)[0] dist = np.linalg.norm(landmarks1 - landmarks2) return dist < thres, f"Keypoint L2 distance: {dist:.2f}"设定阈值(如 10.0),过滤掉关键点偏移过大的样本。
5.3 PSNR/SSIM 初步筛选
虽然 GPEN 处理的是非成对退化问题,但在构建训练集时仍可用 PSNR 和 SSIM 作为初步筛选工具:
from basicsr.metrics import calculate_psnr, calculate_ssim psnr = calculate_psnr(lr_img.astype(np.float32), hr_img.astype(np.float32)) ssim = calculate_ssim(lr_img.astype(np.float32), hr_img.astype(np.float32)) print(f"PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}")注意:极低 PSNR/SSIM 不一定代表不能训练,但过高(如 PSNR > 35dB)则说明退化不足,可能导致模型学不到有效特征。
5.4 自动化质检流水线建议
建立如下自动化流程:
# 1. 批量生成低质图像 python generate_degraded.py --source /raw/ffhq --target /data/degraded --method bsr-gan # 2. 对齐检查 python align_check.py --hr /raw/ffhq --lr /data/degraded --output /data/filter_list.txt # 3. 过滤不合格样本 python filter_dataset.py --list /data/filter_list.txt --src-dir /data/degraded --dst-dir /data/final_lr # 4. 开始训练 python train_gpen.py --dataroot /data --load_size 512 --crop_size 5126. 训练调优建议
在确保数据质量的前提下,进一步优化训练策略:
6.1 学习率设置
- 初始学习率建议设为
2e-4(Adam 优化器) - 使用 Cosine 衰减策略,总 epoch 数建议不少于 200
- 可先冻结判别器训练生成器前 10–20 个 epoch
6.2 损失函数监控
除了总损失外,应分别记录:
- L1 Loss(像素级重建)
- Perceptual Loss(感知相似性)
- GAN Loss(对抗损失)
观察各分支变化趋势,判断是否某一部分主导或抑制其他部分。
6.3 使用 TensorBoard 实时监控
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(log_dir='./logs') for step, data in enumerate(dataloader): # ... training ... writer.add_scalar('Loss/L1', l1_loss.item(), step) writer.add_scalar('Loss/GAN_G', gan_g_loss.item(), step) writer.add_scalar('Loss/GAN_D', gan_d_loss.item(), step)及时发现异常波动,辅助定位问题。
7. 总结
训练损失不下降是 GPEN 模型训练中最常见的问题之一,其根本原因往往不在模型本身,而在训练数据对的质量缺陷。本文系统梳理了三大典型问题:数据未对齐、降质方式不合理、分布偏差大,并提供了四套可落地的数据质检方法——可视化比对、关键点一致性检测、PSNR/SSIM 筛选和自动化流水线构建。
通过严格的前期数据清洗与质量控制,结合合理的学习率调度和损失监控机制,可以显著提升 GPEN 模型的收敛速度与最终修复效果。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。