Skip to main content

One post tagged with "WebGPU"

View All Tags

背景

在深度学习和高性能计算中,矩阵乘法(Matmul)是核心操作之一, 也是现代 AI 模型如 GPT 和 Transformer 的基础计算单元。

随着 WebGPU 的发展,我们可以在浏览器中高效运行 GPU 计算,为前端机器学习应用带来了更多可能。

本文将通过五个阶段,从基础内核出发,逐步优化 WebGPU 矩阵乘法内核, 最终达到 超过 1TFLOPS 的性能,并探讨 WebGPU 与 CUDA 的区别及应用场景。


什么是 WebGPU?

WebGPU 是为浏览器设计的下一代 GPU 编程接口,原生支持计算着色器(Compute Shader), 通过使用 WGSL(WebGPU Shading Language) 编写 GPU 代码, 并支持多种硬件平台(如 Vulkan 和 Metal)。

优势

  1. 跨平台:兼容 Vulkan、Metal 和 DirectX。
  2. 高性能:原生支持并行计算,如矩阵乘法和深度学习。
  3. 便捷性:无需传统 WebGL 的复杂 hack,可直接进行机器学习计算。

WebGPU 与 CUDA 的区别

特性WebGPUCUDA
硬件支持跨平台(支持 Vulkan、Metal)NVIDIA 专用
并行模型线程、工作组(Workgroup)、网格(Grid)线程块(ThreadBlock)、网格
开发语言WGSLCUDA C
适用场景前端高性能计算、跨平台机器学习专业高性能计算、训练 AI 模型

WebGPU 计算着色器基础

  1. 线程(Thread):最小的并行执行单元。
  2. 工作组(Workgroup):线程的集合,支持组内存共享。
  3. 网格(Grid):多个工作组组成的并行执行结构。

示例:@workgroup_size(x, y, z) 定义每个工作组的线程数量为 (x \times y \times z)。

矩阵乘法优化的五个阶段

阶段 1:基础实现

Python 示例:

def matmul(a, b, c):
m, k, n = len(a), len(a[0]), len(b[0])
for i in range(m):
for j in range(n):
c[i][j] = sum(a[i][l] * b[l][j] for l in range(k))

WGSL 实现:

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.x / dimensions.N;
let col = global_id.x % dimensions.N;
if (row < dimensions.M && col < dimensions.N) {
var sum = 0.0;
for (var i: u32 = 0; i < dimensions.K; i++) {
sum += a[row * dimensions.K + i] * b[i * dimensions.N + col];
}
result[row * dimensions.N + col] = sum;
}
}

存在问题:

  • 每个线程仅计算一个结果,导致大量工作组启动开销高。
  • 每个工作组重复加载数据,没有利用缓存。

阶段 2:增加线程数量

通过提高每个工作组的线程数(如 @workgroup_size(256)),显著减少工作组的数量,从而降低启动开销。

阶段 3:二维工作组优化

通过将工作组从一维扩展到二维(如 (16 \times 16)),使每个工作组能够并行计算更多结果。

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.y;
let col = global_id.x;
...
}

阶段 4:内核平铺(Tiling)

采用平铺策略,每个线程一次计算多个结果(如 (1 \times 4)),进一步提升性能。

阶段 5:循环展开(Unrolling)

通过手动展开循环,减少 GPU 在运行时的循环控制开销,并利用指令级并行,性能大幅提升。

优化成果

  • 性能提升 超过 1000 倍,达到 1TFLOPS 的运算强度。
  • 有效利用 WebGPU 的多线程并行与缓存机制。
  • 实现了更高效的矩阵乘法内核,适用于前端高性能计算场景。

参考资料

鱼雪