Skip to content

Add static quant flow to NVFP4 #2572

@drisspg

Description

@drisspg

Summary

We have all the features built out for the statically calibrated double quant scaling for NVFP4, just not a front end api to do this
I had claude generate this script that looks pretty good

#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0

"""
Script for calibrating NVFP4 models with per-tensor amax collection.
Collects calibration data for static quantization.
"""

import random
import numpy as np
import torch
import torch.nn.functional as F
import time
import json
from pathlib import Path
from typing import Optional, Literal, Dict

from transformers import AutoModelForCausalLM, AutoTokenizer
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.granularity import PerTensor
from torchao.quantization.quant_primitives import MappingType
from jsonargparse import CLI
from rich import print


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class NVFP4ObservedLinear(torch.nn.Linear):
    """Linear layer with observers for NVFP4 calibration"""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        bias: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs

    def forward(self, input: torch.Tensor):
        # Observe activations and weights during forward pass
        observed_input = self.act_obs(input)
        observed_weight = self.weight_obs(self.weight)
        return F.linear(observed_input, observed_weight, self.bias)

    @classmethod
    def from_float(cls, float_linear, act_obs, weight_obs):
        """Convert regular Linear to observed Linear"""
        observed_linear = cls(
            float_linear.in_features,
            float_linear.out_features,
            act_obs,
            weight_obs,
            float_linear.bias is not None,
            device=float_linear.weight.device,
            dtype=float_linear.weight.dtype,
        )
        observed_linear.weight = float_linear.weight
        observed_linear.bias = float_linear.bias
        return observed_linear


def create_nvfp4_observer():
    """Create observer for NVFP4 calibration (symmetric, per-tensor)"""
    return AffineQuantizedMinMaxObserver(
        mapping_type=MappingType.SYMMETRIC,  # NVFP4 is symmetric
        target_dtype=torch.float8_e4m3fn,  # Use as proxy for scale computation
        granularity=PerTensor(),  # Per-tensor amax
        eps=torch.finfo(torch.float32).eps,
        scale_dtype=torch.float32,
        zero_point_dtype=torch.float32,
    )


def insert_nvfp4_observers(model):
    """Inject NVFP4 observers into all Linear modules"""

    def is_linear(module, fqn):
        return isinstance(module, torch.nn.Linear)

    def create_observed_linear(module):
        act_obs = create_nvfp4_observer()
        weight_obs = create_nvfp4_observer()
        return NVFP4ObservedLinear.from_float(module, act_obs, weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, create_observed_linear, is_linear)
    print(
        f"Inserted observers into {sum(1 for _ in model.modules() if isinstance(_, NVFP4ObservedLinear))} Linear modules"
    )


def collect_calibration_data(model) -> Dict[str, torch.Tensor]:
    """Extract per-tensor amax data from observed modules"""
    calibration_data = {}

    for name, module in model.named_modules():
        if isinstance(module, NVFP4ObservedLinear):
            # Get amax from observers (scale contains the amax)
            weight_scale, _ = module.weight_obs.calculate_qparams()
            act_scale, _ = module.act_obs.calculate_qparams()

            # Store with FQN for later lookup
            calibration_data[f"{name}.weight_amax"] = weight_scale.cpu()
            calibration_data[f"{name}.activation_amax"] = act_scale.cpu()

    print(f"Collected calibration data for {len(calibration_data)//2} modules")
    return calibration_data


def run_calibration_loop(
    model, tokenizer, num_samples: int = 100, max_length: int = 512
):
    """Run calibration with random text samples"""
    print(f"Running calibration with {num_samples} samples...")

    # Sample calibration prompts (you can customize this)
    calibration_prompts = [
        "The quick brown fox jumps over the lazy dog.",
        "In the beginning was the Word, and the Word was with God.",
        "To be or not to be, that is the question.",
        "It was the best of times, it was the worst of times.",
        "Call me Ishmael. Some years ago—never mind how long precisely.",
        "In a hole in the ground there lived a hobbit.",
        "Space: the final frontier. These are the voyages of the starship Enterprise.",
        "Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly normal.",
        "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.",
        "Happy families are all alike; every unhappy family is unhappy in its own way.",
    ]

    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            # Cycle through prompts and add some randomness
            prompt = calibration_prompts[i % len(calibration_prompts)]
            if i > 0:
                prompt = f"Sample {i}: {prompt}"

            # Tokenize and run forward pass
            inputs = tokenizer(
                prompt,
                return_tensors="pt",
                max_length=max_length,
                truncation=True,
                padding=True,
            ).to(model.device)

            # Forward pass to collect statistics
            outputs = model(**inputs)

            if (i + 1) % 10 == 0:
                print(f"  Processed {i + 1}/{num_samples} samples")


