跳到主要内容

34 篇博文 含有标签「Rust」

查看所有标签

Burn 是一个新的全面的动态深度学习框架,以极高的灵活性计算效率可移植性为其主要目标。

性能

我们相信深度学习框架的目标是将计算转化为有用的智能,因此我们将性能视为Burn的核心支柱

我们努力通过利用下面描述的多种优化技术来实现最高效率。

自动内核融合 💥

使用Burn意味着在任何后端对您的模型进行优化

在可能的情况下,我们提供一种自动和动态创建自定义内核的方式, 最小化在不同内存空间之间的数据重定位,当内存迁移是瓶颈时非常有用。

例如,您可以使用高级张量 API(请参见下面的Rust 代码片段)编写自己的GELU 激活函数。

fn gelu_custom < B : Backend , const D : usize > ( x : Tensor < B , D > ) -> Tensor < B , D > {
let x = x . clone ( ) * ( ( x / SQRT_2 ) . erf ( ) + 1 ) ;
x / 2
}

然后,在运行时将自动为您的特定实现创建一个自定义低级内核,并将与手工编写的 GPU 实现相媲美

内核由大约 60 行 WGSL WebGPU 着色语言组成,这是一种非常冗长的低级着色语言,您可能不希望用其来编写深度学习模型!

就目前而言,我们的融合策略仅为我们自己的 WGPU 后端实现,并且仅支持部分操作。

我们计划很快添加更多操作,并将这种技术扩展到未来的其他内部后端。

异步执行 ❤️‍🔥

对由 Burn 团队从头开始开发的后端,使用了一种异步执行风格,这允许执行各种优化,例如先前提到的自动内核融合。 异步执行还确保框架的正常执行不会阻塞模型计算,这意味着框架的开销不会显著影响执行速度。 反之,模型中的强烈计算不会干扰框架的响应。 有关我们异步后端的更多信息,请参阅此博客文章

线程安全的构建模块 🦞

Burn通过利用Rust的所有权系统来强调线程安全使用Burn,每个模块都是其权重的所有者

因此,可以将一个模块发送到另一个线程来计算梯度,然后将梯度发送到可以聚合它们的主线程,这样,您就可以进行多设备训练。

这与PyTorch的做法截然不同,PyTorch中的反向传播实际上会改变每个张量参数的梯度属性。 这是一个非线程安全的操作,因此需要较低级别的同步原语,请参考分布式训练。 请注意,这仍然非常快速,但不兼容不同的后端,并且实现起来相当困难。

智能内存管理 🦀

深度学习框架的主要作用之一是减少运行模型所需的内存量

处理内存的天真方式是每个张量都有自己的内存空间,在张量创建时分配,然后在张量超出范围时释放

然而,分配和释放数据非常昂贵,因此通常需要内存池来实现良好的吞吐量

Burn 提供了一个基础架构,可以轻松创建和选择后端的内存管理策略

有关 Burn 中内存管理的更多详细信息,请参阅此博客文章

Burn 的另一个非常重要的内存优化是,我们通过良好使用所有权系统,跟踪张量何时可以就地进行突变

即使这在单独的内存优化方面相对较小,但在训练或运行较大模型进行推理时,会显着累积,并有助于进一步减少内存使用

有关更多信息,请参阅有关张量处理的此博客文章

自动内核选择🎯

一个优秀的深度学习框架应该确保模型在所有硬件上运行顺畅。 然而,并非所有硬件在执行速度方面都表现相同。

例如,矩阵乘法内核可以通过许多不同的参数启动,这些参数对矩阵的大小和硬件非常敏感。 使用错误的配置可能会大幅降低执行速度(在极端情况下甚至会降低10倍或更多),因此选择正确的内核成为优先事项。

通过我们自制的后端,我们会自动运行基准测试,并选择适用于当前硬件和矩阵大小的最佳配置,配合合理的缓存策略。 这会稍微增加热身执行时间,但在几次前向和后向传递之后很快稳定下来,从长远来看节省大量时间。 请注意,这个功能不是强制性的,当冷启动优先于优化吞吐量时可以禁用。

硬件特定功能🔥

深度学习主要依赖矩阵乘法作为其核心操作,因为这是全连接神经网络建模的方式

越来越多的硬件制造商专门为矩阵乘法工作负载进行优化。 例如,Nvidia 拥有其 Tensor Cores,并且今天大多数手机都配备了人工智能专用芯片。

就目前而言,我们支持 Tensor Cores 与我们的 LibTorch 和 Candle 后端,但尚未支持其他加速器。 我们希望这个问题在某个时间点得到解决,以支持我们的 WGPU 后端。

自定义后端扩展 🎒

Burn 致力于成为最灵活的深度学习框架。 虽然保持与各种后端的兼容性至关重要,但 Burn 也提供了扩展后端实现功能以满足个人建模需求的能力。

这种多功能性在许多方面都具有优势,例如支持自定义操作,如快闪注意力,或手动编写特定后端的内核以增强性能。 请查看 Burn Book 🔥 中的此部分以获取更多详细信息。

训练和推理

使用 Burn,整个深度学习工作流程变得轻松,因为您可以通过符合人体工程学的仪表板监控培训进展, 并在嵌入式设备到大型 GPU 集群中的任何位置运行推理。

Burn 从头开始构建时考虑了培训和推理。 值得一提的是,与 PyTorch 等框架相比,Burn 如何简化了从训练到部署的过度,消除了代码更改的需要。

培训仪表板 📈

正如您在以前的视频中所看到 (单击图片!), 一个基于 Ratatui 框架的新的终端 UI 仪表板使用户能够轻松地跟踪他们的培训,无需连接到任何外部应用程序。

您可以实时查看培训和验证指标的更新,并仅使用箭头键分析任何已注册指标的终身进展或最近历史。 在不崩溃的情况下中断培训循环,使潜在的检查点能够完全编写或重要的代码片段可以在没有干扰的情况下完成 🛡

ONNX 支持 🐫

ONNX(Open Neural Network Exchange)是一种开放标准格式,可导出深度学习模型的架构和权重。

Burn 支持符合 ONNX 标准的模型的导入, 因此您可以轻松地将在另一个框架(如 TensorFlow 或 PyTorch)中编写的模型移植到 Burn, 以从我们的框架提供的所有优势中受益。

我们的 ONNX 支持在 Burn 书的这一部分中有进一步描述 🔥。

备注

此 crate 正在积极开发中,目前支持有限的 ONNX 运算符。

导入 PyTorch 模型 🚚

支持将 PyTorch 模型权重加载到 Burn 的本地模型架构中,确保无缝集成。 参见 Burn 书 🔥 中有关导入 PyTorch 的部分

浏览器内的推理 🌐

我们的几个后端可以编译成 Web AssemblyCandleNdArray 用于 CPU,WGPU 用于 GPU。

这意味着您可以直接在浏览器内运行推理。我们提供了几个示例:

  • MNIST, 您可以在其中绘制数字,然后一个小 convnet 试图找出它是什么! 2️⃣ 7️⃣ 😰
  • 图像分类, 您可以上传图像并进行分类! 🌄

嵌入式:不支持no_std ⚙️

Burn的核心组件支持no_std。这意味着它可以在没有操作系统的裸机环境中运行,如嵌入式设备

备注

截至目前,只有NdArray后端可以在no_std环境中使用。

后端

Burn 旨在在尽可能多的硬件上尽可能快速,并具有强大的实现。 我们认为这种灵活性对于现代需求至关重要,在这里您可以在云中训练模型,然后在用户硬件上部署,这些硬件因用户而异。

