背景
在深度学习和高性能计算中,矩阵乘法(Matmul)是核心操作之一, 也是现代 AI 模型如 GPT 和 Transformer 的基础计算单元。
随着 WebGPU 的发展,我们可以在浏览器中高效运行 GPU 计算,为前端机器学习应用带来了更多可能。
本文将通过五个阶段,从基础内核出发,逐步优化 WebGPU 矩阵乘法内核, 最终达到 超过 1TFLOPS 的性能,并探讨 WebGPU 与 CUDA 的区别及应用场景。
什么是 WebGPU?
WebGPU 是为浏览器设计的下一代 GPU 编程接口,原生支持计算着色器(Compute Shader), 通过使用 WGSL(WebGPU Shading Language) 编写 GPU 代码, 并支持多种硬件平台(如 Vulkan 和 Metal)。
优势
- 跨平台:兼容 Vulkan、Metal 和 DirectX。
- 高性能:原生支持并行计算,如矩阵乘法和深度学习。
- 便捷性:无需传统 WebGL 的复杂 hack,可直接进行机器学习计算。
WebGPU 与 CUDA 的区别
特性 | WebGPU | CUDA |
---|---|---|
硬件支持 | 跨平台(支持 Vulkan、Metal) | NVIDIA 专用 |
并行模型 | 线程、工作组(Workgroup)、网格(Grid) | 线程块(ThreadBlock)、网格 |
开发语言 | WGSL | CUDA C |
适用场景 | 前端高性能计算、跨平台机器学习 | 专业高性能计算、训练 AI 模型 |
WebGPU 计算着色器基础
- 线程(Thread):最小的并行执行单元。
- 工作组(Workgroup):线程的集合,支持组内存共享。
- 网格(Grid):多个工作组组成的并行执行结构。
示例:@workgroup_size(x, y, z)
定义每个工作组的线程数量为 (x \times y \times z)。
矩阵乘法优化的五个阶段
阶段 1:基础实现
Python 示例:
def matmul(a, b, c):
m, k, n = len(a), len(a[0]), len(b[0])
for i in range(m):
for j in range(n):
c[i][j] = sum(a[i][l] * b[l][j] for l in range(k))
WGSL 实现:
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.x / dimensions.N;
let col = global_id.x % dimensions.N;
if (row < dimensions.M && col < dimensions.N) {
var sum = 0.0;
for (var i: u32 = 0; i < dimensions.K; i++) {
sum += a[row * dimensions.K + i] * b[i * dimensions.N + col];
}
result[row * dimensions.N + col] = sum;
}
}
存在问题:
- 每个线程仅计算一个结果,导致大量工作组启动开销高。
- 每个工作组重复加载数据,没有利用缓存。