JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。 由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解
- 使用
import jax.numpy
访问 NumPy 函数,使用import jax.scipy
访问 SciPy 函数。 - 通过使用
@jax.jit
进行装饰,可以加快即时编译速度。 - 使用
jax.grad
求导数 - 使用
jax.vmap
进行矢量化,并使用jax.pmap
跨设备进行并行化。
JAX 遵循函数式编程理念。 这意味着您的函数必须是自包含的,或者说是纯函数:不允许有副作用。
从本质上讲,纯函数就像数学函数。拥有输入,输出,但不与外界通信。
- 以下代码片段是一个非纯函数式的示例
import jax.numpy as jnp
bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
备注
注意 impure_example
之外的偏差(bias
)。
在编译期间,偏差(bias
)可能会被缓存,因此不再反映偏差(bias
)的变化。
- 这是一个纯函数的例子。
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation
在这里,pure_example
是独立的:所有参数都作为参数传递。