目录
目录README.md

SwinIR 图像超分辨率项目

本项目使用 SwinIR 模型对低分辨率图像进行超分辨率处理,提供简单易用的 Python API 接口。

环境配置

1. 创建 Python 环境

建议使用 Python 3.8-3.12 版本:

conda create -n swinir python=3.12
conda activate swinir

2. 安装依赖库

pip install -r requirements.txt

或手动安装:

pip install torch torchvision opencv-python numpy timm requests

注意:如果需要 GPU 加速,请根据您的 CUDA 版本安装对应的 PyTorch。访问 PyTorch 官网 获取正确的安装命令。

快速开始

基本用法

from main_test_swinir import hr
import cv2

# 方式1: 使用图像路径
output = hr("input.png")
cv2.imwrite("output.png", output)

# 方式2: 使用numpy数组
input_img = cv2.imread("input.png")
output = hr(input_img)
cv2.imwrite("output.png", output)

API 文档

函数签名

def hr(input_img, 
       task='real_sr',
       scale=4,
       model_path='model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth',
       tile=None,
       tile_overlap=32,
       large_model=False,
       noise=15,
       jpeg=40,
       training_patch_size=128):
    """
    超分辨率图像处理函数
    
    参数:
        input_img: 输入图像
                   - str: 图像文件路径
                   - numpy.ndarray: 图像数组 (HWC格式, BGR或RGB, uint8或float32)
        
        task: 任务类型,默认'real_sr' (真实图像超分辨率)
              - 'classical_sr': 经典超分辨率
              - 'lightweight_sr': 轻量级超分辨率
              - 'real_sr': 真实图像超分辨率 (推荐)
              - 'gray_dn': 灰度图像去噪
              - 'color_dn': 彩色图像去噪
              - 'jpeg_car': JPEG压缩伪影去除
              - 'color_jpeg_car': 彩色JPEG压缩伪影去除
        
        scale: 放大倍数,默认4 (支持: 1, 2, 3, 4, 8)
        
        model_path: 模型文件路径
        
        tile: 分块处理大小,None表示整图处理
              - 对于大图像,建议设置为400-800以节省显存
              - 小图像可以设置为None
        
        tile_overlap: 分块重叠大小,默认32
        
        large_model: 是否使用大模型,默认False
        
        noise: 去噪等级,默认15 (用于去噪任务,可选: 15, 25, 50)
        
        jpeg: JPEG质量,默认40 (用于JPEG伪影去除,可选: 10, 20, 30, 40)
        
        training_patch_size: 训练时的patch大小,默认128
    
    返回:
        numpy.ndarray: 处理后的高分辨率图像 (HWC格式, BGR, uint8)
    """

使用示例

1. 基本超分辨率处理

from main_test_swinir import hr
import cv2

# 处理单张图像
output = hr("input.png")
cv2.imwrite("output.png", output)

2. 批量处理图像

from main_test_swinir import hr
import cv2
import os

input_folder = "fig_lr"
output_folder = "results_hr"
os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(input_folder):
    if filename.endswith(('.png', '.jpg', '.jpeg')):
        input_path = os.path.join(input_folder, filename)
        output_path = os.path.join(output_folder, filename)
        
        try:
            output = hr(input_path, scale=4, tile=512)
            cv2.imwrite(output_path, output)
            print(f"✓ {filename} 处理完成")
        except Exception as e:
            print(f"✗ {filename} 处理失败: {e}")

3. 处理大尺寸图像(使用tile分块)

from main_test_swinir import hr

# 对于大图像,使用tile参数避免显存溢出
output = hr(
    "large_image.png",
    tile=512,          # 分块大小
    tile_overlap=32    # 重叠区域
)

4. 使用numpy数组作为输入

from main_test_swinir import hr
import cv2

# 读取图像
img = cv2.imread("input.png")

# 可以进行预处理
img = cv2.GaussianBlur(img, (3, 3), 0)

# 超分辨率处理
output = hr(img)

# 保存结果
cv2.imwrite("output.png", output)

5. 不同任务类型

真实图像超分辨率(默认)

output = hr("input.png", task='real_sr', scale=4)

图像去噪

output = hr("noisy_image.png", task='color_dn', noise=25)

JPEG伪影去除

output = hr("compressed.jpg", task='color_jpeg_car', jpeg=30)

模型文件

自动下载

如果模型文件不存在,函数会自动从GitHub下载。默认模型:

  • 路径: model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth
  • 任务: 真实图像超分辨率 x4

可用模型

模型 任务 放大倍数 说明
003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth real_sr 4x 真实图像超分(推荐)
003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pth real_sr 2x 真实图像超分 2x
003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth real_sr 4x 大模型,效果更好
001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth classical_sr 2x 经典超分 2x
002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth lightweight_sr 2x 轻量级超分

使用其他模型

output = hr(
    "input.png",
    task='classical_sr',
    scale=2,
    model_path='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth'
)

注意事项

  1. 显存占用:

    • 对于大图像(>2000x2000),建议设置tile=512或更小
    • 如果显存不足,减小tile值
  2. 输入图像格式:

    • 支持路径字符串或numpy数组
    • numpy数组支持uint8或float32类型
    • 自动处理BGR/RGB转换
  3. 输出格式:

    • 返回numpy数组,HWC格式,BGR通道顺序,uint8类型
    • 可直接使用cv2.imwrite()保存
  4. 性能优化:

    • 首次运行会下载模型,需要网络连接
    • 有GPU时自动使用CUDA加速
    • 模型加载后会保持在内存中

项目结构

.
├── main_test_swinir.py          # 主程序(包含hr()函数)
├── example_usage.py             # 使用示例
├── models/
│   └── network_swinir.py        # SwinIR 网络模型
├── utils/
│   └── util_calculate_psnr_ssim.py  # 图像质量评估工具
├── model_zoo/
│   └── swinir/                  # 预训练模型存放目录
├── fig_lr/                      # 输入:低分辨率图像文件夹
└── results/                     # 输出:超分辨率结果文件夹

常见问题

Q: 如何处理非常大的图像?

output = hr("huge_image.png", tile=400, tile_overlap=32)

Q: 如何在循环中处理多张图像?
只需要多次调用hr()函数,模型会自动复用。

Q: 输出图像尺寸是多少?
输出尺寸 = 输入尺寸 × scale (例如: 100x100输入,scale=4,输出400x400)

Q: 支持哪些图像格式?
支持OpenCV imread能读取的所有格式: PNG, JPG, JPEG, BMP, TIFF等。

Q: 如何知道是否使用了 GPU?
程序会自动检测并使用 GPU(如果可用)。运行时查看 nvidia-smi 命令确认 GPU 使用情况。

Q: 处理速度很慢怎么办?
确保安装了支持 CUDA 的 PyTorch 版本。如果仍然慢,可能是 CPU 模式,请检查 CUDA 安装。

Q: 输出图像质量不理想?
尝试使用 large_model=True 参数切换到大模型,或调整 tile 参数(更大的 tile 通常效果更好,但需要更多显存)。

参考文献

SwinIR: Image Restoration Using Swin Transformer
论文链接:https://arxiv.org/abs/2108.10257
项目主页:https://github.com/JingyunLiang/SwinIR

关于
42.2 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号