LitData: 优化数据集,加速AI模型训练
在当今数据驱动的AI时代,高效处理和利用大规模数据集变得至关重要。LitData应运而生,作为一个专门用于优化和转换大规模数据集的Python库,它为AI研究人员和工程师提供了强大的工具,可以显著提升数据处理效率和模型训练速度。
LitData的核心功能
LitData主要提供两大核心功能:
- 优化数据集以加速模型训练
- 大规模转换数据集
这两项功能使LitData成为处理大规模数据集的理想选择。让我们深入了解LitData的主要特性和使用方法。
优化数据集,加速模型训练
LitData通过优化数据集格式,实现了数据的高效流式处理,从而大幅提升了模型训练速度。以下是LitData在这方面的主要特性:
1. 流式处理云端大规模数据集
使用LitData,你可以直接从云存储中流式读取数据,无需将整个数据集下载到本地。这极大地节省了时间和存储空间。
from litdata import StreamingDataset, StreamingDataLoader
dataset = StreamingDataset('s3://my-bucket/my-data', shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=64)
for batch in dataloader:
process(batch) # 替换为你的数据处理逻辑
2. 支持多GPU、多节点训练
LitData优化后的数据集可以自动适配分布式训练环境,在多GPU或多节点上高效流式传输数据。
train_dataset = StreamingDataset('s3://my-bucket/my-train-data', shuffle=True, drop_last=True)
train_dataloader = StreamingDataLoader(train_dataset, batch_size=64)
val_dataset = StreamingDataset('s3://my-bucket/my-val-data', shuffle=False, drop_last=False)
val_dataloader = StreamingDataLoader(val_dataset, batch_size=64)
3. 支持多种云存储提供商
LitData支持从多种常见云存储提供商读取优化后的数据集,包括AWS S3、Google Cloud Storage和Azure Blob Storage。
import os
import litdata as ld
# 从AWS S3读取数据
aws_storage_options = {
"AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'],
"AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'],
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
# 从GCS读取数据
gcp_storage_options = {
"project": os.environ['PROJECT_ID'],
}
dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options)
# 从Azure读取数据
azure_storage_options = {
"account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
"credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY']
}
dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options)
4. 支持暂停和恢复数据流
LitData提供了有状态的StreamingDataLoader,允许你在长时间训练过程中暂停和恢复数据流,而不会丢失进度。
import os
import torch
from litdata import StreamingDataset, StreamingDataLoader
dataset = StreamingDataset("s3://my-bucket/my-data", shuffle=True)
dataloader = StreamingDataLoader(dataset, num_workers=os.cpu_count(), batch_size=64)
# 如果存在,恢复dataLoader状态
if os.path.isfile("dataloader_state.pt"):
state_dict = torch.load("dataloader_state.pt")
dataloader.load_state_dict(state_dict)
# 遍历数据
for batch_idx, batch in enumerate(dataloader):
# 每1000个批次保存一次状态
if batch_idx % 1000 == 0:
torch.save(dataloader.state_dict(), "dataloader_state.pt")
5. 优化LLM预训练数据
LitData为LLM预训练提供了高度优化的功能。首先需要对整个数据集进行分词,然后就可以高效地消费这些数据。
import json
from pathlib import Path
import zstandard as zstd
from litdata import optimize, TokensLoader
from tokenizer import Tokenizer
from functools import partial
# 1. 定义一个函数,将jsonl文件中的文本转换为tokens
def tokenize_fn(filepath, tokenizer=None):
with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
for row in f:
text = json.loads(row)["text"]
if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub":
continue # 排除与starcoder重叠的GitHub数据
text_ids = tokenizer.encode(text, bos=False, eos=True)
yield text_ids
if __name__ == "__main__":
# 2. 生成输入(我们将优化SlimPajama数据集中的所有压缩json文件)
input_dir = "./slimpajama-raw"
inputs = [str(file) for file in Path(f"{input_dir}/SlimPajama-627B/train").rglob("*.zst")]
# 3. 将优化后的数据存储在所需位置
outputs = optimize(
fn=partial(tokenize_fn, tokenizer=Tokenizer(f"{input_dir}/checkpoints/Llama-2-7b-hf")), # 注意:你可以使用HF tokenizer或其他tokenizer
inputs=inputs,
output_dir="./slimpajama-optimized",
chunk_size=(2049 * 8012),
# 这一点很重要,告诉LitData我们正在编码连续的1D数组(tokens)
# LitData会跳过为每个样本存储元数据,即所有tokens都被连接成一个大张量
item_loader=TokensLoader(),
)
6. 组合数据集
LitData允许你混合和匹配不同的数据集,以进行实验并创建更好的模型。
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader, TokensLoader
from tqdm import tqdm
import os
train_datasets = [
StreamingDataset(
input_dir="s3://tinyllama-template/slimpajama/train/",
item_loader=TokensLoader(block_size=2048 + 1), # LLM使用的优化tokens加载器
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir="s3://tinyllama-template/starcoder/",
item_loader=TokensLoader(block_size=2048 + 1), # LLM使用的优化tokens加载器
shuffle=True,
drop_last=True,
),
]
# 按以下比例混合SlimPajama数据和Starcoder数据:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights, iterate_over_all=False)
train_dataloader = StreamingDataLoader(combined_dataset, batch_size=8, pin_memory=True, num_workers=os.cpu_count())
# 遍历组合后的数据集
for batch in tqdm(train_dataloader):
pass
大规模转换数据集
除了优化数据集以加速模型训练,LitData还提供了强大的功能来转换大规模数据集。这些功能使得数据预处理、特征工程等任务变得更加高效和灵活。
1. 并行化数据转换(map操作)
LitData的map
操作允许你并行地对数据集的不同部分应用相同的转换,从而大大节省时间和精力。
from litdata import map
from PIL import Image
# 注意:inputs也可以直接引用s3上的文件
input_dir = "my_large_images"
inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
# resize_image函数接收一个输入(image_path)和输出目录
# 写入output_dir的文件会被持久化
def resize_image(image_path, output_dir):
output_image_path = os.path.join(output_dir, os.path.basename(image_path))
Image.open(image_path).resize((224, 224)).save(output_image_path)
map(
fn=resize_image,
inputs=inputs,
output_dir="s3://my-bucket/my_resized_images",
)
2. 支持S3兼容的云对象存储
LitData支持与S3兼容的对象存储服务集成,如MinIO,这为数据存储提供了灵活性和成本节省的选择。
# 设置环境变量以连接MinIO
export AWS_ACCESS_KEY_ID=access_key
export AWS_SECRET_ACCESS_KEY=secret_key
export AWS_ENDPOINT_URL=http://localhost:9000 # MinIO端点
# 或者在~/.aws/{credentials,config}中配置凭证和端点
mkdir -p ~/.aws && \
cat <<EOL >> ~/.aws/credentials
[default]
aws_access_key_id = access_key
aws_secret_access_key = secret_key
EOL
cat <<EOL >> ~/.aws/config
[default]
endpoint_url = http://localhost:9000 # MinIO端点
EOL
3. 支持数据加密和解密
LitData支持在chunk/sample级别对数据进行加密和解密,确保敏感信息在存储过程中受到保护。
from litdata import optimize
from litdata.utilities.encryption import FernetEncryption
import numpy as np
from PIL import Image
# 使用密码初始化FernetEncryption进行样本级加密
fernet = FernetEncryption(password="your_secure_password", level="sample")
data_dir = "s3://my-bucket/optimized_data"
def random_image(index):
"""生成随机图像用于演示"""
fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
return {"image": fake_img, "class": index}
# 在应用加密的同时优化数据
optimize(
fn=random_image,
inputs=list(range(5)), # 示例输入: [0, 1, 2, 3, 4]
num_workers=1,
output_dir=data_dir,
chunk_bytes="64MB",
encryption=fernet,
)
# 将加密密钥保存到文件中以备后用
fernet.save("fernet.pem")
LitData的性能优势
LitData在数据处理和优化方面展现出了卓越的性能。以下是一些关键的性能指标:
-
数据流式处理速度: LitData优化后的数据集在流式处理速度上比未优化的数据快20倍,比其他流式解决方案快2倍。
-
数据优化速度: LitData在优化1.2百万张ImageNet图像时,比其他框架快3-5倍。
-
存储效率: LitData优化后的数据集大小与其他框架相当,但文件数量更少,有利于管理。
-
分布式处理能力: LitData支持在多台机器上并行处理大规模工作负载,可以将数据处理任务的完成时间从数周缩短到数分钟。
结语
LitData为处理和优化大规模数据集提供了一套强大而灵活的工具。无论是加速AI模型训练,还是进行大规模数据转换,LitData都能显著提升效率和性能。它支持多种云存储平台,提供了数据加密功能,并且可以轻松集成到现有的AI开发流程中。
对于需要处理大规模数据集的AI研究人员和工程师来说,LitData无疑是一个值得尝试的工具。它不仅可以加速数据处理和模型训练过程,还能帮助更有效地管理和利用大规模数据集,从而推动AI研究和应用的进步。
随着数据规模的不断增长和AI模型的日益复杂,像LitData这样的工具将在未来的AI开发中扮演越来越重要的角色。我们期待看到LitData在更多领域的应用,以及它如何继续evolve以满足AI社区不断变化的需求。