与其他框架相比,Burn 对支持许多后端的方法有很大不同。 通过设计,大多数代码都是以 Backend 特质通用的,这使我们能够构建具有可互换后端的 Burn。 这使得构建后端变得可能,并通过增加额外功能,比如自动微分和自动核融合。

我们已经实现了很多后端,全部列在下面👇

WGPU(WebGPU):跨平台 GPU 后端 🌐

可以运行在任何GPU上的首选后端。 基于最受欢迎和得到良好支持的Rust图形库 WGPU, 此后端自动以WebGPU着色语言WGSL为基础,针对Vulkan、OpenGL、Metal、Direct X11/12和WebGPU。 也可以编译成Web Assembly在浏览器中运行,同时利用GPU,可参见此演示。 有关此后端的更多信息,请参阅此博客。

WGPU后端是我们的第一个“内部后端”,这意味着我们完全控制其实现细节。 它充分优化了之前提到的性能特征,因为它是我们研究各种优化的实验场所。

有关更多详情,请查看WGPU后端README

Candle:使用 Candle 绑定的后端 🕯

基于 Hugging FaceCandle,这是一个专注于性能和易用性的极简主义 ML 框架, 该后端可以在支持 Web Assembly 的 CPU 上运行,也可以在支持 CUDA 的 Nvidia GPU 上运行。

有关更多详细信息, 请参阅 Candle 后端的 README

备注

该后端尚未完全完成,但在某些情况下(例如推断)可以工作。

LibTorch:使用 LibTorch 绑定的后端 🎆

在深度学习领域,PyTorch 无需介绍。 这个后端利用 PyTorch Rust 绑定,让您可以在 CPU、CUDA 和 Metal 上使用 LibTorch C++ 内核。

请查看 LibTorch 后端的 README 以了解更多详情。

NdArray:使用 NdArray 原始数据结构的后端 🦐

这个 CPU 后端确实不是我们最快的后端,但提供了极强的可移植性。 这是我们唯一支持 no_std 的后端。

查看 NdArray Backend README 以获取更多详细信息。

Autodiff:为任何后端带来反向传播的后端装饰器 🔄

与前述后端相反,Autodiff实际上是一个后端修饰器这意味着它不能单独存在;必须封装另一个后端

简单地使用Autodiff将基础后端包装起来,可以透明地为其提供自动微分支持,从而可以在模型上调用反向传播。

use burn::backend::{Autodiff, Wgpu}; 
use burn::tensor::{Distribution, Tensor};

fn main() {
type Backend = Autodiff<Wgpu>;
let x: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default);
let y: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default).require_grad();
let tmp = x.clone() + y.clone();
let tmp = tmp.matmul(x);
let tmp = tmp.exp();
let grads = tmp.backward();
let y_grad = y.grad(&grads).unwrap();
println!("{y_grad}");
}

请注意,对于不支持自动微分(用于推理)的后端,不可能在运行在此后端上的模型上调用backward这个方法, 因为这个方法只由Autodiff后端提供

请参阅Autodiff Backend README 获取更多详细信息。

Fusion:为支持内核融合的后端带来内核融合的后端装饰器 💥

此后端修饰符通过与内部后端支持的核融合增强后端

请注意,您可以将此后端与其他后端修饰符(如 Autodiff)组合使用。 目前,仅 WGPU 后端支持融合内核

use burn::backend::{Autodiff, Fusion, Wgpu};
use burn::tensor::{Distribution, Tensor};

fn main() {
type Backend = Autodiff<Fusion<Wgpu>>;

let x: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default);
let y: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default).require_grad();

let tmp = x.clone() + y.clone();
let tmp = tmp.matmul(x);
let tmp = tmp.exp();

let grads = tmp.backward();
let y_grad = y.grad(&grads).unwrap();
println!("{y_grad}");
}

值得注意的是,我们计划根据计算边界和内存边界操作实现自动梯度检查点, 这将与融合后端优雅地配合,使您的代码在训练期间运行得更快, 详细信息请参阅 Fusion Backend README

入门

刚听说过Burn吗?您来对地方了!继续阅读这一部分,希望您能很快上手。

The Burn Book 🔥

要有效地开始使用Burn,了解其关键组成部分和哲学至关重要。 这就是为什么我们强烈建议新用户阅读The Burn Book 🔥 的前几节。 它提供了详细的示例和解释,涵盖了框架的每个方面,包括张量模块优化器等构建模块, 一直到高级用法,例如编写自己的GPU内核。

备注

这个项目在不断发展,我们尽可能将这本书与新增内容保持最新。 然而,有时我们可能会漏掉一些细节,所以如果您发现有什么奇怪的地方, 请告诉我们!我们也很乐意接受Pull请求 😄

示例 🙏

让我们从一个代码片段开始,展示这个框架使用起来有多直观! 在以下内容中,我们声明一个神经网络模块,带有一些参数以及它的前向传递。

use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: nn::Linear<B>,
linear_outer: nn::Linear<B>,
dropout: nn::Dropout,
gelu: nn::Gelu,
}

impl<B: Backend> PositionWiseFeedForward<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);

self.linear_outer.forward(x)
}
}

我们在存储库中有相当多的示例, 展示了如何在不同情景中使用框架。 要获得更多实际见解,您可以克隆存储库并在计算机上直接运行任何一个示例!

预训练模型 🤖

我们保持一个更新和精选的使用Burn构建的模型和示例列表,有关更多详细信息, 请参阅 tracel-ai/models 存储库

找不到您想要的模型?不要犹豫,打开一个问题,我们可能会优先考虑。 使用Burn构建模型并想要分享吗?您还可以打开一个 Pull Request 并将您的模型添加到社区部分下!

为什么要使用 Rust 进行深度学习?🦀

深度学习是一种特殊形式的软件,在这种软件中,你需要非常高级的抽象,以及极快的执行时间

Rust 是这种用例的完美选择,因为它提供了零成本的抽象,可以轻松创建神经网络模块,并精细控制内存以优化每一个细节。

一个框架在高层次使用时易于使用,这样用户就可以专注于在人工智能领域进行创新。 但是,由于运行模型如此严重依赖计算,性能不能被忽视。

直到今天,解决这个问题的主流方法是在 Python 中提供 API,但依赖于诸如 C/C++ 等低级语言的绑定。 这降低了可移植性,增加了复杂性,并在研究人员和工程师之间产生摩擦。 我们认为 Rust 对于抽象的处理方式使其足够通用,可以解决这两种语言的困境。

Rust 还配备了 Cargo 软件包管理器,可以轻松构建、测试和部署任何环境的应用,通常在 Python 中很麻烦。

虽然 Rust 被认为是一种起初难以掌握的语言,但我们坚信经过一些练习后,它会带来更可靠、无 bug 的解决方案, 并且能够更快地构建(开玩笑😅)!

鱼雪

在Rust中,特性(features)是一种用于条件编译的机制

它们允许您根据需要启用或禁用某些代码块

本文将详细说明如何使用 cargo add --featurescargo run --features,并介绍它们的区别、优缺点以及示例代码。

1. cargo add --features

cargo add 是一个用来向项目的 Cargo.toml 文件中添加依赖项的命令。

通过 cargo add crate_name --features features_name,可以在添加依赖项时同时启用特定的特性。

使用方法

cargo add crate_name --features features_name

示例

cargo add serde --features derive

上述命令会将 serde 库添加到 Cargo.toml 文件的依赖项部分,并启用 serde 库的 derive 特性。

添加后的 Cargo.toml 可能如下所示:

[dependencies]
serde = { version = "1.0", features = ["derive"] }

优点

  • 便于管理:所有特性配置集中在 Cargo.toml 文件中,便于维护。
  • 自动启用:一旦在 Cargo.toml 中指定了特性,无需每次运行命令时都显式指定特性。

