目录

jittor-手工智能-时序图链接预测

计图挑战赛 赛道一:基于图学习的动态推荐任务
战队:手工智能 | A 榜 MRR:1.084 | 排名:110
框架:Jittor 1.3.11 + JittorGeometric | 模型:CRAFT (改进版)


1. 项目简介

本仓库实现了基于 CRAFT(Continuous-time Representation learning with Attention for Future interaction prediction)模型的时序图链接预测方案,用于计图挑战赛赛道一。

核心改进

  • 更大的 Transformer 编码器(hidden 128/96, layers 3/2, heads 4/2)
  • Cosine 学习率预热 + 退火调度
  • 断点续训支持
  • 针对 6GB 显存的参数自适应

数据集:两个场景——非二部图(dataset1)和二部图(dataset2),详见赛题说明。


2. 环境安装

  • Python:3.11
  • Jittor:1.3.11
  • CUDA:12.4(推荐)
# 创建 conda 环境
conda create -n jittor python=3.11 -y
conda activate jittor

# 安装 Jittor
pip install jittor==1.3.11

# 安装 JittorGeometric
pip install git+https://github.com/AlgRUC/JittorGeometric.git

# 安装其他依赖
pip install -r requirements.txt

3. 数据准备

数据需按以下结构放置:

data/
├── dataset1/
│   ├── train.csv          # 训练数据 (src, dst, time)
│   └── test.csv           # 测试数据 (src, time, c1~c100)
└── dataset2/
    ├── train.csv
    └── test.csv

训练数据格式:src,dst,time(源节点, 目标节点, 时间戳)
测试数据格式:src,time,c1,c2,...,c100(源节点, 时间戳, 100 个候选目标)

数据通过比赛平台下载,不包含在本仓库中。


4. 训练

方式一:使用脚本(推荐)

# dataset1(非二部图)
bash scripts/train_dataset1.sh

# dataset2(二部图,显存受限环境)
bash scripts/train_dataset2.sh

方式二:直接调用

# dataset1 — 改进版 CRAFT (128/3/4)
python src/train.py \
    --dataset dataset1 \
    --hidden_size 128 --n_layers 3 --n_heads 4 \
    --num_neighbors 32 --loss_type BPR \
    --epochs 50 --batch_size 200 --lr 0.0001

# dataset2 — 轻量版 (96/2/2,适配 6GB 显存)
python src/train.py \
    --dataset dataset2 \
    --hidden_size 96 --n_layers 2 --n_heads 2 \
    --num_neighbors 24 --loss_type BPR \
    --epochs 50 --batch_size 128 --lr 0.0001

断点续训

训练中断后,从 checkpoint 继续:

python src/resume.py --epochs 29 --start_epoch 21

模型权重自动保存在 outputs/ 目录下。


5. 推理

# 预测并生成提交文件
python src/predict.py --dataset dataset1 --ckpt outputs/dataset1_CRAFT_best.pkl
python src/predict.py --dataset dataset2 --ckpt outputs/dataset2_CRAFT_best.pkl

# 打包提交
python tools/pack_results.py

6. 结果说明

数据集 模型 参数量 Epoch
dataset1 CRAFT (128/3/4) 6.1M 50
dataset2 CRAFT (96/2/2) 13.6M 35 (early stop)

评测指标:MRR(Mean Reciprocal Rank)
A 榜得分:1.084(排名 110/??)

MRR 计算方式:对每个测试样本的 100 个候选节点按预测概率降序排列,正样本排名为 k 则得分 1/k,最终 MRR = 所有样本得分的均值。


7. 项目结构

├── README.md
├── LICENSE
├── .gitignore
├── requirements.txt
├── configs/
│   └── default.yaml           # 默认训练配置
├── src/
│   ├── train.py               # 训练主脚本
│   ├── resume.py              # 断点续训脚本
│   ├── predict.py             # 推理脚本
│   ├── model.py               # CRAFT 模型封装
│   └── utils.py               # 工具函数(LR 调度等)
├── scripts/
│   ├── train_dataset1.sh      # dataset1 训练脚本
│   └── train_dataset2.sh      # dataset2 训练脚本
├── tools/
│   ├── pack_results.py        # 打包提交文件
│   └── analyze_results.py     # 结果分析
├── data/
│   └── README.md              # 数据格式说明
└── outputs/                   # 模型权重、日志(不提交)

8. 引用

9. 许可证

MIT License. 详见 LICENSE

关于
54.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号