目录

sam-optimizers

新型优化算法的实现与分析 - SAM系列优化器

本项目实现了多种基于锐度感知最小化(Sharpness-Aware Minimization, SAM)的优化算法,并在CIFAR-10/100数据集上进行对比实验。

实验环境

硬件要求

  • AMD GPU (支持ROCm) 或 NVIDIA GPU (支持CUDA)
  • 至少8GB显存(推荐12GB以上用于完整训练)

软件环境

  • Python 3.8+
  • PyTorch (ROCm版本或CUDA版本)
  • torchvision
  • numpy
  • matplotlib (可选,用于可视化)

安装依赖

对于AMD ROCm用户:

# 安装ROCm版本的PyTorch
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm5.7

对于NVIDIA CUDA用户:

# 安装CUDA版本的PyTorch
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118

安装其他依赖:

pip install numpy matplotlib

数据集下载

本项目使用CIFAR-10和CIFAR-100数据集,首次运行时会自动下载。

  • CIFAR-10: 10个类别,50000张训练图片,10000张测试图片
  • CIFAR-100: 100个类别,50000张训练图片,10000张测试图片

数据集会自动预加载到GPU显存中,以加速训练过程。

运行方式

1. 单优化器训练

使用SGD训练:

python train.py --optimizer sgd --dataset cifar10 --model resnet56 --epochs 200 --lr 0.1

使用Adam训练:

python train.py --optimizer adam --dataset cifar10 --model resnet56 --epochs 200 --lr 0.01

使用SAM训练:

python train.py --optimizer sam --dataset cifar10 --model resnet56 --epochs 200 --lr 0.1 --rho 0.05

使用ESAM训练:

python train.py --optimizer esam --dataset cifar10 --model resnet56 --epochs 200 --lr 0.1 --rho 0.05 --beta 0.5 --gamma 0.5

使用K-SAM训练:

python train.py --optimizer ksam --dataset cifar10 --model resnet56 --epochs 200 --lr 0.1 --rho 0.05 --k1 64 --k2 64

2. 批量对比实验

运行所有优化器的对比实验:

python run_comparison.py --dataset cifar10 --model resnet56 --epochs 200

3. 结果可视化

生成训练曲线和对比图表:

python visualize.py --results_dir ./results

项目结构

sam_optimizers/
├── optimizers/           # 优化器实现
│   ├── __init__.py
│   ├── sam.py           # SAM (Sharpness-Aware Minimization)
│   ├── esam.py          # ESAM (Efficient SAM)
│   └── ksam.py          # K-SAM (Top-K SAM)
├── models/               # 模型定义
│   ├── __init__.py
│   └── resnet.py        # ResNet系列模型
├── utils/                # 工具函数
│   ├── __init__.py
│   ├── data_loader.py   # 数据加载(支持显存预加载)
│   └── train_utils.py   # 训练工具函数
├── train.py             # 主训练脚本
├── run_comparison.py    # 批量对比实验脚本
├── visualize.py         # 结果可视化脚本
└── README.md            # 本文件

优化器说明

1. SAM (Sharpness-Aware Minimization)

  • 论文: Foret et al., “Sharpness-Aware Minimization for Efficiently Improving Generalization”, ICLR 2021
  • 核心思想: 同时最小化损失值和损失锐度,寻找平坦的最小值点
  • 特点: 需要两次前向和反向传播,计算开销约为SGD的2倍
  • 超参数: rho (邻域大小)

2. ESAM (Efficient Sharpness-Aware Minimization)

  • 论文: Du et al., “Efficient Sharpness-Aware Minimization for Improved Training of Neural Networks”, ICLR 2022
  • 核心思想: 通过两种策略提高SAM的效率
    • 随机权重扰动 (SWP): 只随机选择部分权重进行扰动
    • 锐度敏感数据选择 (SDS): 只选择对锐度最敏感的样本进行更新
  • 特点: 计算开销约为SGD的1.4倍,同时保持甚至提升性能
  • 超参数: beta (权重选择概率), gamma (样本选择比例)

3. K-SAM (Top-K Sharpness-Aware Minimization)

  • 论文: Ni et al., “K-SAM: Sharpness-Aware Minimization at the Speed of SGD”, 2022
  • 核心思想: 在SAM的两个阶段都只使用损失最大的K个样本
  • 特点: 可以达到与SGD相当的速度,同时获得接近SAM的泛化提升
  • 超参数: k1 (上升阶段样本数), k2 (下降阶段样本数)

实验结果

推荐参数设置

优化器 学习率 rho 其他参数
SGD 0.1 - momentum=0.9, weight_decay=5e-4
Adam 0.01 - weight_decay=5e-4
SAM 0.1 0.05 base_optimizer=SGD
ESAM 0.1 0.05 beta=0.5, gamma=0.5
K-SAM 0.1 0.05 k1=64, k2=64

预期结果

在CIFAR-10 + ResNet-56的设置下,预期的测试准确率排序大致为: SAM > ESAM ≈ K-SAM > SGD > Adam

具体数值会因训练设置和随机种子而有所不同。

AMD ROCm 兼容性说明

本项目完全兼容AMD ROCm PyTorch。由于ROCm PyTorch使用与CUDA版本相同的API,代码无需修改即可运行。

注意事项:

  1. 确保安装了正确的ROCm版本PyTorch
  2. 数据集预加载功能在ROCm上同样有效
  3. 所有优化器实现均使用标准PyTorch API,与硬件平台无关

显存优化

本项目实现了数据集全量加载到显存的功能,可以显著提升训练速度:

  • 避免了每次迭代时的数据从CPU到GPU的传输开销
  • 对于CIFAR-10/100这类小数据集,显存占用很小(约几百MB)
  • 可以通过--gpu_preload参数控制(默认开启)

参考文献

[1] Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR 2021.

[2] Du, J., Yan, H., Feng, J., Zhou, J. T., Zhen, L., Goh, R. S. M., & Tan, V. Y. F. (2022). Efficient Sharpness-Aware Minimization for Improved Training of Neural Networks. ICLR 2022.

[3] Ni, R., Chiang, P. Y., Geiping, J., Goldblum, M., Wilson, A. G., & Goldstein, T. (2022). K-SAM: Sharpness-Aware Minimization at the Speed of SGD.

关于

深度神经网络的泛化能力是机器学习领域的核心问题之一。传统优化算法如随机梯度下降(SGD)和Adam仅最小化训练损失值,容易导致模型收敛到尖锐的最小值点,从而影响泛化性能。近年来,锐度感知最小化(Sharpness-Aware Minimization, SAM)通过同时最小化损失值和损失锐度,有效提升了模型的泛化能力,但其计算开销约为传统优化器的两倍。本文系统研究了SAM及其两种高效变体——ESA

58.0 KB
邀请码