缺点

  • 需要修改 Cargo.toml 文件:如果只是临时需要某个特性,修改 Cargo.toml 文件可能显得繁琐。

2. cargo run --features

cargo run 是一个用来编译和运行当前项目的命令。

通过 cargo run --features features_name,可以在运行时启用特定的特性

使用方法

cargo run --features features_name

示例

假设项目的 Cargo.toml 文件如下:

[dependencies]
serde = "1.0"

[features]
special_feature = ["serde/derive"]

运行命令:

cargo run --features special_feature

此命令会在编译和运行时启用 special_feature 特性。

优点

  • 灵活性:无需修改 Cargo.toml 文件,便可以在运行时灵活启用或禁用特性。
  • 临时配置:适用于临时需要某些特性而不希望永久修改配置的情况。

缺点

  • 需显式指定:每次运行命令时都需要显式指定特性,可能会显得繁琐。

3. 区别

  • 配置方式cargo add --features 是在 Cargo.toml 中配置特性,cargo run --features 是在运行命令时临时指定特性。
  • 持久性cargo add --features 使特性配置永久保存于 Cargo.toml 中,cargo run --features 是临时配置。
  • 使用场景cargo add --features 适用于需要长期启用的特性,cargo run --features 适用于临时启用的特性。

4. 示例代码

使用 cargo add --features

Cargo.toml 中添加依赖并指定特性:

cargo add serde --features derive

Cargo.toml 文件:

[dependencies]
serde = { version = "1.0", features = ["derive"] }

Rust代码:

// 由于在 Cargo.toml 中已经指定了特性,因此无需显式导入 serde_derive
use serde::{Serialize, Deserialize};

#[derive(Serialize, Deserialize)]
struct MyStruct {
name: String,
age: u32,
}

fn main() {
let my_struct = MyStruct {
name: "Alice".to_string(),
age: 30,
};

let serialized = serde_json::to_string(&my_struct).unwrap();
println!("Serialized: {}", serialized);

println!("Hello, world!");
}

使用 cargo run --features

Cargo.toml 文件:

[dependencies]
serde = "1.0"

[features]
special_feature = ["serde/derive"]

Rust代码:

// 在这里显式导入特性相关的内容,因为特性是在运行时指定的
#[cfg(feature = "special_feature")]
#[macro_use]
extern crate serde_derive;

#[cfg(feature = "special_feature")]
#[derive(Serialize, Deserialize)]
struct MyStruct {
name: String,
age: u32,
}

fn main() {
#[cfg(feature = "special_feature")]
{
let my_struct = MyStruct {
name: "Alice".to_string(),
age: 30,
};

let serialized = serde_json::to_string(&my_struct).unwrap();
println!("Serialized: {}", serialized);
}

println!("Hello, world!");
}

运行命令:

cargo run --features special_feature

5. 总结

  • cargo add --featurescargo run --features 都可以用来启用Rust项目中的特性,但它们的使用场景不同。
  • cargo add --features 适用于需要长期启用的特性,通过在 Cargo.toml 中配置,特性将自动启用。
  • cargo run --features 适用于临时启用的特性,通过在运行时指定,避免了修改 Cargo.toml 文件的繁琐。

通过合理使用这两种方法,可以更加灵活和高效地管理Rust项目中的特性。

鱼雪

什么是模块

模块是用来分割与组织代码的一种方式。 在开发中,代码往往会变得越来越复杂,模块可以帮助我们更好地组织代码,提高代码的可读性和可维护性。 代码超过一屏幕时,就应该考虑将其拆分为模块。

在Rust中,组织代码的基本单元是模块。

模块是Rust中的一个重要概念,它可以帮助我们更好地组织代码,提高代码的可读性和可维护性。

在Rust中,模块中可以包含函数、结构体、枚举、trait等等,甚至模块中也可以包含其他模块。

首先我们来看,如何创建模块。

在Rust中哪些方式可以创建模块

  1. 使用mod关键字,mod关键字后面跟着模块名
  2. 单个文件,文件名即为模块名
  3. 包含mod.rs文件的目录,包含mod.rs文件的目录即为一个模块,目录名即为模块名

使用mod关键字

mod my_module {
pub fn my_function() {
println!("Hello, world!");
}
}

单个文件

// src/my_module.rs
pub fn my_function() {
println!("Hello, world!");
}

包含mod.rs文件的目录

// src/my_module/mod.rs
pub fn my_function() {
println!("Hello, world!");
}

以上三种方式都可以创建一个名为my_module的模块,其中包含一个名为my_function的函数。

那么创建好模块之后如何使用模块呢?

如何使用模块

创建好模块之后,就是怎么来使用模块,便于我们调用模块中的函数,或者使用模块中的结构体、枚举等。

首先就是需要声明模块,然后就可以使用模块中的函数了。 声明模块的方法是在文件中使用mod关键字,后面跟着模块名,跟着分号;

声明模块

mod my_module;

导入模块中的中的内容

可以导入模块中的函数、结构体、枚举等,使用use关键字,后面跟着模块名,再跟着::,再跟着要导入的内容。

如要要导入模块中的函数、结构体、枚举等,需要在模块中使用pub关键字,表示对外公开。 否则,模块中的内容默认是私有的,无法在其他模块中使用。

如果需要层层导出的话,可以使用pub use关键字,可以让这个模块引入的内容再被其他模块引入。

use my_module::my_function;
// 或者
// pub use my_module::my_function;

my_function();

引入模块有两种方式

  1. 绝对路径导入
  • 对于当前项目中:使用绝对路径导入,需要从crate根开始,以crate关键字开头
    • 对于第三方库:使用绝对路径导入,crate_name后跟::,再跟着模块名或者需要导入的内容
  1. 相对路径导入
    • 使用super关键字表示父模块,使用父模块下的内容,一般使用场景是在子模块中使用父模块的内容,比如写测试用例时
    • 使用self关键字表示当前模块,使用当前模块下的内容,可以省略self,一般使用场景是如果有同名的模块和函数时,可以使用self来区分

总结

在Rust中,使用模块大体分为两步:

  1. 创建模块,有三种方式可以创建模块,即:
  • mod my_module { ... }
    • 单个文件(src/my_module.rs)
    • 包含mod.rs文件的目录(src/my_module/mod.rs)
  1. 使用模块
  • 声明模块:mod my_module;,使用模块之前需要先声明模块
    • 导入模块中的内容:use my_module::my_function;,使用use关键字导入模块中的内容,使用pub use关键字可以让这个模块引入的内容再被其他模块引入
  1. 引入模块有两种方式:

    • 绝对路径导入:使用crate关键字(针对于当前项目的crate)或者crate_name模块名(针对当前crate之外的crate,即依赖的crate)
    • 相对路径导入:使用super关键字(一般用来写测试模块使用较多)或者self关键字(一般用来区分同名模块和函数)
  2. Rust中的mod.rs和Python中的__init__.py类似, 都是用来标识目录为一个模块的文件,mod.rs文件中可以包含模块的内容,也可以导入其他模块。 也可以使用pub use关键字逐层导出,方便使用,将所有需要用到的内容逐层导出到lib.rs,别的模块使用时不必写很长的路径,也更容易找。

鱼雪

在Rust的错误处理生态系统中,从标准库的std::error::Erroranyhowthiserrorsnafu, 每个库都在用法和功能上进行了不同程度的改进和演变。 下面是对这些库的改进和设计的详细说明。

标准库的std::error::Error

特点

  • 基础特性:定义了一个通用的错误trait,所有错误类型都可以实现这个trait。
  • 手动实现:需要手动实现DisplayError trait,比较繁琐。

