rust-ml

SKILL.md

Machine Learning Development

Domain Constraints

Domain Rule Design Constraint Rust Implication
Large data Efficient memory Zero-copy, streaming
GPU acceleration CUDA/Metal support candle, tch-rs
Model portability Standard formats ONNX
Batch processing Throughput over latency Batched inference
Numerical precision Float handling ndarray, careful f32/f64
Reproducibility Deterministic Seeded random, versioning

Critical Rules

  • Avoid copying large tensors — memory bandwidth is the bottleneck. Use references, views, in-place ops.
  • Batch operations for GPU efficiency — GPU has overhead per kernel launch, batch to amortize.
  • Use standard model formats — train in Python, deploy in Rust via ONNX.

Use Case to Framework

Use Case Recommended Why
Inference only tract (ONNX) Lightweight, portable
Training + inference candle, burn Pure Rust, GPU
PyTorch models tch-rs Direct bindings
Data pipelines polars Fast, lazy eval

Key Crates

Purpose Crate
Tensors ndarray
ONNX inference tract
ML framework candle, burn
PyTorch bindings tch-rs
Data processing polars
Embeddings fastembed

Inference Server Pattern

use std::sync::OnceLock;
use tract_onnx::prelude::*;

static MODEL: OnceLock<SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>> = OnceLock::new();

fn get_model() -> &'static SimplePlan<...> {
    MODEL.get_or_init(|| {
        tract_onnx::onnx()
            .model_for_path("model.onnx")
            .unwrap()
            .into_optimized()
            .unwrap()
            .into_runnable()
            .unwrap()
    })
}

async fn predict(input: Vec<f32>) -> anyhow::Result<Vec<f32>> {
    let model = get_model();
    let input = tract_ndarray::arr1(&input).into_shape((1, input.len()))?;
    let result = model.run(tvec!(input.into()))?;
    Ok(result[0].to_array_view::<f32>()?.iter().copied().collect())
}

Batched Inference Pattern

async fn batch_predict(inputs: Vec<Vec<f32>>, batch_size: usize) -> Vec<Vec<f32>> {
    let mut results = Vec::with_capacity(inputs.len());

    for batch in inputs.chunks(batch_size) {
        let batch_tensor = stack_inputs(batch);
        let batch_output = model.run(batch_tensor).await;
        results.extend(unstack_outputs(batch_output));
    }

    results
}

Common Mistakes

Mistake Domain Violation Fix
Clone tensors Memory waste Use views
Single inference GPU underutilized Batch processing
Load model per request Slow Singleton pattern
Sync data loading GPU idle Async pipeline
Weekly Installs
3
GitHub Stars
3
First Seen
Feb 9, 2026
Installed on
opencode3
gemini-cli3
claude-code3
github-copilot3
codex3
kimi-cli3