目录

JittorGeometric GCN Cora Node Classification

本项目为计图挑战赛赛道一热身赛的开源实现,基于 Jittor 和 JittorGeometric 在 Cora 引文网络数据集上训练两层 GCN 模型,完成节点分类 任务,并生成测试集预测文件 result.json

环境安装

推荐环境:

  • Python 3.10
  • Jittor 1.3.11.0
  • JittorGeometric 2.0.0
  • CUDA 环境可用时使用 GPU 训练

建议创建独立 conda 环境:

conda create -n jittor-hot python=3.10
conda activate jittor-hot
pip install -r requirements.txt

Jittor 和 JittorGeometric 的官方安装说明可参考:

数据准备

赛题数据文件为 data/cora.pkl,目录结构如下:

data/
  cora.pkl

cora.pkl 为 pickle 格式,包含:

字段 类型 说明
x numpy array, (2708, 1433) 节点特征矩阵
y numpy array, (2708,) 节点标签,测试集标签为 -1
edge_index numpy array, (2, num_edges) COO 边列表
train_mask numpy bool array, (2708,) 训练集掩码
val_mask numpy bool array, (2708,) 验证集掩码
test_mask numpy bool array, (2708,) 测试集掩码
num_classes int 类别数,固定为 7
num_features int 特征维度,固定为 1433

开源仓库中不建议提交原始数据文件。如需复现,请从比赛发布包中获取 cora.pkl 并放入 data/ 目录。

训练与推理

运行以下命令会完成训练、验证和测试集预测:

python gcn.py

脚本会执行 200 个 epoch,并在当前目录生成:

result.json

result.json 是提交所需的测试集预测结果,格式为:

{
  "1708": 0,
  "1709": 1,
  "1710": 3
}

其中 key 为测试节点编号的字符串形式,value 为预测类别编号,范围为 0-6

如果使用 conda 环境直接运行:

conda run -n jittor-hot python gcn.py

提交打包

比赛要求提交 result.zip,根目录包含模型代码和预测结果:

result.zip
  gcn.py
  result.json

生成命令:

zip result.zip gcn.py result.json

提交前可检查预测文件格式:

python - <<'PY'
import json

with open("result.json", "r", encoding="utf-8") as f:
    result = json.load(f)

print("num_predictions:", len(result))
print("label_range:", min(result.values()), max(result.values()))
print("first_items:", list(result.items())[:5])
PY

期望结果:

  • num_predictions 为 1000
  • label_range0-6 范围内

结果说明

本项目使用验证集准确率作为本地训练过程中的参考指标。一次运行结果示例:

最终结果: Val Acc: 0.8080
共预测 1000 个测试节点

线上评测会加载 result.json,与隐藏测试集真实标签计算 Accuracy。赛题通过标准 为测试集 Accuracy 不低于 0.70。

实现说明

gcn.py 的主要流程:

  1. 读取 data/cora.pkl 并转换为 Jittor 张量。
  2. 对节点特征做行归一化。
  3. 使用 gcn_norm 加自环并计算 GCN 边权重。
  4. 将 COO 边表示转换为 CSC/CSR 稀疏格式。
  5. 使用两层 GCNConv 完成节点分类。
  6. 使用训练集节点计算交叉熵损失,在验证集上记录最佳准确率。
  7. 对测试集节点生成预测并保存到 result.json

随机种子在脚本中固定为 42,便于复现。

第三方依赖与声明

本项目依赖 Jittor 和 JittorGeometric:

二者均采用 Apache License 2.0。

License

This project is licensed under the Apache License 2.0. See LICENSE for details.

关于
35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号