用法示例

use std::fmt;

#[derive(Debug)]
struct MyError {
details: String,
}

impl fmt::Display for MyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.details)
}
}

impl std::error::Error for MyError {}

小结

要使用Rust标准库的Error trait实现自定义错误,那么需要做到以下两点:

  • 实现std::fmt::Display trait,以便将错误信息显示给用户。
  • 实现std::error::Error trait,以便将错误信息传递给调用者。

anyhow

该库提供了 anyhow::Error,一种基于特征对象(trait object)的错误类型, 用于在 Rust 应用程序中轻松地进行惯用错误处理。

改进

  • 简化错误处理:提供了一个基于trait对象的通用错误类型anyhow::Error,简化了错误的传播和处理。
  • 上下文信息:可以为错误添加上下文信息,帮助调试。
  • 自动回溯:自动捕获和打印回溯信息(在Rust 1.65及以上)。

用法示例

use anyhow::{Context, Result};

fn read_file(path: &str) -> Result<String> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file at path: {}", path))?;
Ok(content)
}

小结

anyhow库通过提供anyhow::Error类型,简化了Rust中的错误处理。 它提供了一种简单的方式来处理和传播错误,同时支持添加上下文信息和自动回溯信息。

  • 在anyhow中,经常使用的就是anyhow::Resultanyhow::anyhow
  • 可以使用?符号进行错误传播,同时可以解包未出错的值

thiserror

这个库为标准库的 std::error::Error trait 提供了一个方便的派生宏。

改进

  • 派生宏:通过派生宏简化了自定义错误类型的定义。
  • 自动实现:自动实现DisplayError trait,减少了手动编码的负担。
  • 与其他错误类型集成:可以轻松地将其他错误类型转换为自定义错误类型。

用法示例

use thiserror::Error;

