diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4229577b95..a657f602c1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -285,7 +285,7 @@ def _replace_with_custom_fn_if_matches_filter( new_module.weight = model.weight new_module.bias = model.bias model = new_module - if filter_fn(model, cur_fqn[:-1]): + if filter_fn(model): if device is not None: model.to(device=device) # move to device before quantization model = replacement_fn(model, *extra_args)