@@ -39,6 +39,21 @@ fn powers_of_two(max_value: usize) -> Vec<usize> {
39
39
result
40
40
}
41
41
42
+ fn generate_bucket_sizes ( bucket_size : usize , max_s : usize , base_exp : usize ) -> Vec < usize > {
43
+ let mut sizes = Vec :: new ( ) ;
44
+ let mut current = bucket_size;
45
+
46
+ while current <= max_s {
47
+ sizes. push ( current) ;
48
+ match current. checked_mul ( base_exp) {
49
+ Some ( next) => current = next,
50
+ None => break ,
51
+ }
52
+ }
53
+
54
+ sizes
55
+ }
56
+
42
57
fn is_hpu ( ) -> bool {
43
58
match Command :: new ( "hl-smi" )
44
59
. args ( [ "-Q" , "name" , "-f" , "csv" ] )
@@ -114,7 +129,7 @@ impl Backend {
114
129
} ;
115
130
let seq_bucket_size: usize = read_env_var ( "PAD_SEQUENCE_TO_MULTIPLE_OF" , 128 ) ;
116
131
let max_warmup_length: usize = read_env_var ( "MAX_WARMUP_SEQUENCE_LENGTH" , 1024 ) ;
117
-
132
+ let seq_len_exp_base : usize = read_env_var ( "SEQ_LEN_EXPONENT_BASE" , 2 ) ;
118
133
let max_batch_size = max_bs. unwrap_or_else ( || read_env_var ( "MAX_WARMUP_BATCH_SIZE" , 8 ) ) ;
119
134
120
135
let mut batch_sizes: Vec < usize > = powers_of_two ( max_batch_size) ;
@@ -135,9 +150,11 @@ impl Backend {
135
150
}
136
151
137
152
max_input_length = std:: cmp:: min ( max_input_length, max_warmup_length) ;
138
- let mut seq_lengths: Vec < usize > = ( seq_bucket_size..=max_input_length)
139
- . step_by ( seq_bucket_size)
140
- . collect ( ) ;
153
+ let mut seq_lengths: Vec < usize > = generate_bucket_sizes (
154
+ seq_bucket_size,
155
+ max_input_length,
156
+ seq_len_exp_base,
157
+ ) ;
141
158
if let Some ( & last) = seq_lengths. last ( ) {
142
159
if last < max_input_length {
143
160
seq_lengths. push ( max_input_length) ;
0 commit comments