Skip to content

Commit df4c12e

Browse files
authored
Fix tests and improve CI (#11614)
### Summary Fix broken test `test_gather_benchmark_configs` in the CI. Ensure the tests will be triggered when the config script are modified. ### Test plan `python .ci/scripts/tests/test_gather_benchmark_configs.py` ``` ---------------------------------------------------------------------- Ran 14 tests in 0.156s OK` ```
1 parent 0536862 commit df4c12e

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,11 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
135135
# etLLM recipes for Llama
136136
repo_name = model_name.split("meta-llama/")[1]
137137
if "qlora" in repo_name.lower():
138-
configs.append("llama3_qlora")
138+
configs = ["llama3_qlora"]
139139
elif "spinquant" in repo_name.lower():
140-
configs.append("llama3_spinquant")
140+
configs = ["llama3_spinquant"]
141141
else:
142-
configs.append("llama3_fb16")
143-
configs.append("et_xnnpack_custom_spda_kv_cache_8da4w")
142+
configs.extend(["llama3_fb16", "et_xnnpack_custom_spda_kv_cache_8da4w"])
144143
configs.extend(
145144
[
146145
config

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,24 @@ def test_generate_compatible_configs_llama_model(self):
112112
result = self.gather_benchmark_configs.generate_compatible_configs(
113113
model_name, target_os
114114
)
115-
expected = ["llama3_fb16", "llama3_coreml_ane"]
116-
self.assertEqual(result, expected)
115+
expected = [
116+
"llama3_fb16",
117+
"llama3_coreml_ane",
118+
"et_xnnpack_custom_spda_kv_cache_8da4w",
119+
"hf_xnnpack_custom_spda_kv_cache_8da4w",
120+
]
121+
self.assertCountEqual(result, expected)
117122

118123
target_os = "android"
119124
result = self.gather_benchmark_configs.generate_compatible_configs(
120125
model_name, target_os
121126
)
122-
expected = ["llama3_fb16"]
123-
self.assertEqual(result, expected)
127+
expected = [
128+
"llama3_fb16",
129+
"et_xnnpack_custom_spda_kv_cache_8da4w",
130+
"hf_xnnpack_custom_spda_kv_cache_8da4w",
131+
]
132+
self.assertCountEqual(result, expected)
124133

125134
def test_generate_compatible_configs_quantized_llama_model(self):
126135
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"

.github/workflows/android-perf.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ on:
66
pull_request:
77
paths:
88
- .github/workflows/android-perf.yml
9+
- .ci/scripts/gather_benchmark_configs.py
910
- extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2
1011
push:
1112
branches:
1213
- main
1314
paths:
1415
- .github/workflows/android-perf.yml
16+
- .ci/scripts/gather_benchmark_configs.py
1517
- extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2
1618
# Note: GitHub has an upper limit of 10 inputs
1719
workflow_dispatch:

.github/workflows/apple-perf.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ on:
66
pull_request:
77
paths:
88
- .github/workflows/apple-perf.yml
9+
- .ci/scripts/gather_benchmark_configs.py
910
- extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2
1011
push:
1112
branches:
1213
- main
1314
paths:
1415
- .github/workflows/apple-perf.yml
16+
- .ci/scripts/gather_benchmark_configs.py
1517
- extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2
1618
# Note: GitHub has an upper limit of 10 inputs
1719
workflow_dispatch:

0 commit comments

Comments
 (0)