#[derive(Error, Debug)]
pub enum MyError {
#[error("IO error")]
Io(#[from] std::io::Error),
#[error("Parse error")]
Parse(#[from] std::num::ParseIntError),
}

snafu

SNAFU 是一个可以轻松生成错误并向底层错误添加信息的库, 特别是当相同的底层错误类型可能发生在不同的上下文中时。

改进

  • 错误类型生成:通过宏生成错误类型和相关的上下文信息。
  • 上下文支持:更强大的上下文信息支持,提供了详细的错误信息。
  • 源错误追踪:内置对源错误的追踪和显示。

用法示例

use snafu::{Snafu, ResultExt};

#[derive(Debug, Snafu)]
enum MyError {
#[snafu(display("Failed to open file {}: {}", filename, source))]
OpenFile { filename: String, source: std::io::Error },
#[snafu(display("Failed to parse integer: {}", source))]
ParseInt { source: std::num::ParseIntError },
}

fn read_file(path: &str) -> Result<String, MyError> {
let content = std::fs::read_to_string(path).context(OpenFile { filename: path.to_string() })?;
Ok(content)
}

总结

从Rust标准库的std::error::Erroranyhowthiserrorsnafu,经历了一系列的改进和设计:

  • 从手动到自动:从标准库需要手动实现错误处理,到thiserrorsnafu利用宏自动生成代码,减少了开发者的负担。
  • 上下文信息anyhowsnafu增强了对上下文信息的支持,使调试更加容易。
  • 错误传播简化anyhow通过统一的错误类型简化了错误传播,而snafu提供了强大的错误上下文支持。
  • 回溯信息anyhow自动捕获回溯信息,为调试提供了更多有用的信息。

这些改进和设计使得Rust的错误处理更加简洁、高效和易于维护。

参考

鱼雪

thiserror 库为 Rust 的 std::error::Error trait 提供了一个方便的派生宏,使得定义错误类型变得更加简单和高效。下面是对 thiserror 中不同属性的用法和含义的总结,以及它们的使用情况。

Thiserror 思维导图

#[error("...")]

  • 用法: 用于为错误类型或枚举的每个变体提供一个显示格式(Display)。
  • 含义: 定义了当错误被打印时显示的消息格式。
  • 适用情况: 任何需要向用户展示错误信息的场景。

#[from]

  • 用法: 用于自动实现从源错误类型到当前错误类型的 From trait。
  • 含义: 允许一个错误类型自动转换("from")另一个错误类型。
  • 适用情况: 当你想要从一个特定的错误类型自动转换到你定义的错误类型时。这通常用在错误链的上下文中,允许底层错误被包装成更高层的抽象错误。

#[source]

  • 用法: 标记一个字段作为源错误(即导致当前错误的底层错误)。
  • 含义: 该字段会被 Error trait 的 source() 方法返回,用于错误链的追踪。
  • 适用情况: 当你的错误是由另一个错误导致的,并且你想要保留这种因果关系时。这对于调试和错误报告非常有用。

#[backtrace]

  • 用法: 标记一个字段为回溯(backtrace)信息。
  • 含义: 允许捕获和存储错误发生时的调用栈信息。
  • 适用情况: 在需要调试或详细了解错误发生上下文时非常有用。这通常用于复杂系统中,通过回溯可以更容易地定位问题源头。

#[error(transparent)]

  • 用法: 使得当前错误类型在显示和源链处理上透明地代理到内部的错误类型。
  • 含义: 当前错误类型的 Displaysource 方法将直接委托给它包装的错误类型。
  • 适用情况: 当你定义的错误类型仅仅是对另一个错误类型的简单封装,而你不希望添加任何额外信息或行为时。这在定义通用或透明的错误包装时特别有用。

通过这些属性,thiserror 库大大简化了 Rust 中错误处理的复杂性,使得定义丰富而又具有表达力的错误类型变得非常简单。 无论是简单的直接转换,还是更复杂的错误链处理和调试信息捕获,thiserror 都提供了强大而灵活的工具,以支持各种不同的使用场景。 相关信息可以在其官方文档GitHub 仓库中找到。

鱼雪

anyhow库为Rust应用程序提供了一种基于trait对象的错误类型anyhow::Error,以便于进行简便和惯用的错误处理。

Anyhow思维导图

主要特性

  • 简化错误传播:通过使用?操作符,可以轻松地传播实现了std::error::Error trait的任何错误。
  • 上下文添加:允许为错误添加上下文,帮助调试时理解错误发生的具体环节。这是通过Context trait和相关方法(如.context().with_context())实现的。
  • 错误下转型:支持将anyhow::Error下转型为具体的错误类型,以便进行更精确的错误处理或信息获取。
  • 自动捕获回溯信息:在Rust版本≥1.65时,如果底层错误类型没有提供自己的回溯信息,anyhow会自动捕获并打印错误的回溯信息。通过环境变量可以控制回溯信息的显示。
  • 与任何错误类型兼容anyhow可以与任何实现了std::error::Error的错误类型一起工作,不需要特定的derive宏来实现相关trait。
  • 宏支持:提供了几个宏来简化错误处理,例如anyhow!用于创建一个即时的错误消息,bail!用于提前返回一个错误,以及ensure!用于在条件不满足时返回错误。

使用场景

  • 函数返回类型:对于可能失败的函数,推荐使用Result<T, anyhow::Error>(或等价的anyhow::Result<T>)作为返回类型。
  • 错误传播:在函数内部,使用?来简化错误的传播。
  • 添加错误上下文:在可能导致调试困难的低级错误上添加上下文信息,以提供更多关于错误发生时上下文的信息。
  • 处理特定错误:通过错误下转型来处理特定类型的错误。
  • 自定义错误类型:虽然anyhow不直接提供derive宏,但可以与如thiserror库结合使用,来定义和实现自定义错误类型。
  • 即时错误消息:通过anyhow!bail!宏来快速创建和返回错误。

适用性

由于其灵活性和简便性,anyhow库适用于大多数Rust应用程序中的错误处理。它特别适合那些需要简单、直接且灵活处理各种可能错误的应用程序。 对于需要在库中暴露具体错误类型的情况,可能需要结合使用如thiserror之类的库来提供更精细的错误定义和处理。

鱼雪

axum::Router结构体

pub struct Router<S = ()> { /* private fields */}

用于组合处理程序和服务的路由器类型。

实现

impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,

新建路由器

pub fn new() -> Self

创建一个新的路由器,除非您添加额外的路由,否则将对所有请求响应404未找到。

添加另一个路由到路由器

pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self
  • path: 是由/分割的路径段字符串。每个段可能是静态的、捕获的或者是通配符。
  • method_router: 是一个MethodRouter,它将请求方法映射到处理程序。 method_router通常会是类似于get的方法路由器中的处理程序。

静态路径

例如:

  • /
  • /foo
  • /foo/bar

如果传入的请求路径完全匹配,则将调用相应的服务。

捕获

例如:

  • /:key
  • /foo/:key
  • /users/:id/tweets

路径可以包含类似于/:key的段,它匹配任何单个段,并将存储在key处捕获的值。 捕获的值可以是零长度,除了无效路径//

捕获可以使用Path进行提取。

MatchedPath可以用于提取匹配路径,而不是实际路径。

通配符

路径可以以/*key结尾,匹配所有段并捕获的段存储在key中。

例如:

  • /*key
  • /users/*path
  • /:id/:repo/*tree

请注意,/*key 不匹配空段。因此:

  • /*key 不匹配 /,但匹配 /a/a/ 等。
  • /x/*key 不匹配 /x/x/,但匹配 /x/a/x/a/ 等。

还可以使用 Path 来提取通配符捕获。 请注意,不包括前导斜杠,即对于路由 /foo/*rest 和路径 /foo/bar/bazrest 的值将是 bar/baz

接受多种方法

要接受同一路由的多个方法,您可以同时添加所有处理程序。

use axum::{Router, routing::{get, delete}, extract::Path};

let app = Router::new().route(
"/",
get(get_root).post(post_root).delete(delete_root),
);

async fn get_root() {}
async fn post_root() {}
async fn delete_root() {}

或者你也可以一一添加:

let app = Router::new()
.route("/", get(get_root))
.route("/", post(post_root))
.route("/", delete(delete_root));

更多例子

use axum::{Router, routing::{get, delete}, extract::Path};

let app = Router::new()
.route("/", get(root))
.route("/users", get(list_users).post(create_user))
.route("/users/:id", get(show_user))
.route("/api/:version/users/:id/action", delete(do_users_action))
.route("/assets/*path", get(serve_asset));

async fn root() {}

async fn list_users() {}

async fn create_user() {}

async fn show_user(Path(id): Path<u64>) {}

async fn do_users_action(Path((version, id)): Path<(String, u64)>) {}

async fn serve_asset(Path(path): Path<String>) {}

Panics

如果路径与另一个路由重叠,则会发生panic

use axum::{routing::get, Router};

let app = Router::new()
.route("/", get(|| async {}))
.route("/", get(|| async {}));

静态路由 /foo 和动态路由 /:key 不被视为重叠,并且 /foo 将优先。 如果路径为空,也会引发 panic。

路由服务

添加另一个路由到路由器调用一个服务

pub fn route_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error=Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,

示例:

use axum::{
Router,
body::Body,
routing::{any_service, get_service},
extract::Request,
http::StatusCode,
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;

let app = Router::new()
.route(
"/",
any_service(service_fn(|_: Request| async {
let res = Response::new(Body::from("Hi from `GET /`"));
}))
)
.route_service(
"/foo",
service_fn(|req: Request| async move {
let body = Body::from(format!("Hi from `{}` /foo", req.method()))
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route_service(
"/static/Cargo.toml",
ServeFile::new("Cargo.toml"),
);

以这种方式路由到任意服务会对背压(Service::poll_ready)产生复杂性。 有关更多详细信息,请参阅服务路由和背压模块。

由于相同的原因而出现panic,或者尝试将路由到Router时也会发生panic。

use axum::{routing::get, Router};

let app = Router::new().route_service(
"/",
Router::new().route("/foo", get(|| async {})),
);

使用Router::nest替换

在某个路径上嵌套一个路由器。

这样可以将应用程序分解成更小的部分,并将它们组合在一起。

pub fn nest(self, path: &str, router: Router<S>) -> Self

示例:

use axum::{
routing::{get, post},
Router,
};

let user_routes = Router::new().route("/:id", get(|| async {}));
let team_routes = Router::new().route("/", post(|| async {}));

let api_routes = Router::new()
.nest("/users", user_routes)
.nest("/teams", team_routes);

let app = Router::new().nest("/api", api_routes);

// Our app now accepts
// - GET /api/users/:id
// - POST /api/teams

URI如何变化

请注意,嵌套路由将无法看到原始请求URI,而是会剥去匹配的前缀。 这对于像静态文件服务之类的服务工作是必要的。 如果需要原始请求URI,请使·OriginalUri

外部路由的捕获

在使用嵌套动态路由时要小心,因为嵌套还会从外部路由中捕获:

use axum::{
extract::Path,
routing::get,
Router,
};
use std::collections::HashMap;

async fn users_get(Path(params): Path<HashMap<String, String>>) {
// Both `version` and `id` were captured even though `users_api` only
// explicitly captures `id`.
let version = params.get("version");
let id = params.get("id");
}

let users_api = Router::new().route("/users/:id", get(users_get));

let app = Router::new().nest("/:version/api", users_api);

与通配符路由的区别

嵌套路由类似于通配符路由。 不同之处在于通配符路由仍然可以看到整个 URI,而嵌套路由将会去掉前缀:

use axum::{routing::get, http::Uri, Router};

let nested_router = Router::new()
.route("/", get(|uri: Uri| async {
// `uri` will _not_ contain `/bar`
}));

let app = Router::new()
.route("/foo/*rest", get(|uri: Uri| async {
// `uri` will contain `/foo`
}))
.nest("/bar", nested_router);

后备方案

如果嵌套路由器没有自己的回退,则将从外部路由器继承回退:

use axum::{routing::get, http::StatusCode, handler::Handler, Router};

async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

let api_routes = Router::new().route("/users", get(|| async {}));

let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback);

在这里,像 GET /api/not-found 这样的请求将进入 api_routes, 但由于它没有匹配的路由,也没有自己的回退,它将调用外部路由器的回退,即回退功能。 如果嵌套路由器有自己的回退,则外部回退将不会被继承:

use axum::{
routing::get,
http::StatusCode,
handler::Handler,
Json,
Router,
};

async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

async fn api_fallback() -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "status": "Not Found" })),
)
}

let api_routes = Router::new()
.route("/users", get(|| async {}))
.fallback(api_fallback);

let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback);

在这里,像 GET /api/not-found 这样的请求将转到 api_fallback

用状态嵌套路由器

当使用此方法将Router组合时,每个Router必须具有相同类型的状态。 如果您的路由器具有不同类型,您可以使用Router::with_state来提供状态并使类型匹配:

use axum::{
Router,
routing::get,
extract::State,
};

#[derive(Clone)]
struct InnerState {}

#[derive(Clone)]
struct OuterState {}

async fn inner_handler(state: State<InnerState>) {}

let inner_router = Router::new()
.route("/bar", get(inner_handler))
.with_state(InnerState {});

async fn outer_handler(state: State<OuterState>) {}

let app = Router::new()
.route("/", get(outer_handler))
.nest("/foo", inner_router)
.with_state(OuterState {});

请注意,内部路由器仍将继承外部路由器的后备机制。

恐慌

  • 如果路由与另一个路由重叠。有关详细信息,请参阅Router::route
  • 如果路由包含通配符(*)。
  • 如果path为空。

nest类似,但接受任何服务

pub fn nest_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error=Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,

将两个路由器的路径(path)和回退(fallbacks)合并到一个路由器

pub fn merge<R>(self, other: R) -> Self
where
R: Into<Router<S>>,

这对于将应用程序分成更小的部分并将它们组合成一个非常有用。

use axum::{
routing::get,
Router,
};

let user_routes = Router::new()
.route("/users", get(users_list))
.route("/users/:id", get(users_show));

let team_routes = Router::new()
.route("/teams", get(teams_list));

let app = Router::new()
.merge(user_routes)
.merge(team_routes);

// 也可以执行 `user_routes.merge(team_routes)`

// 我们的应用程序现在接受
// - GET /users
// - GET /users/:id
// - GET /teams

合并路由器的状态

使用此方法合并 Router 时,每个 Router 必须具有相同类型的状态。 如果您的 routers 具有不同类型,可以使用 Router::with_state 来提供状态并使类型匹配:

use axum::{
Router,
routing::get,
extract::State,
};

#[derive(Clone)]
struct InnerState {}

#[derive(Clone)]
struct OuterState {}

async fn inner_handler(state: State<InnerState>) {}

let inner_router = Router::new()
.route("/bar", get(inner_handler))
.with_state(InnerState {});

async fn outer_handler(state: State<OuterState>) {}

let app = Router::new()
.route("/", get(outer_handler))
.merge(inner_router)
.with_state(OuterState {});

合并具有回退的路由器

使用此方法合并 Router 时,后备(fallbacks)也会合并。但是只能有一个路由器有后备

tower::Layer 应用于路由器中的所有路由。

pub fn layer<L>(self, layer: L) -> Router<S>
where
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,

这可以用于为一组路由的请求添加额外的处理。

注意,中间件只应用于现有路由。

因此,您必须首先添加您的路由(和/或回退(fallbacks)),然后调用层(layer)。 在调用层之后添加的额外路由将不会添加中间件。

如果要将中间件添加到单个处理程序,可以使用 MethodRouter::layerHandler::layer

示例

添加tower_http::trace::TraceLayer:

use axum::{
routing::get,
Router,
};
use tower_http::trace::TraceLayer;

let app = Router::new()
.route("/foo", get(|| async {}))
.route("/bar", get(|| async {}))
.layer(TraceLayer::new_for_http());

如果您需要编写自己的中间件,请参阅“编写中间件”以获取不同的选项。

如果您只想在某些路由上使用中间件,可以使用Router::merge

use axum::{
routing::get,
Router,
};
use tower_http::{trace::TraceLayer, compression::CompressionLayer};

let with_tracing = Router::new()
.route("/foo", get(|| async {}))
.layer(TraceLayer::new_for_http());

let with_compression = Router::new()
.route("/bar", get(|| async {}))
.layer(CompressionLayer::new());

let app = Router::new()
.merge(with_tracing)
.merge(with_compression);

多中间件

当应用多个中间件时,建议使用tower::ServiceBuilder。有关更多详细信息,请参阅中间件。

路由之后运行

使用此方法添加的中间件将在路由之后运行,因此无法用于重写请求URI。 有关更多详细信息和解决方法,请参见“在中间件中重写请求URI”。

错误处理

请参阅有关错误处理影响中间件的详细信息。

向路由器应用一个 tower::Layer,只有当请求匹配路由时才会运行

pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,

请注意,中间件仅应用于现有路由。 因此,您必须首先添加您的路由(和/或回调),然后在之后调用 route_layer。 调用 route_layer 后添加的额外路由将不会添加中间件。

这与 Router::layer 类似,不同之处在于只有当请求匹配路由时中间件才会运行。 这对于提前返回的中间件非常有用(例如授权), 否则可能会将 404 Not Found 转换为 401 Unauthorized。

示例

use axum::{
routing::get,
Router,
};
use tower_http::validate_request::ValidateRequestHeaderLayer;

let app = Router::new()
.route("/foo", get(|| async {}))
.route_layer(ValidateRequestHeaderLayer::bearer("password"));

// `GET /foo` 使用有效令牌将接收 `200 OK`
// `GET /foo` 使用无效令牌将接收 `401 未经授权`
// `GET /not-found` 使用无效令牌将接收 `404 未找到`

向路由器添加一个回退处理程序

pub fn fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static

如果没有任何路由匹配传入的请求,将调用此服务。

use axum::{
routing::get,
Router,
handler::Handler,
response::IntoResponse,
http::{StatusCode, Uri},
};

let app = Router::new()
.route("/foo", get(|| async { "foo" }))
.fallback(fallback);

async fn fallback(uri: Uri) -> (StatusCode, String) {
(StatusCode::NOT_FOUND, format!("No route for {uri}"))
}

仅在路由器中没有匹配任何内容的路由时才适用回退。 如果处理程序被请求匹配但返回 404,则不会调用回退。

处理所有没有其他路由的请求

如果没有其他路由,使用Router::new().fallback(...)来接受所有请求, 无论路径或方法如何,这并不是最佳选择

use axum::Router;

async fn handler() {}

let app = Router::new().fallback(handler);

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

直接运行处理程序会更快,因为它避免了路由的开销:

use axum::handler::HandlerWithoutStateExt;

async fn handler() {}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();

向路由器添加一个回退服务

pub fn fallback_service<T>(self, service: T) -> Self
where
T: Service<Request, Error=Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,

查看Router::fallback以获取更多详细信息。

为路由器提供状态

pub fn with_state<S2>(self, state: S) -> Router<S2>
use axum::{Router, routing::get, extract::State};

#[derive(Clone)]
struct AppState {}

let routes = Router::new()
.route("/", get(|State(state): State<AppState>| async {
// use state
}))
.with_state(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

从函数中返回带状态的路由器

在从函数返回 Router 时,通常建议不直接设置状态

use axum::{Router, routing::get, extract::State};

#[derive(Clone)]
struct AppState {}

// 不要在这里调用 `Router::with_state`
fn routes() -> Router<AppState> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
}

// 在运行服务器之前执行
let routes = routes().with_state(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

如果确实需要提供状态,并且您没有将路由嵌套/合并到另一个路由器中,则返回不带任何类型参数的 Router:

// 不要返回 `Router<AppState>`
fn routes(state: AppState) -> Router {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}

let routes = routes(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

这是因为我们只能在 Router<()> 上调用 Router::into_make_service, 而不能在 Router<AppState> 上调用。有关原因的更多详细信息,请参见下文。

请注意,状态默认为(),所以RouterRouter<()>是一样的。

如果您需要嵌套/合并路由器,建议在结果路由器上使用通用状态类型:

fn routes<S>(state: AppState) -> Router<S> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}

let routes = Router::new().nest("/api", routes(AppState {}));

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

状态是路由器内的全局状态

传递给此方法的状态将用于该路由器接收的所有请求。 这意味着它不适合保存从请求中派生的状态,比如在中间件中提取的授权数据。 请改用 Extension 来存储此类数据。

Router<S>中的S代表什么

Router<S>表示一个缺少类型为S的状态以处理请求的路由器。 它并不意味着具有类型为S的状态的路由器。

例如:

// 需要`AppState`来处理请求的路由器
let router: Router<AppState> = Router::new()
.route("/", get(|_: State<AppState>| async {}));

// 一旦我们调用 `Router::with_state` 方法,路由器就不再缺少状态了,因为我们刚刚提供了它
//
// 因此,路由器类型变为`Router<()>`,即一个不缺少任何状态的路由器。
let router: Router<()> = router.with_state(AppState {});

// 只有 `Router<()>` 具有 `into_make_service` 方法。
//
// 因为它仍然缺少 `AppState`,所以不能在 `Router<AppState>` 上调用 `into_make_service`。
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

或许有点反直觉,Router::with_state 并不总是返回一个 Router<()>。 相反,您可以选择新的缺失状态类型是什么:

let router: Router<AppState> = Router::new()
.route("/", get(|_: State<AppState>| async {}));

// 当我们调用`with_state`时,我们可以选择下一个丢失的状态类型是什么。在这里我们选择`String`
let string_router: Router<String> = router.with_state(AppState {});

// 这允许我们添加使用`String`作为状态类型的新路由
let string_router = string_router
.route("/needs-string", get(|_: State<String>| async {}));

// 提供`String`,并选择 `()` 作为新的缺失状态。
let final_router: Router<()> = string_router.with_state("foo".to_owned());

// 既然我们有一个`Router<()>`,我们可以运行它。
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, final_router).await.unwrap();

为什么在调用with_state后返回Router<AppState>不起作用?

// 这不会起作用,因为我们正在返回 `Router<AppState>`
// 即,我们在说我们仍然缺少一个 `AppState`
fn routes(state: AppState) -> Router<AppState> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}

let app = routes(AppState {});

// 我们只能在 `Router<()>` 上调用 `Router::into_make_service` 方法,
// 而 `app` 是 `Router<AppState>`
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

由于我们提供了所需的所有状态,因此请返回 Router<()>

// 我们已经提供了所有必要的状态,因此返回 `Router<()>`。
fn routes(state: AppState) -> Router<()> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}

let app = routes(AppState {});

// 我们现在可以调用 `Router::into_make_service`。
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

关于性能的说明

如果您需要一个实现 Service 但不需要任何状态的 Router(也许您正在制作一个在内部使用 axum 的库), 那么建议在开始提供请求之前调用该方法:

use axum::{Router, routing::get};

let app = Router::new()
.route("/", get(|| async { /* ... */ }))
// 即使我们不需要任何状态,也要调用`with_state(())`。
.with_state(());

