steven liu

Big parameters, small GPUs

July 11, 2025 (6 months ago)37 views

I recently gave my first talk at Sonoma AI, a local AI meetup. The talk was about how Transformers and Diffusers reduce the memory required to load large models on consumer GPUs.

This post recaps and summarizes the talk with some additional details.

memory maths

Llama 3.1 8B Instruct was downloaded over 5.5M times in the past month, making it the most downloaded text generation model on Hugging Face. But how much GPU memory is required to load this popular model for inference?

You can get a pretty good estimate by multiplying the number of parameters by the number of bytes per parameter (plus a little extra for the forward pass).

Llama 3.1 8B Instruct has 8B parameters and is stored in bfloat16 (half-precision), which is 2 bytes per parameter.

8B parameters * 2 bytes/parameter = 16GB

The problem is, many free-tier or consumer GPUs don't have that much memory. A free T4 GPU instance on Colaboratory has 16GB of GPU memory. But only 15GB of it is actually available. And buying a sufficiently powerful GPU can be expensive.

This is not very accessible.

Transformers and Diffusers lowers the barrier to fitting large models into GPU memory.

Big Model Inference

A model is typically loaded like this.

  1. Create the model with randomly initialized weights (16GB).
  2. Load the weights in memory (16GB).
  3. Load the weights in the empty model.
  4. Move the model to the device for inference.

Big Model Inference (BMI) loads a model like this.

shardsshard 1weightsshard 2weightsshard 3weightsshard nweightsloadempty model (meta device)filledfilledemptyemptyfilledemptyemptyemptyshards loaded one at a time, discarded after useShards load weights into the empty model on the meta device. Each shard is discarded after its weights are placed.
  1. Create an empty model with the PyTorch meta device. This creates tensors without any data attached. Tensors can be any size without worrying about memory constraints. It only creates tensors with the expected shape.

    Transformers instantiates a model directly on the meta device. This avoids loading a model into memory twice.

  2. The device_map optimally distributes model weights. This is automatic, but you can design your own device_map by assigning each module/layer to a device. From the shape and dtype of each tensor on the meta device, you can figure out how much memory the actual weights require.

    Transformers tries to fit as many weights as possible on your fastest device (GPU) first. If they don't all fit, it places the remaining weights on the CPU. And if that still doesn't fit, the rest of the weights are offloaded to disk.

    It even accounts for certain layers that shouldn't be split, like layers with residual connections.

  3. Load model shards into memory instead of loading the entire model. Once a shard is loaded, the weights are placed in the model and moved to the appropriate device. The shard is discarded, and the next shard is loaded.

    You only need enough CPU memory to load the biggest shard rather than the entire model.

  4. Load the weights in the empty model.

  5. Move the model to the device for inference.

  6. Repeat step 3 until all weights are loaded.

device_map

For multiple GPUs, device_map can split the model weights using different strategies.

Set the device_map argument in from_pretrained to distribute model weights across GPUs.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    dtype=torch.bfloat16,
    device_map="auto"
)

Transformers uses hooks to make sure weights are correctly moved and placed,.

This is slower than tensor parallelism because GPUs are used sequentially causing some to idle.

time →GPU 0GPU 1GPU 2layers 0-7idleidleidlelayers 8-15idleidleidlelayers 16-23YWith device_map, GPUs process layers sequentially. Only one GPU is active at a time while others idle.

dtype

The dtype is the data type of the elements in a tensor. It affects how much memory is required and what kind of numerical values a tensor can represent.

Tensor values are calculated from the sign, exponent, and significand (mantissa).

1sign8exponent23significandfp32: 1 sign bit, 8 exponent bits, 23 significand bits.

fp32 is considered full precision and takes up 32 bits in memory. 1 bit for the sign, 8 bits for the exponent, and 23 bits for the significand.

Lower precision dtype has fewer bits and requires less memory to store.

1sign5exponent10significandfp16: 1 sign bit, 5 exponent bits, 10 significand bits.

fp16 is half-precision and takes up 16 bits in memory. 1 bit for the sign, 5 bits for the exponent, and 10 bits for the significand.

bf16 is also half-precision but represents a wider range of values. 1 bit for the sign, 8 bits for the exponent, and 7 bits for the significand.

PyTorch loads a model in fp32 by default even if the model weights are in fp16 because you can't access the model until after you've loaded it with from_pretrained().

