@@ -82,6 +82,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
8282 generator .generate_next_token ()
8383 return tokenizer .decode (generator .get_sequence (0 ))
8484
85+ # Use prompt length to get pre-defined prompt
86+ def get_prompt_by_length (prompt_length ):
87+ json_path = "prompts.json"
88+ with open (json_path ) as prompts_file :
89+ content = prompts_file .read ()
90+ data = json .load (content )
91+ return data [f"{ prompt_length } " ]
92+
8593def get_target_pip_package_version (target_pip_package_name_list ):
8694 # get package name and version
8795 import pkg_resources
@@ -231,6 +239,18 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
231239 # use random tokens instead of generating a prompt using the model and then tokenizing it
232240 tokens = np .random .randint (100 , size = (batch_size , prompt_length ))
233241 prompt = [tokenizer .decode (tokens [0 ])] * batch_size
242+ elif args .use_prompt_set :
243+ prompt = [get_prompt_by_length (prompt_length )] * batch_size
244+ tokens = tokenizer .encode_batch (prompt )
245+
246+ if len (tokens ) > max_length :
247+ # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
248+ tokens = tokens [:, :max_length ]
249+ elif len (tokens ) < max_length :
250+ # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
251+ tokens_first_col = tokens [:, 0 ].unsqueeze (0 ).T
252+ for _ in range (max_length - len (tokens )):
253+ tokens = np .hstack ((tokens_first_col , tokens ))
234254 else :
235255 prompt = [generate_prompt (model , tokenizer , prompt_length , args .use_graph_capture )] * batch_size
236256 tokens = tokenizer .encode_batch (prompt )
@@ -416,6 +436,7 @@ def str2strlist(value):
416436 parser .add_argument ('-mn' , '--model_name' , type = str , default = 'model_name' , help = 'Model name defined by users' )
417437 parser .add_argument ('-pr' , '--precision' , type = str , default = 'fp16' , help = 'Model precision for metrics info' )
418438 parser .add_argument ('--use_random_tokens' , action = 'store_true' , help = 'Use random tokens instead of generating a prompt' )
439+ parser .add_argument ('--use_prompt_set' , action = 'store_true' , help = 'Use pre-generated prompt set instead of generating a prompt' )
419440 args = parser .parse_args ()
420441
421442 # check max_lengths
0 commit comments