JonLuca's Blog

May 3, 2025

Optimizing AI Model Load Times

At Weights, we handle millions of inference and training requests every day. Users frequently use and train LoRa models, resulting in nearly a petabyte of user-generated models.

Efficient infrastructure is critical for supporting these AI/ML workloads. One key metric is the speed at which models can be loaded into memory, significantly affecting latency during scaling or recovery events.

We recently encountered slow initialization issues with our PyTorch microservices running in Docker containers orchestrated by Kubernetes. Our largest models took up to 12 minutes to load from NAS storage and around 2 minutes from ephemeral filesystems, creating unacceptable downtime.

Here's how we addressed this, ultimately reducing load times by over 80%.

Identifying the Bottlenecks

Typically, model loading in a containerized environment involves:

  • Reading model files from persistent storage, often network-attached
  • Entire model file reads into memory
  • Overhead from Docker’s overlay filesystem
  • I/O throttling due to container resource constraints
  • Non-optimized default PyTorch loading mechanisms

These factors contributed significantly to loading delays, particularly with large models like our 32GB transformer.

Utilizing Memory-Based Filesystems (/dev/shm)

Our initial optimization involved leveraging Linux's RAM-based filesystem (/dev/shm). Using a tmpfs mount significantly speeds up model loading:

import os
import shutil
import torch
import time
import subprocess


def download_weights(url: str, dest: str):
    if url.endswith("tar"):
        subprocess.check_call(["pget", "--log-level=WARNING", "-x", url, dest], close_fds=False)
    else:
        subprocess.check_call(["pget", "--log-level=WARNING", url, dest], close_fds=False)


def fast_load_model(model_s3_path, model_name):
    shm_path = f"/dev/shm/{model_name}"
    os.makedirs(shm_path, exist_ok=True)

    # Download to shared memory
    local_path = f"{shm_path}/model.pth"
    if not os.path.exists(local_path):
        download_weights(model_s3_path, local_path)

    # Load model from shared memory
    model = MyModel()
    model.load_state_dict(torch.load(local_path))
    return model

Switching to this approach decreased loading times from approximately 90 seconds to 30 seconds - a 66% improvement.

We also ensured sufficient /dev/shm allocation by configuring Docker containers and Kubernetes deployments to match instance RAM, typically exceeding 120GB on GPU instances like A100 and H100.

Distributed Model Storage with Ceph

However, having each pod individually download models wasted bandwidth and added latency. With help from our good friends at Northflank, we implemented Ceph, a distributed filesystem, which helped us achieve:

  • Cluster-wide model caching
  • High-speed access from any node
  • Redundancy and durability
  • Pre-warmed caches before deployments

Our model loading logic would now check if the model was already cached in /mnt/models (Ceph mount) before downloading it to /dev/shm. This hybrid approach allowed us to leverage the speed of memory-based filesystems while ensuring models were available across all nodes.:

def load_model():
    ceph_path = "/mnt/models/model.pth"
    shm_path = "/dev/shm/model.pth"

    if os.path.exists("/mnt/models"):
        if not os.path.exists(shm_path):
            os.makedirs(os.path.dirname(shm_path), exist_ok=True)
            shutil.copy(ceph_path, shm_path)
    else:
        download_weights(model_s3_path, shm_path)

    model = MyModel()
    model.load_state_dict(torch.load(shm_path))
    return model

This hybrid strategy allowed efficient model distribution combined with rapid memory-based loading. Unfortunately occasionally, we still encountered slow loading times due to the sheer size of the models and the slow performance of the actual disks. One interesting side note that I'll cover in a future post is how keeping a filesystem in memory and loading from it over the network significantly outperforms ceph - a transparent S3 proxy that caches the models in memory also can lead to double digit performance improvements.

Optimizing PyTorch Model Loading

Beyond storage improvements, we optimized PyTorch loading methods:

  1. Safetensors: Using sft for memory-mapped loading in Rust, resulting in a ~2x speed improvement.
  2. Lazy Loading: Deferred loading of model components until necessary.
  3. Compilation and Quantization: Leveraged PyTorch’s compile functionality and experimented with quantization to shrink models.
  4. Artifact Caching: Employed PyTorch 2.7’s caching mechanism for faster reloading:
import torch

@torch.compile
def fn(x, y):
    return x.sin() @ y

a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)

compiled_model = torch.compile(fn, mode="max-autotune")
result = compiled_model(a, b)

artifacts = torch.compiler.save_cache_artifacts()

assert artifacts is not None
artifact_bytes, cache_info = artifacts


# Later...
torch.compiler.load_cache_artifacts(artifact_bytes)

This ensured models loaded quickly at startup or dynamically on-demand. A regular compilation run for our image models could easily take 8 minutes, whereas loading the compiled artifacts and re running the compilation took less than 30 seconds.

Storing these compiled artifacts in redis let us transparently load them into memory when needed, significantly reducing the time to first byte.

Summary

By strategically leveraging memory-based filesystems, distributed storage via Ceph, and PyTorch-specific optimizations, we significantly improved model loading performance.

These enhancements improved our cold start times, optimized our GPU utilization rates, and provided a robust infrastructure capable of rapidly scaling our AI/ML services.

We now have 10 kubernetes clusters across the three major cloud providers, spanning 6 availability zones with hundreds of GPUs and thousands of CPU cores, all running from a single codebase.

JonLuca

JonLuca at 22:28

To get notified when I publish a new essay, please subscribe here.