跳到主要内容

Rust机器学习框架-HuggingFace Candle

鱼雪

CandleRust 的极简 ML 框架,重点关注性能(包括 GPU 支持)和易用性。尝试我们的在线演示: whisperLLaMA2T5yoloSegment Anything

模块

Candle项目包括一些crates,如下:

  • candle-book: candle相关的文档
  • candle-core: 核心功能库,核心操作,设备,Tensor结构定义等。
  • candle-nn: 神经网络,构建真实模型的工具
  • candle-examples: 在实际环境中使用库的示例
  • candle-datasets: 数据集和数据加载
  • candle-transformers: Transformer相关实现工具
  • candle-flash-attn: Flash Attention v2实现
  • candle-kernels: CUDA加速实现
  • candle-pyo3: Rust提供的Python接口
  • candle-wasm-examples: Rust WASM示例

其它有用的库:

  • candle-lora: 提供了符合官方peft实现的LoRA实现

特点

  • 语法简单(看起来像PyTorch)
    • 支持模型训练
    • 支持用于自定义操作运算
  • 后端
    • 优化的CPU后端,具有针对x86的可选MKL支持和针对MacAccelerate支持
    • CUDA后端可以再GPU上高效运行,通过NCCL运行多GPU分配
    • WASM支持,在浏览器中运行模型
  • 包含的模型
    • 语言模型
      • LLaMA v1 and v2
      • FaIcon
      • StarCoder
      • Phi v1.5
      • T5
      • Bert
    • Whisper(多语言支持)
    • Stable Diffusion v1.5, v2.1, XL v1.0
    • Wurstchen v2
    • 计算机视觉
      • DINOv2
      • EfficientNet
      • yolo-v3
      • yolo-v8
      • Segmeng-Anything(SAM)
  • 文件格式
    • 加载模型支持的格式如下:
      • safetensors
      • npz
      • ggml
      • PyTorch files
  • 无服务部署
    • 小型且快速的部署
  • 使用llama.cpp量化类型的量化支持

基本用法介绍

  1. 创建张量

    Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
    Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
  2. 张量索引

    tensor.i((.., ..4))?
  3. 张量重塑

    tensor.reshape((2, 2))?
  4. 张量矩阵乘法

    a.matmul(&b)?
  5. 张量数据移动到特定设备

    tensor.to_device(&Device::new_cuda(0)?)?
  6. 更改张量数据类型

    tensor.to_dtype(&Device::F16)?
  7. 张量算术运算

    &a + &b
  8. 保存模型

    candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
  9. 加载模型

    candle::safetensors::load("model.safetensors", &device)