地理标记模型
该存储库旨在支持开发者构建和训练他们自己的地理标记模型。这里提供的地理标记模型架构允许进行定制和训练。此外,我们还发布了适用于不同地理位置检测场景训练的数据集。
当前模型在最相关的10%文本上达到30公里的哈弗辛距离中位数误差。存储库的问题部分开放了改进模型性能的挑战。
架构和训练
点击展开地理标记模型架构图。
%%{init:{'theme':'neutral'}}%%
flowchart TD
subgraph "ByT5分类器"
a("输入文本") --> b("输入ID")
subgraph "byt5(T5编码器模型)"
b("输入ID") --> c("byt5.encoder.inp_input_ids")
subgraph "byt5.encoder(T5堆栈)"
c("byt5.encoder.inp_input_ids") --> d("byt5.encoder.embed_tokens")
subgraph "byt5.encoder.embed_tokens (嵌入)"
d("byt5.encoder.embed_tokens") --> f("嵌入")
e("byt5.encoder.embed_tokens.inp_weights") --> f("嵌入") --> g("byt5.encoder.embed_tokens.out_0")
end
g("byt5.encoder.embed_tokens.out_0") --> h("byt5.encoder.dropout(丢弃)") --> i("byt5.encoder.block.0(T5块)") --> j("byt5.encoder.block.1(T5块)") & k("byt5.encoder.block.2-9(T5块)") & l("byt5.encoder.block.10(T5块)")
j("byt5.encoder.block.1(T5块)") --> k("byt5.encoder.block.2(T5块)<br><br> ...<br><br>byt5.encoder.block.10(T5块) ") --> l("byt5.encoder.block.11(T5块)") --> m("byt5.encoder.final_layer_norm(T5层归一化)")
m("byt5.encoder.final_layer_norm(T5层归一化)")-->n("byt5.encoder.dropout(丢弃)")--> o("byt5.encoder.out_0")
end
o("byt5.encoder.out_0") --> p("byt5.out_0")
end
p("byt5.out_0")-->q("(线性)")
end
q("(线性)") -->r("logits")
依赖项
确保在你的环境中安装以下依赖项以构建和训练你的地理标记模型:
transformers==4.29.1
tqdm==4.63.2
pandas==1.4.4
pytorch==1.7.1
要使用基于ByT5编码器的方法训练你的地理标记模型,执行以下脚本:
python train_model.py --train_input_file <训练文件> --test_input_file <测试文件> --do_train true --do_test true --load_clustering .
查看train_model.py
文件以获取可用参数的完整列表。
输出示例
{
"text":"这些小猫需要家,位于奥马哈地区!它们已经接种疫苗并绝育/节育。它们需要在1月1日之前离开!请转发以帮助传播消息!",
"geotagging":{
"lat":41.257160,
"lon":-95.995102,
"confidence":0.9950085878372192
}
}
{
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"id": 1,
"properties": {
"ID": 0
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-96.296363, 41.112793],
[-96.296363, 41.345177],
[-95.786877, 41.345177],
[-95.786877, 41.112793],
[-96.296363, 41.112793]
]
]
}
},
{
"type": "Feature",
"id": 2,
"properties": {
"ID": 0
},
"geometry": {
"type": "Point",
"coordinates": [-95.995102, 41.257160]
}
}
]
}
数据集
我们的团队为两种不同的训练方法策划了两个全面的数据集。这些数据集旨在用于训练和验证模型。在存储库的问题部分分享你的训练结果。
地区方法的目标是研究世界上人口最多的地区的数据集。
- 是一个包含50万条文本及其相应地理坐标的标注语料库
- 覆盖123个地区
- 每个地点包含5000条推文
季节方法的目标是识别帖子的时间/日期、内容和位置之间的相关性。应分析并利用时区差异以及事件的季节性来预测位置。例如:雪更可能出现在北半球,尤其是在12月。摇滚音乐会更可能在晚上和大城市举行,因此应使用关于音乐会的帖子时间来确定作者的时区并缩小潜在位置的范围。
- 是一个包含超过60万条文本的.json文件
- 收集时间跨度为12个月
- 覆盖15个不同时区
- 聚焦于6个国家(古巴、伊朗、俄罗斯、朝鲜、叙利亚、委内瑞拉)
您的自定义数据。 地理标记模型支持在自定义数据集上进行训练和测试。请准备CSV格式的数据,包含以下列:text
、lat
和lon
。
置信度和预测
地理标记模型融入了置信度估计,以评估预测坐标的可靠性。输出中的相关性字段表示预测置信度,范围从0.0
到1.0
。数值越高表示置信度越高。
有关置信度估计和如何使用模型进行地理标记预测的详细信息,请参阅inference.py
文件。该文件提供了一个示例脚本,演示了模型架构和置信度估计的集成。
欢迎!
Fork用户
请随意探索代码,根据您的具体需求进行调整,并将其集成到您的项目中。如果您有任何问题或需要帮助,请随时联系我们。我们非常感谢您的反馈,并致力于不断改进地理标记模型。