@@ -83,6 +83,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
8383 generator .generate_next_token ()
8484 return tokenizer .decode (generator .get_sequence (0 ))
8585
86+ # Use prompt length to get pre-defined prompt
87+ def get_prompt_by_length (prompt_length ):
88+ json_path = "prompts.json"
89+ f = open (json_path )
90+ data = json .load (f )
91+
92+ return data [f"{ prompt_length } " ]
93+
8694def get_target_pip_package_version (target_pip_package_name_list ):
8795 # get package name and version
8896 import pkg_resources
@@ -232,6 +240,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
232240 # use random tokens instead of generating a prompt using the model and then tokenizing it
233241 tokens = np .random .randint (100 , size = (batch_size , prompt_length ))
234242 prompt = [tokenizer .decode (tokens [0 ])] * batch_size
243+ elif args .use_prompt_set :
244+ prompt = get_prompt_by_length (prompt_length )
245+ tokens = tokenizer .encode_batch (prompt )
235246 else :
236247 prompt = [generate_prompt (model , tokenizer , prompt_length , args .use_graph_capture )] * batch_size
237248 tokens = tokenizer .encode_batch (prompt )
@@ -424,6 +435,7 @@ def str2strlist(value):
424435 parser .add_argument ('-mn' , '--model_name' , type = str , default = 'model_name' , help = 'Model name defined by users' )
425436 parser .add_argument ('-pr' , '--precision' , type = str , default = 'fp16' , help = 'Model precision for metrics info' )
426437 parser .add_argument ('--use_random_tokens' , action = 'store_true' , help = 'Use random tokens instead of generating a prompt' )
438+ parser .add_argument ('--use_prompt_set' , action = 'store_true' , help = 'Use pre-generated prompt set instead of generating a prompt' )
427439 args = parser .parse_args ()
428440
429441 # check max_lengths
0 commit comments