Loading a model in fp32 and again in fp16 wastes memory. Use the dtype argument in from_pretrained to explicitly set the dtype to avoid this.

I recommend using the "auto" option to let Transformers automatically get the most optimal dtype from the model weights.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto"
)

quantization

Quantization takes the idea of dtypes to an even lower level, usually from floating points to integers.

1sign7significandint8: 1 sign bit, 7 significand bits.

int8 takes up 8 bits in memory. 1 bit for the sign and 7 bits for the significand.

The original range of values are quantized to a lower range.

fp32-3.2min05.8maxint8-128zp=35(≈0 in fp32)127scale = (max - min) / 255s = (5.8 - (-3.2)) / 255 ≈ 0.035dequantize: x = scale × (q - zero_point)fp16 value = 0.035 × (int8_weight - 35)Asymmetric quantization maps fp32 min/max to int8 range, using scale and zero-point to preserve the original distribution.
  1. Map the min/max values from fp32 to int8.
  2. The min/max values have different distances to 0. 0 in fp32 doesn't equal 0 in int8.
  3. Calculate a scaling factor to get a linear mapping for the remaining values and adjust them with the zero-point value to account for the different distances to 0.
  4. Dequantize weights with scaling factor and zero-point so you can perform computations with your inputs (presumably in fp16/bf16).

The quantization and dequantization steps may decrease inference speed and be lossy, especially for lower quantization levels like int4.

With Transformers, choose and configure a quantization backend. Then plug the quantization_config into from_pretrained to quantize a model.

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=quantization_config
)

offloading

Diffusers offers several offloading options. Offloading moves weights off the GPU to another device when they're not in use. This is useful for large models like Flux.1 [dev].

Flux.1 [dev] requires ~9GB of memory for the two text encoders and ~22GB for the transformer model. Loading and generating an image uses ~33GB in bf16.

Diffusers offers 3 offloading options.

import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16
)
pipeline.enable_model_cpu_offload()
Flux.1 [dev] uses 22.6GB with model offloading
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload()
Flux.1 [dev] uses 2.4GB with CPU offloading
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16
)
apply_group_offloading(
    pipeline.transformer,
    offload_type="block_level",
    num_blocks_per_group=2,
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    use_stream=True,
)
Flux.1 [dev] uses 4.41GB with group offloading

tensor parallelism

Tensor parallelism distributes model weights (tensors) across multiple GPUs. This helps you fit large models into memory that wouldn't otherwise fit on a single GPU.

weight matrixWshardGPU 0W₀GPU 1W₁GPU 2W₂x × W₀x × W₁x × W₂all-reduceYparallel computationTensor parallelism shards weights across GPUs. Each GPU computes on its slice in parallel, then results sync via all-reduce.

It is faster because each GPU can perform computations in parallel and sync the results at the end to return the final output.

There is a bit of communication overhead between GPUs, so it is best for single machines with multiple GPUs that communicate with faster intra-node communications.

Set the tp_plan argument in from_pretrained to use tensor parallelism.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8",
    dtype=torch.bfloat16,
    tp_plan="auto"
)

kv cache

Decoder models predict one token at a time. The predicted token is dependent on all of the previous context. Every time the model predicts a new token, it ends up performing some of the same calculations again.

Performing the same calculations every time is wasteful and slows down inference.

A key-value (kv) cache stores the previously calculated kv values and reuses them to avoid recomputation. At each step, you're only calculating the kv value for the current token rather than all the previous ones.

step 1step 2step 3step 4withoutcacheK₁V₁newK₁V₁redoK₂V₂newK₁V₁redoK₂V₂redoK₃V₃new.........K₄V₄newwithcacheK₁V₁newK₁V₁cachedK₂V₂newK₁V₁cachedK₂V₂cachedK₃V₃new.........K₄V₄newWithout caching, K and V are recomputed for all tokens at each step. With caching, only the new token's K and V are computed.

However, storing the kv values requires memory that grows linearly with sequence length.

Transformers provides two memory-optimized cache types.

Configure the cache_implementation argument in generate to use either cache type.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    dtype=torch.bfloat16,
    device_map="auto"
)
model.generate(
    **inputs,
    do_sample=False, 
    max_new_tokens=23,
    cache_implementation="offloaded"
)

resources