Skip to content

Commit 8eb7a84

Browse files
authored
Make sliding_window for Qwen2 optional (#546)
No need for that, what you have is good enough for now. There are other things we could do better overall regarding tensor names.
1 parent d8206a3 commit 8eb7a84

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

backends/candle/src/models/flash_qwen2.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct Qwen2Attention {
2222
impl Qwen2Attention {
2323
pub fn load(vb: VarBuilder, config: &Qwen2Config) -> Result<Self> {
2424
if config.use_sliding_window {
25-
candle::bail!("Sliding window is not supported");
25+
candle::bail!("Sliding window is not supported for Qwen2",);
2626
}
2727

2828
let num_attention_heads = config.num_attention_heads;
@@ -264,7 +264,15 @@ impl FlashQwen2Model {
264264
ModelType::Embedding(pool) => pool,
265265
};
266266

267-
let vb = vb.pp("model");
267+
// Pushing the prefix for `model` is apparently only required if the model architecture is
268+
// ForCausalLM as it contains the `lm_head`, other than that, the `model` key won't be
269+
// present e.g. a model without the `model` key as it's a `Qwen2Model` instance not a
270+
// `Qwen2ModelForCausalLM` is https://huggingface.co/mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B
271+
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
272+
vb.pp("model")
273+
} else {
274+
vb
275+
};
268276

269277
let embeddings = Embedding::new(
270278
vb.pp("embed_tokens")

backends/candle/src/models/qwen2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ pub struct Qwen2Config {
1313
pub max_position_embeddings: usize,
1414
pub rms_norm_eps: f32,
1515
pub rope_theta: f32,
16-
pub sliding_window: usize,
16+
pub sliding_window: Option<usize>,
1717
pub use_sliding_window: bool,
1818
}

0 commit comments

Comments
 (0)