def save_calibration_data(calibration_data: Dict[str, torch.Tensor], output_path: Path):
    """Save calibration data to JSON"""
    # Convert tensors to lists for JSON serialization
    json_data = {}
    for key, tensor in calibration_data.items():
        json_data[key] = tensor.tolist()

    with open(output_path, "w") as f:
        json.dump(json_data, f, indent=2)

    print(f"Saved calibration data to {output_path}")


def load_calibration_data(input_path: Path) -> Dict[str, torch.Tensor]:
    """Load calibration data from JSON"""
    with open(input_path, "r") as f:
        json_data = json.load(f)

    # Convert lists back to tensors
    calibration_data = {}
    for key, tensor_list in json_data.items():
        calibration_data[key] = torch.tensor(tensor_list)

    return calibration_data


def main(
    model_name: str = "facebook/opt-125m",
    output_dir: Optional[str] = None,
    num_calibration_samples: int = 100,
    max_sequence_length: int = 512,
    device_map: str = "cuda",
    save_format: Literal["json", "pt"] = "json",
):
    """
    Calibrate NVFP4 model and collect per-tensor amax data.

    Args:
        model_name: Model to calibrate (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
        output_dir: Directory to save calibration data
        num_calibration_samples: Number of samples to use for calibration
        max_sequence_length: Maximum sequence length for calibration
        device_map: Device mapping strategy
        save_format: Format to save calibration data ('json' or 'pt')
    """
    set_seed(42)

    # Set default output directory
    if output_dir is None:
        model_base_name = model_name.split("/")[-1]
        output_dir = f"data/nvfp4-calibration-{model_base_name}"

    print(f"🔧 Calibrating model: {model_name}")
    print(f"📊 Using {num_calibration_samples} calibration samples")
    print(f"📏 Max sequence length: {max_sequence_length}")

    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load model and tokenizer
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Insert observers
    print("Inserting NVFP4 observers...")
    insert_nvfp4_observers(model)

    # Run calibration
    start_time = time.time()
    run_calibration_loop(
        model,
        tokenizer,
        num_samples=num_calibration_samples,
        max_length=max_sequence_length,
    )
    calibration_time = time.time() - start_time
    print(f"⏱️  Calibration completed in {calibration_time:.2f} seconds")

    # Collect calibration data
    print("Collecting calibration data...")
    calibration_data = collect_calibration_data(model)

    # Save calibration data
    if save_format == "json":
        output_file = output_dir / "nvfp4_calibration_data.json"
        save_calibration_data(calibration_data, output_file)
    else:
        output_file = output_dir / "nvfp4_calibration_data.pt"
        torch.save(calibration_data, output_file)
        print(f"Saved calibration data to {output_file}")

    # Save metadata
    metadata = {
        "model_name": model_name,
        "num_calibration_samples": num_calibration_samples,
        "max_sequence_length": max_sequence_length,
        "calibration_time_seconds": calibration_time,
        "num_modules_calibrated": len(calibration_data) // 2,
        "device_map": device_map,
    }

    metadata_file = output_dir / "calibration_metadata.json"
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"📋 Saved metadata to {metadata_file}")

    # Print summary
    print("\n📊 Calibration Summary:")
    print(f"  Model: {model_name}")
    print(f"  Modules calibrated: {len(calibration_data) // 2}")
    print(f"  Calibration samples: {num_calibration_samples}")
    print(f"  Time taken: {calibration_time:.2f}s")
    print(f"  Output directory: {output_dir}")

    # Show sample of collected data
    print("\n🔍 Sample calibration data:")
    for i, (key, value) in enumerate(list(calibration_data.items())[:4]):
        print(f"  {key}: {value.item():.6f}")
    if len(calibration_data) > 4:
        print(f"  ... and {len(calibration_data) - 4} more entries")

    print("\n✅ NVFP4 calibration completed successfully!")
    print(
        "💡 Use this calibration data with NVFP4InferenceStaticConfig for static quantization"
    )


if __name__ == "__main__":
    CLI(main)

We should add a e2e config for doing this and ensure it works in vllm / diffusers etc

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions