Skip to content

Commit f8d1bdc

Browse files
committed
Add pre-generated prompts for benchmark
1 parent 02feea3 commit f8d1bdc

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

benchmark/python/benchmark_e2e.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
8383
generator.generate_next_token()
8484
return tokenizer.decode(generator.get_sequence(0))
8585

86+
# Use prompt length to get pre-defined prompt
87+
def get_prompt_by_length(prompt_length):
88+
json_path = "prompts.json"
89+
f = open(json_path)
90+
data = json.load(f)
91+
92+
return data[f"{prompt_length}"]
93+
8694
def get_target_pip_package_version(target_pip_package_name_list):
8795
# get package name and version
8896
import pkg_resources
@@ -232,6 +240,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
232240
# use random tokens instead of generating a prompt using the model and then tokenizing it
233241
tokens = np.random.randint(100, size=(batch_size, prompt_length))
234242
prompt = [tokenizer.decode(tokens[0])] * batch_size
243+
elif args.use_prompt_set:
244+
prompt = get_prompt_by_length(prompt_length)
245+
tokens = tokenizer.encode_batch(prompt)
235246
else:
236247
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
237248
tokens = tokenizer.encode_batch(prompt)
@@ -424,6 +435,7 @@ def str2strlist(value):
424435
parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users')
425436
parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info')
426437
parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt')
438+
parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt')
427439
args = parser.parse_args()
428440

429441
# check max_lengths

benchmark/python/prompts.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"16": "How are astronauts launched into space quickly on those rockets? ",
3+
"64": "",
4+
"256": "",
5+
"1024": "",
6+
"2048": ""
7+
}

0 commit comments

Comments
 (0)