这不是必需的,但它让 axum 有机会更新路由器中的一些内部内容,可能会影响性能并减少分配。

备注

请注意,Router::into_make_serviceRouter::into_make_service_with_connect_info 会自动执行此操作。

将路由器转换为带有固定请求体类型的借用服务,以辅助类型推断

pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S>

在某些情况下,在 Router 上调用 tower::ServiceExt 的方法时,可能会出现类似以下内容的类型推断错误

let response = router.ready().await?.call(request).await?;
^^^^^ cannot infer type for type parameter `B`

这是因为 Router 使用 impl<B> Service<Request<B>> for Router<()> 实现了 Service

例如:

use axum::{
Router,
routing::get,
http::Request,
body::Body,
};
use tower::{Service, ServiceExt};

let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());
let response = router.ready().await?.call(request).await?;

调用 Router::as_service 可以解决此问题。

use axum::{
Router,
routing::get,
http::Request,
body::Body,
};
use tower::{Service, ServiceExt};

let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());
let response = router.as_service().ready().await?.call(request).await?;

这主要是在测试时调用路由时使用的。当通过 Router::into_make_service 正常运行路由时,这是不必要的。

将路由器转换为具有固定请求正文类型的所有服务,以帮助类型推断

pub fn into_service<B>(self) -> RouterIntoService<B, S>

