用于WasmEdge WASI-NN的MediaPipe任务的Rust库
介绍
- 易于使用:类似mediapipe-python的低代码API。
- 低开销:处理过程中无不必要的数据复制、分配和释放。
- 灵活:用户可以使用自定义媒体字节作为输入。
- 对于TfLite模型,该库不仅支持从[MediaPipe Solutions]下载的所有模型,还支持带有基本信息的**TF Hub模型和自定义模型**。
状态
- 物体检测
- 图像分类
- 图像分割
- 交互式图像分割
- 手势识别
- 手部关键点检测
- 图像嵌入
- 人脸检测
- 人脸关键点检测
- 姿势关键点检测
- 音频分类
- 文本分类
- 文本嵌入
- 语言检测
任务API
每个任务有三种类型:XxxBuilder
、Xxx
、XxxSession
。(Xxx
是任务名称)
-
XxxBuilder
用于创建任务实例Xxx
,有多个选项可设置。示例:使用
ImageClassifierBuilder
构建ImageClassifier
任务。let classifier = ImageClassifierBuilder::new() .max_results(3) // 设置最大结果数 .category_deny_list(vec!["denied label".into()]) // 设置拒绝列表 .gpu() // 设置运行设备 .build_from_file(model_path)?; // 创建图像分类器
-
Xxx
是任务实例,包含任务信息和模型信息。示例:使用
ImageClassifier
创建新的ImageClassifierSession
let classifier_session = classifier.new_session()?;
-
XxxSession
是执行预处理、推理和后处理的运行会话,有缓冲区存储中间结果。示例:使用
ImageClassifierSession
运行图像分类任务并返回分类结果:let classification_result = classifier_session.classify(&image::open(img_path)?)?;
注意:可重复使用会话以提高速度,如果代码只使用会话一次,可以使用任务的包装函数来简化。
// let classifier_session = classifier.new_session()?; // let classification_result = classifier_session.classify(&image::open(img_path)?)?; // 上面两行代码等同于: let classification_result = classifier.classify(&image::open(img_path)?)?;
可用任务
- 视觉:
- 手势识别:
GestureRecognizerBuilder
->GestureRecognizer
->GestureRecognizerSession
- 手部检测:
HandDetectorBuilder
->HandDetector
->HandDetectorSession
- 图像分类:
ImageClassifierBuilder
->ImageClassifier
->ImageClassifierSession
- 图像嵌入:
ImageEmbedderBuilder
->ImageEmbedder
->ImageEmbedderSession
- 图像分割:
ImageSegmenterBuilder
->ImageSegmenter
->ImageSegmenterSession
- 物体检测:
ObjectDetectorBuilder
->ObjectDetector
->ObjectDetectorSession
- 手势识别:
- 音频:
- 音频分类:
AudioClassifierBuilder
->AudioClassifier
->AudioClassifierSession
- 音频分类:
- 文本:
- 文本分类:
TextClassifierBuilder
->TextClassifier
->TextClassifierSession
- 文本分类:
示例
图像分类
use mediapipe_rs::tasks::vision::ImageClassifierBuilder;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (model_path, img_path) = parse_args()?;
let classification_result = ImageClassifierBuilder::new()
.max_results(3) // 设置最大结果数
.build_from_file(model_path)? // 创建图像分类器
.classify(&image::open(img_path)?)?; // 执行推理并生成结果
// 显示格式化的结果信息
println!("{}", classification_result);
Ok(())
}
示例输入:(图像下载自 https://yellow-cdn.veclightyear.com/835a84d5/eda7dfa2-8039-406e-84ff-c1b833fff26b.jpg)
控制台示例输出:
$ cargo run --release --example image_classification -- ./assets/models/image_classification/efficientnet_lite0_fp32.tflite ./assets/testdata/img/burger.jpg
已完成 release [优化] 目标在 0.01 秒内
运行 `/mediapipe-rs/./scripts/wasmedge-runner.sh target/wasm32-wasi/release/examples/image_classification.wasm ./assets/models/image_classification/efficientnet_lite0_fp32.tflite ./assets/testdata/img/burger.jpg`
分类结果:
分类 #0:
类别 #0:
类别名称: "芝士汉堡"
显示名称: 无
得分: 0.70625573
索引: 933
物体检测
use mediapipe_rs::postprocess::utils::draw_detection;
use mediapipe_rs::tasks::vision::ObjectDetectorBuilder;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (model_path, img_path, output_path) = parse_args()?;
let mut input_img = image::open(img_path)?;
let detection_result = ObjectDetectorBuilder::new()
.max_results(2) // 设置最大结果数
.build_from_file(model_path)? // 创建物体检测器
.detect(&input_img)?; // 进行推理并生成结果
// 显示格式化的结果信息
println!("{}", detection_result);
if let Some(output_path) = output_path {
// 在图像上绘制检测结果
draw_detection(&mut input_img, &detection_result);
// 保存输出图像
input_img.save(output_path)?;
}
Ok(())
}
示例输入:(图片下载自 https://yellow-cdn.veclightyear.com/835a84d5/3c8fc365-b65d-4b4e-a71e-2c840b48416b.jpg)
控制台输出示例:
$ cargo run --release --example object_detection -- ./assets/models/object_detection/efficientdet_lite0_fp32.tflite ./assets/testdata/img/cat_and_dog.jpg
已完成 release [优化] 目标在 0.00 秒内
运行 `/mediapipe-rs/./scripts/wasmedge-runner.sh target/wasm32-wasi/release/examples/object_detection.wasm ./assets/models/object_detection/efficientdet_lite0_fp32.tflite ./assets/testdata/img/cat_and_dog.jpg`
检测结果:
检测 #0:
边界框: (左: 0.12283102, 上: 0.38476586, 右: 0.51069236, 下: 0.851197)
类别 #0:
类别名称: "猫"
显示名称: 无
得分: 0.8460574
索引: 16
检测 #1:
边界框: (左: 0.47926134, 上: 0.06873521, 右: 0.8711677, 下: 0.87927735)
类别 #0:
类别名称: "狗"
显示名称: 无
得分: 0.8375256
索引: 17
输出示例:
文本分类
fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = parse_args()?;
let text_classifier = TextClassifierBuilder::new()
.max_results(1) // 设置最大结果数
.build_from_file(model_path)?; // 创建文本分类器
let positive_str = "我非常喜欢编程!";
let negative_str = "我不喜欢下雨。";
// 分类并显示格式化的结果信息
let result = text_classifier.classify(&positive_str)?;
println!("`{}` -- {}", positive_str, result);
let result = text_classifier.classify(&negative_str)?;
println!("`{}` -- {}", negative_str, result);
Ok(())
}
控制台输出示例(使用 bert 模型):
$ cargo run --release --example text_classification -- ./assets/models/text_classification/bert_text_classifier.tflite
已完成 release [优化] 目标在 0.01 秒内
运行 `/mediapipe-rs/./scripts/wasmedge-runner.sh target/wasm32-wasi/release/examples/text_classification.wasm ./assets/models/text_classification/bert_text_classifier.tflite`
`我非常喜欢编程!` -- 分类结果:
分类 #0:
类别 #0:
类别名称: "积极"
显示名称: 无
得分: 0.99990463
索引: 1
`我不喜欢下雨。` -- 分类结果:
分类 #0:
类别 #0:
类别名称: "消极"
显示名称: 无
得分: 0.99541473
索引: 0
手势识别
use mediapipe_rs::tasks::vision::GestureRecognizerBuilder;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (model_path, img_path) = parse_args()?;
let gesture_recognition_results = GestureRecognizerBuilder::new()
.num_hands(1) // 设置只识别一只手
.max_results(1) // 设置最大结果数
.build_from_file(model_path)? // 创建任务实例
.recognize(&image::open(img_path)?)?; // 进行推理并生成结果
for g in gesture_recognition_results {
println!("{}", g.gestures.classifications[0].categories[0]);
}
Ok(())
}
示例输入:(图片下载自 https://yellow-cdn.veclightyear.com/835a84d5/f3e27d24-b3d5-452a-83a3-3474a4e6b18a.jpg)
控制台输出示例:
$ cargo run --release --example gesture_recognition -- ./assets/models/gesture_recognition/gesture_recognizer.task ./assets/testdata/img/gesture_recognition_google_samples/victory.jpg
Finished release [optimized] target(s) in 0.02s
Running `/mediapipe-rs/./scripts/wasmedge-runner.sh target/wasm32-wasi/release/examples/gesture_recognition.wasm ./assets/models/gesture_recognition/gesture_recognizer.task ./assets/testdata/img/gesture_recognition_google_samples/victory.jpg`
类别名称: "Victory"
显示名称: None
得分: 0.9322255
索引: 6
音频输入
任何实现了AudioData
特征的音频媒体都可以作为音频任务的输入。
目前,该库内置了支持symphonia
、ffmpeg
和原始音频数据作为输入的实现。
音频分类示例:
use mediapipe_rs::tasks::audio::AudioClassifierBuilder;
#[cfg(feature = "ffmpeg")]
use mediapipe_rs::preprocess::audio::FFMpegAudioData;
#[cfg(not(feature = "ffmpeg"))]
use mediapipe_rs::preprocess::audio::SymphoniaAudioData;
#[cfg(not(feature = "ffmpeg"))]
fn read_audio_using_symphonia(audio_path: String) -> SymphoniaAudioData {
let file = std::fs::File::open(audio_path).unwrap();
let probed = symphonia::default::get_probe()
.format(
&Default::default(),
symphonia::core::io::MediaSourceStream::new(Box::new(file), Default::default()),
&Default::default(),
&Default::default(),
)
.unwrap();
let codec_params = &probed.format.default_track().unwrap().codec_params;
let decoder = symphonia::default::get_codecs()
.make(codec_params, &Default::default())
.unwrap();
SymphoniaAudioData::new(probed.format, decoder)
}
#[cfg(feature = "ffmpeg")]
fn read_video_using_ffmpeg(audio_path: String) -> FFMpegAudioData {
ffmpeg_next::init().unwrap();
FFMpegAudioData::new(ffmpeg_next::format::input(&audio_path.as_str()).unwrap()).unwrap()
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (model_path, audio_path) = parse_args()?;
#[cfg(not(feature = "ffmpeg"))]
let audio = read_audio_using_symphonia(audio_path);
#[cfg(feature = "ffmpeg")]
let audio = read_video_using_ffmpeg(audio_path);
let classification_results = AudioClassifierBuilder::new()
.max_results(3) // 设置最大结果数
.build_from_file(model_path)? // 创建任务实例
.classify(audio)?; // 进行推理并生成结果
// 显示格式化的结果消息
for c in classification_results {
println!("{}", c);
}
Ok(())
}
使用会话加速
会话包括推理会话(如TfLite解释器)、输入和输出缓冲区等。 显式使用会话可以重用这些资源以加快速度。
示例:文本分类
原始方法:
use mediapipe_rs::tasks::text::TextClassifier;
use mediapipe_rs::postprocess::ClassificationResult;
use mediapipe_rs::Error;
fn inference(
text_classifier: &TextClassifier,
inputs: &Vec<String>
) -> Result<Vec<ClassificationResult>, Error> {
let mut res = Vec::with_capacity(inputs.len());
for input in inputs {
// text_classifier 每次都会创建新的会话
res.push(text_classifier.classify(input.as_str())?);
}
Ok(res)
}
使用会话加速:
use mediapipe_rs::tasks::text::TextClassifier;
use mediapipe_rs::postprocess::ClassificationResult;
use mediapipe_rs::Error;
fn inference(
text_classifier: &TextClassifier,
inputs: &Vec<String>
) -> Result<Vec<ClassificationResult>, Error> {
let mut res = Vec::with_capacity(inputs.len());
// 只创建一个会话并重用会话中的资源
let mut session = text_classifier.new_session()?;
for input in inputs {
res.push(session.classify(input.as_str())?);
}
Ok(res)
}
使用FFMPEG功能处理视频和音频
使用cargo构建带有ffmpeg
功能的库时,用户必须设置以下环境变量:
FFMPEG_DIR
:预构建的FFmpeg库路径。您可以从以下地址下载: https://github.com/yanghaku/ffmpeg-wasm32-wasi/releasesWASI_SDK
或(WASI_SYSROOT
和CLANG_RT
),您可以从以下地址下载: https://github.com/WebAssembly/wasi-sdk/releasesBINDGEN_EXTRA_CLANG_ARGS
:为libclang设置sysroot、target和函数可见性。 (sysroot必须是绝对路径)。
示例:
export FFMPEG_DIR=/path/to/ffmpeg/library
export WASI_SDK=/opt/wasi-sdk
export BINDGEN_EXTRA_CLANG_ARGS="--sysroot=/opt/wasi-sdk/share/wasi-sysroot --target=wasm32-wasi -fvisibility=default"
# 然后运行cargo
GPU和TPU支持
默认设备是CPU,用户可以使用API选择要使用的设备:
use mediapipe_rs::tasks::vision::ObjectDetectorBuilder;
fn create_gpu(model_blob: Vec<u8>) {
let detector_gpu = ObjectDetectorBuilder::new()
.gpu()
.build_from_buffer(model_blob)
.unwrap();
}
fn create_tpu(model_blob: Vec<u8>) {
let detector_tpu = ObjectDetectorBuilder::new()
.tpu()
.build_from_buffer(model_blob)
.unwrap();
}
注意
本项目由Google在Mediapipe上的工作使之成为可能。
相关链接
许可证
本项目采用Apache 2.0许可证。更多详情请参阅LICENSE。