Skip to content

Commit 2adbd12

Browse files
authored
Change HPU warmup logic: seq length should be with exponential growth (#659)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent f1df357 commit 2adbd12

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
tracer = trace.get_tracer(__name__)
1313
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
14+
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
1415

1516

16-
def round_up(number, k):
17-
return (number + k - 1) // k * k
17+
def round_up_seq(number, k, base):
18+
exponent = max(0, math.ceil(math.log(number / k, base)))
19+
return int(k * (base**exponent))
1820

1921

2022
class Batch(ABC):
@@ -46,7 +48,9 @@ def from_pb(
4648
batch_size = len(pb.cu_seq_lengths) - 1
4749
if device.type == "hpu":
4850
# To better utilize HPU, we need to do batch/seq_len bucketing
49-
max_length = round_up(pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF)
51+
max_length = round_up_seq(
52+
pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
53+
)
5054
max_length = min(max_length, max_input_length)
5155
new_bs = 2 ** math.ceil(math.log2(batch_size))
5256
else:

backends/src/lib.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ fn powers_of_two(max_value: usize) -> Vec<usize> {
3939
result
4040
}
4141

42+
fn generate_bucket_sizes(bucket_size: usize, max_s: usize, base_exp: usize) -> Vec<usize> {
43+
let mut sizes = Vec::new();
44+
let mut current = bucket_size;
45+
46+
while current <= max_s {
47+
sizes.push(current);
48+
match current.checked_mul(base_exp) {
49+
Some(next) => current = next,
50+
None => break,
51+
}
52+
}
53+
54+
sizes
55+
}
56+
4257
fn is_hpu() -> bool {
4358
match Command::new("hl-smi")
4459
.args(["-Q", "name", "-f", "csv"])
@@ -114,7 +129,7 @@ impl Backend {
114129
};
115130
let seq_bucket_size: usize = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
116131
let max_warmup_length: usize = read_env_var("MAX_WARMUP_SEQUENCE_LENGTH", 1024);
117-
132+
let seq_len_exp_base: usize = read_env_var("SEQ_LEN_EXPONENT_BASE", 2);
118133
let max_batch_size = max_bs.unwrap_or_else(|| read_env_var("MAX_WARMUP_BATCH_SIZE", 8));
119134

120135
let mut batch_sizes: Vec<usize> = powers_of_two(max_batch_size);
@@ -135,9 +150,11 @@ impl Backend {
135150
}
136151

137152
max_input_length = std::cmp::min(max_input_length, max_warmup_length);
138-
let mut seq_lengths: Vec<usize> = (seq_bucket_size..=max_input_length)
139-
.step_by(seq_bucket_size)
140-
.collect();
153+
let mut seq_lengths: Vec<usize> = generate_bucket_sizes(
154+
seq_bucket_size,
155+
max_input_length,
156+
seq_len_exp_base,
157+
);
141158
if let Some(&last) = seq_lengths.last() {
142159
if last < max_input_length {
143160
seq_lengths.push(max_input_length);

0 commit comments

Comments
 (0)