Burn 是一个新的全面的动态深度学习框架,以极高的灵活性、计算效率和可移植性为其主要目标。
性能
我们相信深度学习框架的目标是将计算转化为有用的智能,因此我们将性能视为Burn的核心支柱。
我们努力通过利用下面描述的多种优化技术来实现最高效率。
自动内核融合 💥
使用Burn意味着在任何后端对您的模型进行优化。
在可能的情况下,我们提供一种自动和动态创建自定义内 核的方式, 最小化在不同内存空间之间的数据重定位,当内存迁移是瓶颈时非常有用。
例如,您可以使用高级张量 API(请参见下面的Rust 代码片段)编写自己的GELU 激活函数。
fn gelu_custom < B : Backend , const D : usize > ( x : Tensor < B , D > ) -> Tensor < B , D > {
let x = x . clone ( ) * ( ( x / SQRT_2 ) . erf ( ) + 1 ) ;
x / 2
}
然后,在运行时, 将自动为您的特定实现创建一个自定义低级内核,并将与手工编写的 GPU 实现相媲美。
内核由大约 60 行 WGSL WebGPU 着色语言组成,这是一种非常冗长的低级着色语言,您可能不希望用其来编写深度学习模型!
就目前而言,我们的融合策略仅为我们自己的 WGPU 后端实现,并且仅支持部分操作。
我们计划很快添加更多操作,并将这种技术扩展到未来的其他内部后端。
异步执行 ❤️🔥
对由 Burn 团队从头开始开发的后端,使用了一种异步执行风格,这允许执行各种优化,例如先前提到的自动内核融合。 异步执行还确保框架的正常执行不会阻塞模型计算,这意味着框架的开销不会显著影响执行速度。 反之,模型中的强烈计算不会干扰框架的响应。 有关我们异步后端的更多信息,请参阅此博客文章。
线程安全的构建模块 🦞
Burn通过利用Rust的所有权系统来强调线程安全。 使用Burn,每个模块都是其权重的所有者。
因此,可以将一个模块发送到另一个线程来计算梯度,然后将梯度发送到可以聚合它们的主线程,这样,您就可以进行多设备训练。
这与PyTorch的做法截然不同,PyTorch中的反向传播实际上会改变每个张量参数的梯度属性。 这是一个非线程安全的操作,因此需要较低级别的同步原语,请参考分布式训练。 请注意,这仍然非常快速,但不兼容不同的后端,并且实现起来相当困难。
智能内存管理 🦀
深度学习框架的主要作用之一是减少运行模型所需的内存量。
处理内存的天真方式是每个张量都有自己的内存空间,在张量创建时分配,然后在张量超出范围时释放。
然而,分配和释放数据非常昂贵,因此通常需要内存池来实现良好的吞吐量。
Burn 提供了一个基础架构,可以轻松创建和选择后端的内存管理策略。
有关 Burn 中内存管理的更多详细信息,请参阅此博客文章。
Burn 的另一个非常重要的内存优化是,我们通过良好使用所有权系统,跟踪张量何时可以就地进行突变。
即使这在单独的内存优化方面相对较小,但在训练或运行较大模型进行推理时,会显着累积,并有助于进一步减少内存使用。
有关更多信息,请参阅有关张量处理的此博客文章。
自动内核选择🎯
一个优秀的深度学习框架应该确保模型在所有硬件上运行顺畅。 然而,并非所有硬件在执行速度方面都表现相同。
例如,矩阵乘法内核可以通过许多不同的参数启动,这些参数对矩阵的大小和硬件非常敏感。 使用错误的配置可能会大幅降低执行速度(在极端情况下甚至会降低10倍或更多),因此选择正确的内核成为优先事项。
通过我们自制的后端,我们会自动运行基准测试,并选择适用于当前硬件和矩阵大小的最佳配置,配合合理的缓存策略。 这会稍微增加热身执行时间,但在几次前向和后向传递之后很快稳定下来,从长远来看节省大量时间。 请注意,这个功能不是强制性的,当冷启动优先于优化吞吐量时可以禁用。
硬件特定功能🔥
深度学习主要依赖矩阵乘法作为其核心操作,因为这是全连接神经网络建模的方式。
越来越多的硬件制造商专门为矩阵乘法工作负载进行优化。 例如,Nvidia 拥有其 Tensor Cores,并且今天大多数手机都配备了人工智能专用芯片。
就目前而言,我们支持 Tensor Cores 与我们的 LibTorch 和 Candle 后端,但尚未支持其他加速器。 我们希望这个问题在某个时间点得到解决,以支持我们的 WGPU 后端。
自定义后端扩展 🎒
Burn 致力于成为最灵活的深度学习框架。 虽然保持与各种后端的兼容性至关重要,但 Burn 也提供了扩展后端实现功能以满足个人建模需求的能力。
这种多功能性在许多方面都具有优势,例如支持自定义操作,如快闪注意力,或手动编写特定后端的内核以增强性能。 请查看 Burn Book 🔥 中的此部分以获取更多详细信息。
训练和推理
使用 Burn,整个深度学习工作流程变得轻松,因为您可以通过符合人体工程学的仪表板监控培训进展, 并在嵌入式设备到大型 GPU 集群中的任何位置运行推理。
Burn 从头开始构建时考虑了培训和推理。 值得一提的是,与 PyTorch 等框架相比,Burn 如何简化了从训练到部署的过度,消除了代码更改的需要。
培训仪表板 📈
正如您在以前的视频中所看到 (单击图片!),
一个基于 Ratatui
框架的新的终端 UI 仪表板使用户能够轻松地跟踪他们的培训,无需连接到任何外部应用程序。
您可以实时查看培训和验证指标的更新,并仅使用箭头键分析任何已注册指标的终身进展或最近历史。 在不崩溃的情况下中断培训循环,使潜在的检查点能够完全编写或重要的代码片段可以在没有干扰的情况下完成 🛡
ONNX 支持 🐫
ONNX(Open Neural Network Exchange)是一种开放标准格式,可导出深度学习模型的架构和权重。
Burn 支持符合 ONNX 标准的模型的导入, 因此您可以轻松地将在另一个框架(如 TensorFlow 或 PyTorch)中编写的模型移植到 Burn, 以从我们的框架提供的所有优势中受益。
我们的 ONNX 支持在 Burn 书的这一部分中有进一步描述 🔥。
此 crate 正在积极开发中,目前支持有限的 ONNX 运算符。
导入 PyTorch 模型 🚚
支持将 PyTorch 模型权重加载到 Burn 的本地模型架构中,确保无缝集成。 参见 Burn 书 🔥 中有关导入 PyTorch 的部分。
浏览器内的推理 🌐
我们的几个后端可以编译成 Web Assembly:Candle 和 NdArray 用于 CPU,WGPU 用于 GPU。
这意味着您可以直接在浏览器内运行推理。我们提供了几个示例: