OpenGraph:迈向开放图基础模型
为实现这一目标,OpenGraph解决了几个关键技术挑战:
- 我们提出了统一的图标记器,使我们的图模型能够很好地适应未见过的图数据,即使底层图属性与训练中遇到的情况有显著不同。
- 我们开发了一个可扩展的图transformer作为基础编码器,能有效高效地捕捉全局拓扑语境中的节点间依赖关系。
- 我们引入了由大型语言模型(LLM)增强的数据增强机制,以缓解现实场景中数据稀缺的局限性。
大量实验验证了我们框架的有效性。通过使OpenGraph适应新的图特征并理解不同图的细微差别,我们的方法在各种设置和领域中实现了卓越的零样本图学习性能。
环境设置
您需要解压datasets/
中的一些数据文件。使用Models/readme
中的链接下载预训练模型。我们的实验是在以下软件包版本下进行的:
- python==3.10.13
- torch==1.13.0
- numpy==1.23.4
- scipy==1.9.3
简要代码结构
以下是代码结构的简要概述。每个目录的说明都被括在引号中(##...##)。有关更详细的版本,请参阅本readme末尾列出的完整版本。
./
│ └── README.md
│ ├── History/ ## 预训练模型的训练历史 ##
│ ├── Models/ ## 预训练模型 ##
│ ├── datasets/
│ ├── graph_generation/ ## 图生成的代码和示例 ##
│ ├── imgs/ ## readme中使用的图片 ##
│ ├── link_prediction/ ## 链接预测和预训练的代码 ##
│ │ ├── data_handler.py
│ │ ├── main.py
│ │ ├── model.py
│ │ └── params.py
│ │ ├── Utils/
│ │ │ └── TimeLogger.py
│ ├── node_classification/ ## 节点分类测试的代码 ##
│ │ ├── data_handler.py
│ │ ├── main.py
│ │ ├── model.py
│ │ └── params.py
│ │ ├── Utils/
│ │ │ └── TimeLogger.py
使用方法
要复现论文中报告的测试性能,请运行以下命令行:
cd link_prediction/
python main.py --load pretrn_gen1 --epoch 0 # 在OGBL-Collab、ML-1M、ML-10M上测试
python main.py --load pretrn_gen0 --tstdata amazon-book --epoch 0 # 在Amazon-Book上测试
python main.py --load pretrn_gen2 --tstdata ddi --epoch 0 # 在OGBL-ddi上测试
cd ../node_classification/
python main.py --load pretrn_gen1 --tstdata cora # 在Cora上测试
python main.py --load pretrn_gen1 --tstdata citeseer # 在Citeseer上测试
python main.py --load pretrn_gen1 --tstdata pubmed # 在Pubmed上测试
要自行重新预训练OpenGraph,请运行以下命令行:
cd ../link_prediction/
python main.py --save pretrn_gen1
python main.py --trndata gen0 --tstdata amazon-book --save pretrn_gen0
python main.py --trndata gen2 --tstdata ddi --save pretrn_gen2
要探索使用多个不同的预训练和测试数据集进行预训练,请修改link_prediction/main.py
第241行的trn_datasets
和tst_datasets
。
图数据生成
图生成代码位于graph_generation/
中。提供了一个小规模的玩具数据集。您需要先在Utils.py
和itemCollecting_dfsIterator.py
中填写您的OpenAI密钥。要生成您的数据集,请修改descs
和hyperparams
字典,并按以下步骤进行:
cd graph_generation/
python itemCollecting_dfsIterator.py
python instance_number_estimation_hierarchical.py
python embedding_generation.py
python human_item_generation_gibbsSampling_embedEstimation.py
python make_adjs.py
下面展示了我们的提示模板,以及提示配置和生成节点的示例。
评估结果
整体泛化性能
OpenGraph在零样本设置下达到最佳性能,优于使用1-shot和5-shot数据训练/微调的基线模型。
预训练数据集研究
我们研究了使用不同预训练数据集的影响。以下结果表明:
- 生成技术(Norm、Loc、Topo)对性能有积极影响。
- 真实世界数据集(Yelp2018、Gowalla)可能比我们生成的数据集产生更差的结果。
- 相关的预训练数据集(ML-10M用于测试数据ML-1M和ML-10M)会带来更优越的性能。
图分词器研究
我们通过调整图平滑的顺序,并用替代方案替换我们的拓扑感知投影,调整了统一图分词器的配置。我们的发现包括:
- 邻接平滑很重要,因为0阶平滑的OpenGraph性能较差。
- 拓扑感知投影在性能上更优。替代方案包括独热编码,它为所有数据集学习一个大的统一表示表;随机,对节点间关系不做任何假设并均匀分布;度,这是非属性图常用的方法,似乎适用于跨图场景。
采样技术研究
我们对图转换器中的两种采样技术进行了消融实验,下面展示了它们对内存和时间成本的积极影响。令人惊讶的是,令牌序列采样对模型性能有积极影响。
引用
如果您认为这项工作对您的研究有用,请考虑引用我们的论文:
@article{xia2024opengraph,
title={OpenGraph: Towards Open Graph Foundation Models},
author={Xia, Lianghao and Kao, Ben and Huang, Chao},
journal={arXiv preprint arXiv:2403.01121},
year={2024}
}
详细代码结构
./
│ └── README.md
│ ├── History/ ## 预训练模型的训练历史 ##
│ │ ├── pretrn_gen0.his
│ │ ├── pretrn_gen2.his
│ │ └── pretrn_gen1.his
│ ├── Models/ ## 预训练模型 ##
│ │ └── readme ## 使用内部链接下载预训练模型 ##
│ ├── datasets/
│ │ ├── amazon-book/
│ │ │ ├── fewshot_mat_1.pkl
│ │ │ ├── trn_mat.pkl.zip ## 需手动解压 ##
│ │ │ ├── tst_mat.pkl
│ │ │ └── fewshot_mat_5.pkl
│ │ ├── citeseer/
│ │ │ ├── adj_-1.pkl
│ │ │ ├── adj_1.pkl
│ │ │ ├── adj_5.pkl
│ │ │ ├── feats.pkl.zip ## 需手动解压 ##
│ │ │ ├── label.pkl
│ │ │ ├── mask_-1.pkl
│ │ │ ├── mask_1.pkl
│ │ │ └── mask_5.pkl
│ │ ├── collab/
│ │ │ ├── fewshot_mat_5.pkl
│ │ │ ├── trn_mat.pkl.zip ## 需手动解压 ##
│ │ │ ├── tst_mat.pkl
│ │ │ ├── val_mat.pkl
│ │ │ └── fewshot_mat_1.pkl
│ │ ├── cora/
│ │ │ ├── adj_-1.pkl
│ │ │ ├── adj_1.pkl
│ │ │ ├── adj_5.pkl
│ │ │ ├── feats.pkl
│ │ │ ├── label.pkl
│ │ │ ├── mask_-1.pkl
│ │ │ ├── mask_1.pkl
│ │ │ └── mask_5.pkl
│ │ ├── ddi/
│ │ │ ├── fewshot_mat_1.pkl
│ │ │ ├── trn_mat.pkl.zip ## 需手动解压 ##
│ │ │ ├── tst_mat.pkl
│ │ │ ├── val_mat.pkl
│ │ │ └── fewshot_mat_5.pkl
│ │ ├── gen0/
│ │ │ ├── trn_mat.pkl
│ │ │ ├── val_mat.pkl
│ │ │ └── tst_mat.pkl
│ │ ├── gen1/
│ │ │ ├── trn_mat.pkl
│ │ │ ├── tst_mat.pkl
│ │ │ └── val_mat.pkl
│ │ ├── gen2/
│ │ │ ├── trn_mat.pkl
│ │ │ ├── val_mat.pkl
│ │ │ └── tst_mat.pkl
│ │ ├── ml10m/
│ │ │ ├── fewshot_mat_1.pkl
│ │ │ ├── trn_mat.pkl.zip ## 需手动解压 ##
│ │ │ ├── tst_mat.pkl.zip ## 需手动解压 ##
│ │ │ └── fewshot_mat_5.pkl
│ │ ├── ml1m/
│ │ │ ├── fewshot_mat_5.pkl
│ │ │ ├── trn_mat.pkl
│ │ │ ├── tst_mat.pkl
│ │ │ └── fewshot_mat_1.pkl
│ │ ├── pubmed/
│ │ │ ├── adj_-1.pkl
│ │ │ ├── adj_1.pkl
│ │ │ ├── feats.pkl.zip ## 需手动解压 ##
│ │ │ ├── label.pkl
│ │ │ ├── mask_-1.pkl
│ │ │ ├── mask_1.pkl
│ │ │ ├── mask_5.pkl
│ │ │ └── adj_5.pkl
│ ├── graph_generation/ ## 图生成的代码和示例 ##
│ │ ├── embedding_generation.py ## 节点嵌入生成 ##
│ │ ├── human_item_generation_gibbsSampling_embedEstimation.py ## 边生成 ##
│ │ ├── instance_number_estimation_hierarchical.py ## 估计每个节点的数量。论文中未提及。 ##
│ │ ├── itemCollecting_dfsIterator.py ## 节点生成 ##
│ │ ├── make_adjs.py ## 为生成的图制作数据集 ##
│ │ └── Utils.py
│ │ ├── Exp_Utils/
│ │ │ ├── Emailer.py ## 用于实验的警告邮件发送工具 ##
│ │ │ └── TimeLogger.py
│ │ ├── gen_results/
│ │ │ ├── tree_wInstanceNum_products_e-commerce platform like Amazon.pkl ## 树形数据结构 ##
│ │ │ └── products_e-commerce platform like Amazon.txt ## 节点列表 ##
│ │ │ ├── datasets/
│ │ │ │ ├── gen_data_ecommerce/
│ │ │ │ │ ├── embedding_dict.pkl
│ │ │ │ │ ├── item_list.pkl
│ │ │ │ │ └── interaction_base-0_iter-0.pkl ## 生成的边 ##
│ │ │ │ │ ├── res/
│ │ │ │ │ │ ├── iter-0_imap.pkl ## 节点的ID映射 ##
│ │ │ │ │ │ ├── iter-0_test.pkl
│ │ │ │ │ │ ├── iter-0_train.pkl
│ │ │ │ │ │ ├── iter-0_valid.pkl
│ │ │ │ │ │ └── interaction_fuse_iter-0.pkl
│ │ │ ├── tem/ ## 节点生成的临时文件 ##
│ │ │ │ ├── e-commerce platform like Amazon_depth1_products
│ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Automotive
│ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Baby
│ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Beauty
│ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Books
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,服装
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,电子产品
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,手工制品
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,健康与个人护理
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,家居装修
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,工业与科学
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,珠宝
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,乐器
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,办公用品
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,宠物用品
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,工具与家居装修
│ │ │ │ ├── 类似亚马逊的电商平台_深度2_产品,玩具
│ │ │ │ └── 类似亚马逊的电商平台_深度2_产品,运动户外
│ ├── imgs/ ## readme中使用的图片 ##
│ │ ├── framework.png
│ │ ├── intro.png
│ │ ├── performance.png
│ │ └── article cover.jpg
│ ├── link_prediction/ ## 链接预测和预训练的代码 ##
│ │ ├── data_handler.py
│ │ ├── main.py
│ │ ├── model.py
│ │ └── params.py
│ │ ├── Utils/
│ │ │ └── TimeLogger.py
│ ├── node_classification/ ## 节点分类测试的代码 ##
│ │ ├── data_handler.py
│ │ ├── main.py
│ │ ├── model.py
│ │ └── params.py
│ │ ├── Utils/
│ │ │ └── TimeLogger.py