TorchGeo是一个PyTorch领域库,类似于torchvision,提供特定于地理空间数据的数据集、采样器、转换和预训练模型。
这个库的目标是简化以下两方面的工作:
- 让机器学习专家处理地理空间数据变得简单;
- 让遥感专家探索机器学习解决方案变得简单。
安装
推荐使用pip安装TorchGeo:
$ pip install torchgeo
文档
可以在ReadTheDocs上找到TorchGeo的文档。这里包括API文档、贡献指南和多个教程。详细信息请查看我们的论文、播客、教程视频和博客文章。
使用示例
以下部分介绍了使用TorchGeo可以做的一些基本操作。
首先,我们将导入以下部分中使用的各种类和函数:
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import SemanticSegmentationTask
地理空间数据集和采样器
许多遥感应用涉及处理地理空间数据集——带有地理元数据的数据集。由于数据的多样性,这些数据集可能难以处理。地理空间影像通常是多光谱的,每颗卫星的光谱带和空间分辨率都不同。此外,每个文件可能使用不同的坐标参考系统(CRS),需要将数据重新投影到匹配的CRS中。
在这个示例中,我们演示了如何使用TorchGeo处理地理空间数据并从结合Landsat和农作物数据层(CDL)的数据中采样小的图像块。首先,我们假设用户已经下载了Landsat 7和8的影像。由于Landsat 8比Landsat 7有更多的光谱带,我们只使用两颗卫星共有的光谱带。通过将这两组数据集进行并集,我们创建了一个包括所有Landsat 7和8影像的单一数据集。
landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"])
landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"])
landsat = landsat7 | landsat8
接下来,我们将该数据集与CDL数据集进行交集。我们选择交集而不是并集,以确保只从同时具有Landsat和CDL数据的区域进行采样。注意,我们可以自动下载并校验CDL数据。另外,需要注意的是,这些数据集中的每个文件可能使用不同的坐标参考系统(CRS)或分辨率,但TorchGeo会自动确保使用匹配的CRS和分辨率。
cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl
这个数据集现在可以与PyTorch数据加载器一起使用。与基准数据集不同,地理空间数据集通常包含非常大的图像。例如,CDL数据集由覆盖整个美国大陆的单一图像组成。为了使用地理坐标从这些数据集中采样,TorchGeo定义了多个采样器。在这个示例中,我们将使用返回256 x 256像素图像和每个纪元10,000个样本的随机采样器。我们还使用自定义的合并函数将每个样本字典合并为一个小批量样本。
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)
这个数据加载器现在可以在你的正常训练/评估管道中使用。
for batch in dataloader:
image = batch["image"]
mask = batch["mask"]
# 训练模型,或使用预训练模型进行预测
许多应用涉及基于地理空间元数据智能地组合数据集。例如,用户可能希望:
- 结合多种图像源的数据集并将其视为等效(如Landsat 7和8)
- 结合不同地理位置的数据集(如Chesapeake NY和PA)
这些组合要求所有查询都至少存在于一个数据集中,并可以使用UnionDataset
创建。同样,用户可能希望:
- 结合图像和目标标签并同时对两者进行采样(如Landsat和CDL)
- 为多模态学习或数据融合结合多种图像源的数据集(如Landsat和Sentinel)
这些组合要求所有查询都存在于两个数据集中,并可以使用IntersectionDataset
创建。当你使用交集(&
)和并集(|
)运算符时,TorchGeo会自动为你组合这些数据集。
基准数据集
TorchGeo包括多个基准数据集——包含输入图像和目标标签的数据集。这些数据集包括图像分类、回归、语义分割、目标检测、实例分割、变化检测等任务的数据集。
如果你以前使用过torchvision,这些数据集应该会很熟悉。在这个示例中,我们将创建一个用于西北工业大学(NWPU)高分辨率十分类(VHR-10)地理空间目标检测数据集。这个数据集可以像torchvision一样自动下载、校验和解压。
from torch.utils.data import DataLoader
from torchgeo.datamodules.utils import collate_fn_detection
from torchgeo.datasets import VHR10
# 初始化数据集
dataset = VHR10(root="...", download=True, checksum=True)
# 使用自定义合并函数初始化dataloader
dataloader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=4,
collate_fn=collate_fn_detection,
)
# 训练循环
for batch in dataloader:
images = batch["image"] # 图像列表
boxes = batch["boxes"] # 包围框列表
labels = batch["labels"] # 标签列表
masks = batch["masks"] # 掩码列表
# 训练模型,或使用预训练模型进行预测
所有 TorchGeo 数据集都与 PyTorch 数据加载器兼容,使它们易于集成到现有的训练工作流程中。在 TorchGeo 中的基准数据集与 torchvision 中的类似数据集之间的唯一区别在于,每个数据集会返回带有每个 PyTorch Tensor
键的字典。
预训练权重
预训练权重在计算机视觉的迁移学习任务中被证明是非常有益的。实践者通常使用在包含 RGB 图像的 ImageNet 数据集上预训练的模型。然而,遥感数据往往超越 RGB,具有可以跨传感器变化的额外多光谱通道。TorchGeo 是第一个支持在不同多光谱传感器上预训练模型的库,并采用了 torchvision 的 多权重 API。目前可用权重的摘要可以在 文档 中查看。要创建一个在 Sentinel-2 图像上预训练权重的 timm Resnet-18 模型,可以进行以下操作:
import timm
from torchgeo.models import ResNet18_Weights
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
这些权重也可以直接用于以下章节中通过 weights
参数显示的 TorchGeo Lightning 模块中。有关笔记本示例,请参见此 教程。
使用 Lightning 实现可重复性
为了便于直接比较文献中发表的结果,并进一步减少运行 TorchGeo 数据集实验所需的样板代码,我们创建了具有定义良好的训练-验证-测试分割的 Lightning 数据模块 和用于分类、回归和语义分割等各种任务的 训练器。这些数据模块展示了如何从 kornia 库中合并增强,包括预处理变换(带有预计算的通道统计),并让用户轻松试验与数据本身相关的超参数(而不是建模过程)。在 Inria 航拍图像标注 数据集上训练语义分割模型就像几个导入和四行代码一样简单。
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights=True,
in_channels=3,
num_classes=2,
loss="ce",
ignore_index=None,
lr=0.1,
patience=6,
)
trainer = Trainer(default_root_dir="...")
trainer.fit(model=task, datamodule=datamodule)
TorchGeo 还支持使用 LightningCLI 的命令行界面训练。它可以通过以下两种方式调用:
# 如果已安装 torchgeo
torchgeo
# 如果已安装 torchgeo,或已将其克隆到当前目录
python3 -m torchgeo
它支持命令行配置或 YAML/JSON 配置文件。有效选项可以从帮助消息中找到:
# 查看有效阶段
torchgeo --help
# 查看有效训练器选项
torchgeo fit --help
# 查看有效模型选项
torchgeo fit --model.help ClassificationTask
# 查看有效数据选项
torchgeo fit --data.help EuroSAT100DataModule
使用以下配置文件:
trainer:
max_epochs: 20
model:
class_path: ClassificationTask
init_args:
model: "resnet18"
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
download: true
我们可以看到脚本在运行:
# 训练和验证模型
torchgeo fit --config config.yaml
# 仅验证
torchgeo validate --config config.yaml
# 计算和报告测试精度
torchgeo test --config config.yaml --ckpt_path=...
如果需要扩展它以添加新功能,它还可以导入并在 Python 脚本中使用:
from torchgeo.main import main
main(["fit", "--config", "config.yaml"])
有关更多详细信息,请参阅 Lightning 文档。
引用
如果您在工作中使用此软件,请引用我们的 论文:
@inproceedings{Stewart_TorchGeo_Deep_Learning_2022,
address = {Seattle, Washington},
author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam},
booktitle = {Proceedings of the 30th International Conference on Advances in Geographic Information Systems},
doi = {10.1145/3557915.3560953},
month = nov,
pages = {1--12},
publisher = {Association for Computing Machinery},
series = {SIGSPATIAL '22},
title = {{TorchGeo}: Deep Learning With Geospatial Data},
url = {https://dl.acm.org/doi/10.1145/3557915.3560953},
year = {2022}
}
贡献
本项目欢迎贡献和建议。如果您想提交拉取请求,请参阅我们的 贡献指南 了解更多信息。
此项目已采用 Microsoft 开源行为准则。有关更多信息,请参阅 行为准则 FAQ 或联系 opencode@microsoft.com 获取任何其他问题或意见。