跳到主要内容

1 篇博文 含有标签「Numpy」

查看所有标签

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。 由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导数
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 跨设备进行并行化。

JAX 遵循函数式编程理念。 这意味着您的函数必须是自包含的,或者说是纯函数:不允许有副作用。

从本质上讲,纯函数就像数学函数。拥有输入,输出,但不与外界通信。

  • 以下代码片段是一个非纯函数式的示例
import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
备注

注意 impure_example 之外的偏差(bias)。

在编译期间,偏差(bias)可能会被缓存,因此不再反映偏差(bias)的变化。

  • 这是一个纯函数的例子。
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation

在这里,pure_example 是独立的:所有参数都作为参数传递

鱼雪