这与 Router::as_service 相同,只是它返回一个拥有的服务。有关更多详细信息,请参见该方法。

impl Router

pub fn into_make_service(self) -> IntoMakeService<Self>

将此路由器转换为一个 MakeService,它是一个响应是另一个服务的服务。

use axum::{
routing::get,
Router,
};

let app = Router::new().route("/", get(|| async { "Hi!" }));

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

仅在tokio特性上可用

pub fn into_make_service_with_connect_info<C>(
self
) -> IntoMakeServiceWithConnectInfo<Self, C>

将这个路由器转换为一个MakeService,它将把C的相关ConnectInfo存储在一个请求扩展中, 以便ConnectInfo可以提取它。

这使得可以提取类似客户端远程地址的信息。

提取std::net::SocketAddr 是开箱即用的。

use axum::{
extract::ConnectInfo,
routing::get,
Router,
};
use std::net::SocketAddr;

let app = Router::new().route("/", get(handler));

async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
format!("Hello {addr}")
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.unwrap();

您可以这样实现自定义连接(Connected):

use axum::{
extract::connect_info::{ConnectInfo, Connected},
routing::get,
serve::IncomingStream,
Router,
};

let app = Router::new().route("/", get(handler));

async fn handler(
ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
) -> String {
format!("Hello {my_connect_info:?}")
}

#[derive(Clone, Debug)]
struct MyConnectInfo {
// ...
}

impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
MyConnectInfo {
// ...
}
}
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();

Trait实现

impl<S> Clone for Router<S>

  • fn clone(&self) -> Self: 返回一个克隆值。
  • fn clone_from(&mut self, source: &Self): 从源(source)执行复制赋值。

