目录
目录README.md

<pad><pad><eos>-计图风格迁移挑战赛

image

简介

本项目包含了第四届计图挑战赛赛道二风格迁移比赛的代码实现。在基础DreamBooth方法上实现了诸多优化,从而能够获得较高质量的风格迁移效果。github链接

本项目在B榜上获得0.5166的评分,排名第2。

环境配置

硬件要求 本项目的训练和推理可在 1 张 A800 显卡上完成。

性能分析 据实测,训练时显存占用约16GB。每个风格,在unet lora rank为64,text encoder lora rank为1的条件下训练4个epoch需要约15min。推理时:显存占用约47GB,每个风格需要约3min。为了提高推理速度,目前的推理方式为25张图片一起推理。若因显存占用太大而无法推理,可尝试在run_one.py中将推理方式修改为单张图片推理。

运行环境

  • linux
  • python == 3.9
  • jittor == 1.3.10

依赖安装 参考 baseline 完成基本的环境配置。本项目的requirements.txtenvironment.yaml已经列在根目录下。

项目代码

目录结构如下

|-- created_prior
|-- configs
|-- data
|-- weights
|-- diffusers
|-- environment.yaml
|-- requirements.txt
|-- run_one.py
|-- test.py
|-- train.py
|-- train.sh
|-- train_all.py
|-- create_prior.py
|-- utils.py
`-- readme.md

created_prior为无风格图像,由create_prior.py脚本生成,以节省训练阶段的时间开销。

configs应存放训练和推理的配置文件。在这里我们给出推理的配置。

diffusers文件夹存放有修改后的diffusers_jittor的内容。请以该文件夹替换掉原文件,以确保推理过程顺利进行。

data文件夹是我们战队在原始的数据集上对caption稍作调整后得到的。

weights存放有28个风格的模型权重,每个风格有1~3个模型不等。

dataweights可以在这里下载

推理

在确保dataweightsconfigs\infer_config.jsonl就绪的前提下,使用如下指令一键完成所有28个风格的推理:

python ./test.py

生成结果在result文件夹中。所有推理阶段的种子都固定为0,只要按照configs中的设置进行推理,其结果就与我们在B榜的最好提交保持一致。

训练

在确保data文件放置在正确的位置以后,可以通过如下指令来生成无风格图像:

python ./create_prior.py

生成结果被保存在created_prior文件夹中。

确保datacreated_prior以及configs\train_config.jsonl就绪以后,可通过如下指令训练:

python ./train_all.py

训练所得到的lora权重将被保存在weights文件夹中。

团队成员

陈亦逍 陈纪东 申君皓

单位:清华大学

QQ:1789757099

关于

第四届计图挑战赛赛道二,队伍“<pad><pad><eos>”的代码实现

7.6 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号