Skip to content

Commit b9dbc5c

Browse files
authored
[Mamba][APC] Add test case to compare apc outputs (vllm-project#34977)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
1 parent 60af7b9 commit b9dbc5c

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,60 @@ def test_apc_multiple_prompts_partial_cached_outputs(
774774
)
775775

776776

777+
# Test that outputs match whether prefix caching is enabled or not for mamba.
778+
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
779+
def test_same_mamba_output_apc_on_vs_off(
780+
vllm_runner,
781+
model: str,
782+
) -> None:
783+
num_logprobs = 5
784+
prompts = [
785+
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
786+
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
787+
]
788+
max_tokens = 20
789+
max_model_len = max(len(p) for p in prompts) + max_tokens + 64
790+
791+
base_kwargs = _get_vllm_runner_params(model, max_model_len)
792+
base_kwargs.update(
793+
enforce_eager=True, block_size=16, seed=42, gpu_memory_utilization=0.8
794+
)
795+
796+
# No prefix caching
797+
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
798+
with vllm_runner(**kwargs_no_apc) as vllm_model:
799+
outputs_no_apc, _ = _get_vLLM_output(
800+
vllm_runner,
801+
kwargs_no_apc,
802+
prompts,
803+
max_tokens,
804+
num_logprobs=num_logprobs,
805+
vllm_model=vllm_model,
806+
)
807+
# With prefix caching
808+
kwargs_with_apc = {
809+
**base_kwargs,
810+
"enable_prefix_caching": True,
811+
"mamba_block_size": 16,
812+
}
813+
with vllm_runner(**kwargs_with_apc) as vllm_model:
814+
outputs_with_apc, _ = _get_vLLM_output(
815+
vllm_runner,
816+
kwargs_with_apc,
817+
prompts,
818+
max_tokens,
819+
num_logprobs=num_logprobs,
820+
vllm_model=vllm_model,
821+
)
822+
823+
check_logprobs_close(
824+
outputs_0_lst=outputs_no_apc[0],
825+
outputs_1_lst=outputs_with_apc[0],
826+
name_0="vllm_no_apc",
827+
name_1="vllm_with_apc",
828+
)
829+
830+
777831
# we have to use a real large model to get reasonable results
778832
# the model can't be a hybrid model as we need block_size 16
779833
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])

0 commit comments

Comments
 (0)