impl<S> Debug for Router<S>

  • fn fmt(&self, f: &mut Formatter<'_>) -> Result: 使用给定的格式器格式化值

impl<S> Default for Router<S> where S: Clone + Send + Sync + 'static

  • fn default() -> Self: 返回一个默认值。

impl Service<IncomingStream<'_>> for Router<()>

  • 仅在 tokio crate 功能和 (crate 功能 http1http2) 中可用。

type Response = Router

  • 由服务提供的响应。

type Error = Infallible

  • 由服务生成的错误

type Future = Ready<Result<<Router as Service<IncomingStream<'_>>>::Response, <Router as Service<IncomingStream<'_>>>::Error>>

  • 功能响应值

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>

  • 当服务能够处理请求时,返回 Poll::Ready(Ok(()))

fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future

  • 处理请求并异步返回响应。

impl<B> Service<Request<B>> for Router<()>
where
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,

type Response = Response<Body>

  • 由服务提供的响应。

type Error = Infallible

  • 由服务生成的错误

type Future = RouteFuture<Infallible>

  • 功能响应值

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>

  • 当服务能够处理请求时,返回 Poll::Ready(Ok(()))

fn call(&mut self, req: Request<B>) -> Self::Future

  • 异步处理请求并返回响应

自动特性实现

  • impl<S> Freeze for Router<S>
  • impl<S = ()> !RefUnwindSafe for Router<S>
  • impl<S> Send for Router<S>
  • impl<S> Sync for Router<S>
  • impl<S> Unpin for Router<S>
  • impl<S = ()> !UnwindSafe for Router<S>

通用实现

impl<T> Any for T
where
T: 'static + ?Sized,

impl<T> Borrow<T> for T
where
T: ?Sized,

impl<T> BorrowMut<T> for T
where
T: ?Sized,

impl<T> From<T> for T

impl<T> FromRef<T> for T
where
T: Clone,

impl<T> Instrument for T

impl<T, U> Into<U> for T
where
U: From<T>,

impl<M, S, Target, Request> MakeService<Target, Request> for M
where
M: Service<Target, Response = S>,
S: Service<Request>,

impl<T> PolicyExt for T
where
T: ?Sized,

impl<T> Same for T

impl<S, R> ServiceExt<R> for S
where
S: Service<R>,

impl<T, Request> ServiceExt<Request> for T
where
T: Service<Request> + ?Sized,

impl<T> ToOwned for T
where
T: Clone,

impl<T, U> TryFrom<U> for T
where
U: Into<T>,

impl<T, U> TryInto<U> for T
where
U: TryFrom<T>,

impl<V, T> VZip<V> for T
where
V: MultiLane<T>,

impl<T> WithSubscriber for T
鱼雪

目录

  • 高级功能
  • 兼容性
  • Hello World
  • 路由
  • 处理程序
  • 提取器
  • 响应
  • 错误处理
  • 中间件
  • 与处理程序共享状态
  • Axum集成
  • 必需依赖
  • 示例
  • 功能标志

高级功能

  • 使用无宏(macro-free)的API路由(router)请求(request)到处理程序(handler)
  • 使用提取器声明式解析请求
  • 简单且可预测的错误处理模型
  • 生成局域最少样板的响应
  • 充分利用towertower-http生态系统的中间件、服务和使用程序

特别是,最后一点是区分axum与其它框架的地方。 axum没有自己的中间件系统,而是使用tower::Service。 这意味着axum获得超时追踪压缩授权等功能, 而且是免费的。 它还能让您与使用hypertonic编写的应用程序共享中间件。

兼容性

axum旨在与tokiohyper协同工作。 至少目前来看,并不追求运行时和传输层独立性。

Hello World

use axum::{
routing::get,
Router,
};

#[tokio::main]
async fn main() {
let app = Router::new().route("/", get(|| async { "Hello World!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
备注

使用#[tokio::main]需要您启用tokio的宏和rt-multi-thread功能, 或者只需启用所有功能(cargo add tokio --features macro, rt-multi-thread)

路由

路由器用于设置哪些路径指向那些服务。

use axum::{Router, routing::get};

let app = Router::new
.route("/", get(root))
.route("/foo", get(get_foo).post(post_foo))
.route("/foo/bar", get(foo_bar));

async fn root() {}
async fn get_foo() {}
async fn post_foo() {}
async fn foo_bar() {}

处理程序

axum中,处理程序是一个异步函数,它接受零个或多个提取器作为参数, 并返回可以转换为响应的东西。 处理程序是您的应用程序逻辑所在的地方,并且axum应用程序是通过在处理程序 之间进行路由而构建的。

提取器

提取器是实现了FromRequestFromRequestsParts接口的类型。 提取器是您拆解传入请求以获取处理程序所需部分的方式。

use axum::extract::{Path, Query, Json};
use std::collections::HashMap;

// `Path`会提供路径参数并对其进行反序列化
async fn path(Path(user_id): Path<u32>) {}

// `Query`为您提供查询参数并将其反序列化
async fn query(Query(params): Query<HashMap<String, String>>) {}

// 将请求正文缓冲并将其反序列化为JSON到`serde_json::Value`
// `Json`支持任何实现`serde::Deserialize`的类型
async fn json(Json(payload): Json<serde_json:Value>) {}

响应

处理程序返回的可以实现IntoResponse接口的任何内容

use axum::{
body::Body,
routing::get,
response::Json,
Router,
};
use serde_json::{Value, json};

async fn plain_text() -> &'static str {
"foo"
}

async fn json() -> Json<Value> {
Json(json!({ "data": 42 }))
}

let app = Router::new()
.route("/plain_text", get(plain_text))
.route("/json", get(json));

错误处理

axum旨在拥有一个简单且可预测的错误处理模型。 这意味着将错误转换为响应变得简单, 并且可以保证所有错误都得到处理。

参见error_handling

处理程序之间共享状态

处理程序之间共享状态是很常见的。 例如,可能需要共享数据库连接池或其他服务的客户端。 实现这一点的三种常见方式是:

  • 使用State提取器
  • 使用请求扩展
  • 使用闭包捕获
  1. 使用State提取器
use axum::{
extract::State,
routing::get,
Router,
};
use std::sync::Arc;

struct AppState {
//...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router.new()
.route("/", get(handler))
.with_state(shared_state);

async fn hander(
State(state): State<Arc<AppState>>,
) {
// ...
}

如果可能的话,您应该更倾向于使用State,因为它更具有类型安全性。 不足之处是,它比请求扩展更少的动态性。

  1. 使用请求扩展

在处理程序中提取状态的另一种方式是使用Extension作为中间层和提取器。

use axum::{
extract::Extension,
routing::get,
Router,
};
use std::sync::Arc;

struct AppState {
// ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
.route("/", get(handler))
.layer(Extension(shared_state));

async fn hander(
Extension(state): Extension<Arc<AppState>>
) {
// ...
}

这种方法的缺点是,如果你尝试提取一个不存在的扩展, 你将会得到运行时错误(500内部服务器错误响应)

  1. 使用闭包捕获

状态也可以通过闭包捕获直接传递给处理程序。

use axum::{
Json,
extract::{Extension, Path},
routing::{get, post},
Router,
};
use std::sync::Arc;
use serde::Deserialize;

struct AppState {
// ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
.route(
"/users",
post({
let shared_state = Arc::clone(&shared_state);
move |body| create_user(body, shared_state)
}),
)
.route(
"/users/:id",
get({
let shared_state = Arc::clone(&shared_state);
move |path| get_user(path, shared_state)
}),
);

async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
// ...
}

async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
// ...
}

#[derive(Deserialize)]
struct CreateUserPayload {
// ...
}

这种方法的缺点在于,它比使用StateExtensions要冗长一些。

为axum构建集成

系统提供FromRequest, FromRequestPartsIntoResponse实现的库 作者应该依赖于axum-core包,而不是axumaxum-core包含核心类型和特性,并且不太可能出现破坏性变化。

必需依赖

要使用axum,你还需要引入一些依赖:

[dependencies]
axum = "<latest-version>"
tokio = { version = "<latest-version>", features = ["full"] }
tower = "<latest-version>"

为了开始使用,完整功能对于 tokio 并不是必需的,但却是最简单的方法。 Tower 也不是绝对必要的,但在测试时很有帮助。 请参考存储库中的测试示例,了解有关测试 axum 应用程序的更多信息。

特性标志

axum使用一组功能标志来减少编译和可选依赖项的数量

以下可选项特性可用:

名称描述是否默认
http1启用hyper的http1特性
http2启用hyper的http2特性
json启用Json类型和一些类似的便利功能
macro启动可选工具宏
matched-path启用了对每个请求的路由路径进行捕获,并使用MatchedPath提取器
multipart启用Multipart解析multipart/form-data请求
original-uri启用捕获每个请求的原始URI和OriginalUri提取器
tokio启用tokio作为依赖和axum::serve,SSE和extract::connect_info类型
tower-log启用tower的日志特性
tracing从内置提取器记录日志
ws通过extract::ws启用Websocket支持
form启用表单提取器
query启用查询提取器

模块

  • body: HTTP请求体工具
  • error_handling: 错误处理模型和工具
  • extract: 从请求为类型和特型提取数据
  • handler: 可以用来处理请求的异步函数
  • middleware: 写中间件的工具
  • response: 生成响应的类型和特型
  • routing: 在Service和处理之间的路由
  • serve: 提供服务

结构体

  • Error: 在使用axum时可能发生的错误
  • Extension: 提取器和扩展响应
  • Form: URL编码的提取器和响应。
  • Json: JSON提取器 / 响应。
  • Router: 用于组合处理程序和服务的路由器类型。

Traits(特型)

  • RequestExt: 扩展特性,为Request添加额外方法
  • RequestPartExt: 扩展特性,为Parts添加额外的方法
  • ServiceExt: 想任何服务添加附加方法的扩展特性

函数

  • serve: tokio和(http1或http2),使用提供的监听器提供服务

属性宏

  • debug_handler: 宏在应用处理函数时生成更好的错误消息。
鱼雪