@@ -774,6 +774,60 @@ def test_apc_multiple_prompts_partial_cached_outputs(
774774 )
775775
776776
777+ # Test that outputs match whether prefix caching is enabled or not for mamba.
778+ @pytest .mark .parametrize ("model" , ["tiiuae/falcon-mamba-7b" ])
779+ def test_same_mamba_output_apc_on_vs_off (
780+ vllm_runner ,
781+ model : str ,
782+ ) -> None :
783+ num_logprobs = 5
784+ prompts = [
785+ "hello what is one plus one what is one plus one what is one plus one the answer is" , # noqa: E501
786+ "hello what is one plus one what is one plus one what is one plus one the answer is" , # noqa: E501
787+ ]
788+ max_tokens = 20
789+ max_model_len = max (len (p ) for p in prompts ) + max_tokens + 64
790+
791+ base_kwargs = _get_vllm_runner_params (model , max_model_len )
792+ base_kwargs .update (
793+ enforce_eager = True , block_size = 16 , seed = 42 , gpu_memory_utilization = 0.8
794+ )
795+
796+ # No prefix caching
797+ kwargs_no_apc = {** base_kwargs , "enable_prefix_caching" : False }
798+ with vllm_runner (** kwargs_no_apc ) as vllm_model :
799+ outputs_no_apc , _ = _get_vLLM_output (
800+ vllm_runner ,
801+ kwargs_no_apc ,
802+ prompts ,
803+ max_tokens ,
804+ num_logprobs = num_logprobs ,
805+ vllm_model = vllm_model ,
806+ )
807+ # With prefix caching
808+ kwargs_with_apc = {
809+ ** base_kwargs ,
810+ "enable_prefix_caching" : True ,
811+ "mamba_block_size" : 16 ,
812+ }
813+ with vllm_runner (** kwargs_with_apc ) as vllm_model :
814+ outputs_with_apc , _ = _get_vLLM_output (
815+ vllm_runner ,
816+ kwargs_with_apc ,
817+ prompts ,
818+ max_tokens ,
819+ num_logprobs = num_logprobs ,
820+ vllm_model = vllm_model ,
821+ )
822+
823+ check_logprobs_close (
824+ outputs_0_lst = outputs_no_apc [0 ],
825+ outputs_1_lst = outputs_with_apc [0 ],
826+ name_0 = "vllm_no_apc" ,
827+ name_1 = "vllm_with_apc" ,
828+ )
829+
830+
777831# we have to use a real large model to get reasonable results
778832# the model can't be a hybrid model as we need block_size 16
779833@pytest .mark .parametrize ("model" , ["tiiuae/falcon-mamba-7b" ])
0 commit comments