-
Notifications
You must be signed in to change notification settings - Fork 310
Open
Description
Summary
We have all the features built out for the statically calibrated double quant scaling for NVFP4, just not a front end api to do this
I had claude generate this script that looks pretty good
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Script for calibrating NVFP4 models with per-tensor amax collection.
Collects calibration data for static quantization.
"""
import random
import numpy as np
import torch
import torch.nn.functional as F
import time
import json
from pathlib import Path
from typing import Optional, Literal, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.granularity import PerTensor
from torchao.quantization.quant_primitives import MappingType
from jsonargparse import CLI
from rich import print
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class NVFP4ObservedLinear(torch.nn.Linear):
"""Linear layer with observers for NVFP4 calibration"""
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
bias: bool = True,
device=None,
dtype=None,
):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs
def forward(self, input: torch.Tensor):
# Observe activations and weights during forward pass
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)
@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
"""Convert regular Linear to observed Linear"""
observed_linear = cls(
float_linear.in_features,
float_linear.out_features,
act_obs,
weight_obs,
float_linear.bias is not None,
device=float_linear.weight.device,
dtype=float_linear.weight.dtype,
)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear
def create_nvfp4_observer():
"""Create observer for NVFP4 calibration (symmetric, per-tensor)"""
return AffineQuantizedMinMaxObserver(
mapping_type=MappingType.SYMMETRIC, # NVFP4 is symmetric
target_dtype=torch.float8_e4m3fn, # Use as proxy for scale computation
granularity=PerTensor(), # Per-tensor amax
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.float32,
)
def insert_nvfp4_observers(model):
"""Inject NVFP4 observers into all Linear modules"""
def is_linear(module, fqn):
return isinstance(module, torch.nn.Linear)
def create_observed_linear(module):
act_obs = create_nvfp4_observer()
weight_obs = create_nvfp4_observer()
return NVFP4ObservedLinear.from_float(module, act_obs, weight_obs)
_replace_with_custom_fn_if_matches_filter(model, create_observed_linear, is_linear)
print(
f"Inserted observers into {sum(1 for _ in model.modules() if isinstance(_, NVFP4ObservedLinear))} Linear modules"
)
def collect_calibration_data(model) -> Dict[str, torch.Tensor]:
"""Extract per-tensor amax data from observed modules"""
calibration_data = {}
for name, module in model.named_modules():
if isinstance(module, NVFP4ObservedLinear):
# Get amax from observers (scale contains the amax)
weight_scale, _ = module.weight_obs.calculate_qparams()
act_scale, _ = module.act_obs.calculate_qparams()
# Store with FQN for later lookup
calibration_data[f"{name}.weight_amax"] = weight_scale.cpu()
calibration_data[f"{name}.activation_amax"] = act_scale.cpu()
print(f"Collected calibration data for {len(calibration_data)//2} modules")
return calibration_data
def run_calibration_loop(
model, tokenizer, num_samples: int = 100, max_length: int = 512
):
"""Run calibration with random text samples"""
print(f"Running calibration with {num_samples} samples...")
# Sample calibration prompts (you can customize this)
calibration_prompts = [
"The quick brown fox jumps over the lazy dog.",
"In the beginning was the Word, and the Word was with God.",
"To be or not to be, that is the question.",
"It was the best of times, it was the worst of times.",
"Call me Ishmael. Some years ago—never mind how long precisely.",
"In a hole in the ground there lived a hobbit.",
"Space: the final frontier. These are the voyages of the starship Enterprise.",
"Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly normal.",
"It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.",
"Happy families are all alike; every unhappy family is unhappy in its own way.",
]
model.eval()
with torch.no_grad():
for i in range(num_samples):
# Cycle through prompts and add some randomness
prompt = calibration_prompts[i % len(calibration_prompts)]
if i > 0:
prompt = f"Sample {i}: {prompt}"
# Tokenize and run forward pass
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=max_length,
truncation=True,
padding=True,
).to(model.device)
# Forward pass to collect statistics
outputs = model(**inputs)
if (i + 1) % 10 == 0:
print(f" Processed {i + 1}/{num_samples} samples")
def save_calibration_data(calibration_data: Dict[str, torch.Tensor], output_path: Path):
"""Save calibration data to JSON"""
# Convert tensors to lists for JSON serialization
json_data = {}
for key, tensor in calibration_data.items():
json_data[key] = tensor.tolist()
with open(output_path, "w") as f:
json.dump(json_data, f, indent=2)
print(f"Saved calibration data to {output_path}")
def load_calibration_data(input_path: Path) -> Dict[str, torch.Tensor]:
"""Load calibration data from JSON"""
with open(input_path, "r") as f:
json_data = json.load(f)
# Convert lists back to tensors
calibration_data = {}
for key, tensor_list in json_data.items():
calibration_data[key] = torch.tensor(tensor_list)
return calibration_data
def main(
model_name: str = "facebook/opt-125m",
output_dir: Optional[str] = None,
num_calibration_samples: int = 100,
max_sequence_length: int = 512,
device_map: str = "cuda",
save_format: Literal["json", "pt"] = "json",
):
"""
Calibrate NVFP4 model and collect per-tensor amax data.
Args:
model_name: Model to calibrate (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
output_dir: Directory to save calibration data
num_calibration_samples: Number of samples to use for calibration
max_sequence_length: Maximum sequence length for calibration
device_map: Device mapping strategy
save_format: Format to save calibration data ('json' or 'pt')
"""
set_seed(42)
# Set default output directory
if output_dir is None:
model_base_name = model_name.split("/")[-1]
output_dir = f"data/nvfp4-calibration-{model_base_name}"
print(f"🔧 Calibrating model: {model_name}")
print(f"📊 Using {num_calibration_samples} calibration samples")
print(f"📏 Max sequence length: {max_sequence_length}")
# Create output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load model and tokenizer
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Insert observers
print("Inserting NVFP4 observers...")
insert_nvfp4_observers(model)
# Run calibration
start_time = time.time()
run_calibration_loop(
model,
tokenizer,
num_samples=num_calibration_samples,
max_length=max_sequence_length,
)
calibration_time = time.time() - start_time
print(f"⏱️ Calibration completed in {calibration_time:.2f} seconds")
# Collect calibration data
print("Collecting calibration data...")
calibration_data = collect_calibration_data(model)
# Save calibration data
if save_format == "json":
output_file = output_dir / "nvfp4_calibration_data.json"
save_calibration_data(calibration_data, output_file)
else:
output_file = output_dir / "nvfp4_calibration_data.pt"
torch.save(calibration_data, output_file)
print(f"Saved calibration data to {output_file}")
# Save metadata
metadata = {
"model_name": model_name,
"num_calibration_samples": num_calibration_samples,
"max_sequence_length": max_sequence_length,
"calibration_time_seconds": calibration_time,
"num_modules_calibrated": len(calibration_data) // 2,
"device_map": device_map,
}
metadata_file = output_dir / "calibration_metadata.json"
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
print(f"📋 Saved metadata to {metadata_file}")
# Print summary
print("\n📊 Calibration Summary:")
print(f" Model: {model_name}")
print(f" Modules calibrated: {len(calibration_data) // 2}")
print(f" Calibration samples: {num_calibration_samples}")
print(f" Time taken: {calibration_time:.2f}s")
print(f" Output directory: {output_dir}")
# Show sample of collected data
print("\n🔍 Sample calibration data:")
for i, (key, value) in enumerate(list(calibration_data.items())[:4]):
print(f" {key}: {value.item():.6f}")
if len(calibration_data) > 4:
print(f" ... and {len(calibration_data) - 4} more entries")
print("\n✅ NVFP4 calibration completed successfully!")
print(
"💡 Use this calibration data with NVFP4InferenceStaticConfig for static quantization"
)
if __name__ == "__main__":
CLI(main)
We should add a e2e config for doing this and ensure it works in vllm / diffusers etc
gau-nernst
Metadata
Metadata
Assignees
Labels
No labels