[1] Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR 2021.
[2] Du, J., Yan, H., Feng, J., Zhou, J. T., Zhen, L., Goh, R. S. M., & Tan, V. Y. F. (2022). Efficient Sharpness-Aware Minimization for Improved Training of Neural Networks. ICLR 2022.
[3] Ni, R., Chiang, P. Y., Geiping, J., Goldblum, M., Wilson, A. G., & Goldstein, T. (2022). K-SAM: Sharpness-Aware Minimization at the Speed of SGD.
sam-optimizers
新型优化算法的实现与分析 - SAM系列优化器
本项目实现了多种基于锐度感知最小化(Sharpness-Aware Minimization, SAM)的优化算法,并在CIFAR-10/100数据集上进行对比实验。
实验环境
硬件要求
软件环境
安装依赖
对于AMD ROCm用户:
对于NVIDIA CUDA用户:
安装其他依赖:
数据集下载
本项目使用CIFAR-10和CIFAR-100数据集,首次运行时会自动下载。
数据集会自动预加载到GPU显存中,以加速训练过程。
运行方式
1. 单优化器训练
使用SGD训练:
使用Adam训练:
使用SAM训练:
使用ESAM训练:
使用K-SAM训练:
2. 批量对比实验
运行所有优化器的对比实验:
3. 结果可视化
生成训练曲线和对比图表:
项目结构
优化器说明
1. SAM (Sharpness-Aware Minimization)
2. ESAM (Efficient Sharpness-Aware Minimization)
3. K-SAM (Top-K Sharpness-Aware Minimization)
实验结果
推荐参数设置
预期结果
在CIFAR-10 + ResNet-56的设置下,预期的测试准确率排序大致为: SAM > ESAM ≈ K-SAM > SGD > Adam
具体数值会因训练设置和随机种子而有所不同。
AMD ROCm 兼容性说明
本项目完全兼容AMD ROCm PyTorch。由于ROCm PyTorch使用与CUDA版本相同的API,代码无需修改即可运行。
注意事项:
显存优化
本项目实现了数据集全量加载到显存的功能,可以显著提升训练速度:
--gpu_preload参数控制(默认开启)参考文献
[1] Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR 2021.
[2] Du, J., Yan, H., Feng, J., Zhou, J. T., Zhen, L., Goh, R. S. M., & Tan, V. Y. F. (2022). Efficient Sharpness-Aware Minimization for Improved Training of Neural Networks. ICLR 2022.
[3] Ni, R., Chiang, P. Y., Geiping, J., Goldblum, M., Wilson, A. G., & Goldstein, T. (2022). K-SAM: Sharpness-Aware Minimization at the Speed of SGD.