Skip to content

Commit c4ec55e

Browse files
authored
Support different Qwen3 sizes in pkg (rasbt#714)
1 parent ddbaf0d commit c4ec55e

File tree

4 files changed

+194
-175
lines changed

4 files changed

+194
-175
lines changed

ch05/11_qwen3/README.md

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this f
66

77

88
 
9-
### Using Qwen3 0.6B via the `llms-from-scratch` package
9+
### Using Qwen3 via the `llms-from-scratch` package
1010

11-
For an easy way to use the Qwen3 0.6B from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
11+
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
1212

1313
 
1414
#### 1) Installation
@@ -36,9 +36,9 @@ TOP_K = 1
3636
```
3737

3838
 
39-
#### 3) Weight download and loading
39+
#### 3a) Weight download and loading of the 0.6B model
4040

41-
This automatically downloads the weight file based on the model choice above:
41+
The following automatically downloads the weight file based on the model choice (reasoning or base) above. Note that this section focuses on the 0.6B model. Skip this section and continue with section 3b) if you want to work with any of the larger models (1.7B, 4B, 8B, or 32B).
4242

4343
```python
4444
from llms_from_scratch.qwen3 import download_from_huggingface
@@ -77,10 +77,74 @@ device = (
7777
torch.device("mps") if torch.backends.mps.is_available() else
7878
torch.device("cpu")
7979
)
80-
model.to(device)
80+
model.to(device);
8181
```
8282

8383
 
84+
#### 3b) Weight download and loading of the larger Qwen models
85+
86+
If you are interested in working with any of the larger Qwen models, for instance, 1.7B, 4B, 8B, or 32B, please use the following code below instead of the code under 3a), which requires additional code dependencies:
87+
88+
```bash
89+
pip install safetensors huggingface_hub
90+
```
91+
92+
Then use the following code (make appropriate changes to `USE_MODEL` to select the desired model size)
93+
94+
```python
95+
USE_MODEL = "1.7B"
96+
97+
if USE_MODEL == "1.7B":
98+
from llms_from_scratch.qwen3 import QWEN3_CONFIG_1_7B as QWEN3_CONFIG
99+
elif USE_MODEL == "4B":
100+
from llms_from_scratch.qwen3 import QWEN3_CONFIG_4B as QWEN3_CONFIG
101+
elif USE_MODEL == "8B":
102+
from llms_from_scratch.qwen3 import QWEN3_CONFIG_8B as QWEN3_CONFIG
103+
elif USE_MODEL == "14B":
104+
from llms_from_scratch.qwen3 import QWEN3_CONFIG_14B as QWEN3_CONFIG
105+
elif USE_MODEL == "32B":
106+
from llms_from_scratch.qwen3 import QWEN3_CONFIG_32B as QWEN3_CONFIG
107+
else:
108+
raise ValueError("Invalid USE_MODEL name.")
109+
110+
repo_id = f"Qwen/Qwen3-{USE_MODEL}"
111+
local_dir = f"Qwen3-{USE_MODEL}"
112+
113+
if not USE_REASONING_MODEL:
114+
repo_id = f"{repo_id}-Base"
115+
local_dir = f"{local_dir}-Base"
116+
```
117+
118+
Now, download and load the weights into the `model`:
119+
120+
```python
121+
from llms_from_scratch.qwen3 import (
122+
Qwen3Model,
123+
download_from_huggingface_from_snapshots,
124+
load_weights_into_qwen
125+
)
126+
127+
model = Qwen3Model(QWEN3_CONFIG)
128+
129+
weights_dict = download_from_huggingface_from_snapshots(
130+
repo_id=repo_id,
131+
local_dir=local_dir
132+
)
133+
load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
134+
del weights_dict # delete weight dictionary to free up disk space
135+
136+
device = (
137+
torch.device("cuda") if torch.cuda.is_available() else
138+
torch.device("mps") if torch.backends.mps.is_available() else
139+
torch.device("cpu")
140+
)
141+
142+
model.to(device);
143+
```
144+
145+
146+
 
147+
84148
#### 4) Initialize tokenizer
85149

86150
The following code downloads and initializes the tokenizer:

pkg/llms_from_scratch/kv_cache/qwen3.py

Lines changed: 7 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,17 @@
44
# Code: https://github.com/rasbt/LLMs-from-scratch
55

66
from .utils import KVCache # noqa: F401
7-
8-
import os
9-
import urllib.request
10-
from pathlib import Path
7+
from ..qwen3 import ( # noqa: F401
8+
QWEN_CONFIG_06_B, QWEN3_CONFIG_1_7B, QWEN3_CONFIG_4B,
9+
QWEN3_CONFIG_8B, QWEN3_CONFIG_14B, QWEN3_CONFIG_32B,
10+
Qwen3Tokenizer, load_weights_into_qwen,
11+
download_from_huggingface,
12+
download_from_huggingface_from_snapshots
13+
)
1114

1215
import torch
1316
import torch.nn as nn
1417

15-
# 0.6B model
16-
QWEN_CONFIG_06_B = {
17-
"vocab_size": 151_936, # Vocabulary size
18-
"context_length": 40_960, # Context length that was used to train the model
19-
"emb_dim": 1024, # Embedding dimension
20-
"n_heads": 16, # Number of attention heads
21-
"n_layers": 28, # Number of layers
22-
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
23-
"head_dim": 128, # Size of the heads in GQA
24-
"qk_norm": True, # Whether to normalize queries and values in GQA
25-
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
26-
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
27-
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
28-
}
29-
3018

3119
class Qwen3Model(nn.Module):
3220
def __init__(self, cfg):
@@ -285,150 +273,3 @@ def forward(self, x):
285273
norm_x = norm_x + self.shift
286274

287275
return norm_x.to(input_dtype)
288-
289-
290-
def load_weights_into_qwen(model, param_config, params):
291-
def assign(left, right, tensor_name="unknown"):
292-
if left.shape != right.shape:
293-
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
294-
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
295-
296-
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
297-
298-
for l in range(param_config["n_layers"]):
299-
block = model.trf_blocks[l]
300-
att = block.att
301-
302-
# Q, K, V projections
303-
att.W_query.weight = assign(
304-
att.W_query.weight,
305-
params[f"model.layers.{l}.self_attn.q_proj.weight"],
306-
f"model.layers.{l}.self_attn.q_proj.weight"
307-
)
308-
att.W_key.weight = assign(
309-
att.W_key.weight,
310-
params[f"model.layers.{l}.self_attn.k_proj.weight"],
311-
f"model.layers.{l}.self_attn.k_proj.weight"
312-
)
313-
att.W_value.weight = assign(
314-
att.W_value.weight,
315-
params[f"model.layers.{l}.self_attn.v_proj.weight"],
316-
f"model.layers.{l}.self_attn.v_proj.weight"
317-
)
318-
319-
# Output projection
320-
att.out_proj.weight = assign(
321-
att.out_proj.weight,
322-
params[f"model.layers.{l}.self_attn.o_proj.weight"],
323-
f"model.layers.{l}.self_attn.o_proj.weight"
324-
)
325-
326-
# QK norms
327-
if hasattr(att, "q_norm") and att.q_norm is not None:
328-
att.q_norm.scale = assign(
329-
att.q_norm.scale,
330-
params[f"model.layers.{l}.self_attn.q_norm.weight"],
331-
f"model.layers.{l}.self_attn.q_norm.weight"
332-
)
333-
if hasattr(att, "k_norm") and att.k_norm is not None:
334-
att.k_norm.scale = assign(
335-
att.k_norm.scale,
336-
params[f"model.layers.{l}.self_attn.k_norm.weight"],
337-
f"model.layers.{l}.self_attn.k_norm.weight"
338-
)
339-
340-
# Attention layernorm
341-
block.norm1.scale = assign(
342-
block.norm1.scale,
343-
params[f"model.layers.{l}.input_layernorm.weight"],
344-
f"model.layers.{l}.input_layernorm.weight"
345-
)
346-
347-
# Feedforward weights
348-
block.ff.fc1.weight = assign(
349-
block.ff.fc1.weight,
350-
params[f"model.layers.{l}.mlp.gate_proj.weight"],
351-
f"model.layers.{l}.mlp.gate_proj.weight"
352-
)
353-
block.ff.fc2.weight = assign(
354-
block.ff.fc2.weight,
355-
params[f"model.layers.{l}.mlp.up_proj.weight"],
356-
f"model.layers.{l}.mlp.up_proj.weight"
357-
)
358-
block.ff.fc3.weight = assign(
359-
block.ff.fc3.weight,
360-
params[f"model.layers.{l}.mlp.down_proj.weight"],
361-
f"model.layers.{l}.mlp.down_proj.weight"
362-
)
363-
block.norm2.scale = assign(
364-
block.norm2.scale,
365-
params[f"model.layers.{l}.post_attention_layernorm.weight"],
366-
f"model.layers.{l}.post_attention_layernorm.weight"
367-
)
368-
369-
# Final normalization and output head
370-
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
371-
372-
# Model uses weight tying, hence we reuse the embedding layer weights here
373-
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
374-
375-
376-
class Qwen3Tokenizer():
377-
def __init__(self, tokenizer_file_path="tokenizer.json",
378-
repo_id=None, add_generation_prompt=False, add_thinking=False):
379-
from tokenizers import Tokenizer
380-
self.tokenizer_file_path = tokenizer_file_path
381-
382-
if add_generation_prompt != add_thinking:
383-
raise ValueError(
384-
"Only add_generation_prompt==add_thinking settings are currently supported"
385-
)
386-
387-
self.add_generation_prompt = add_generation_prompt
388-
self.add_thinking = add_thinking
389-
390-
tokenizer_file_path_obj = Path(tokenizer_file_path)
391-
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
392-
_ = download_from_huggingface(
393-
repo_id=repo_id,
394-
filename=str(tokenizer_file_path_obj.name),
395-
local_dir=str(tokenizer_file_path_obj.parent.name)
396-
)
397-
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
398-
399-
def encode(self, prompt):
400-
messages = [
401-
{"role": "user", "content": prompt}
402-
]
403-
formatted_prompt = self.format_qwen_chat(
404-
messages,
405-
add_generation_prompt=self.add_generation_prompt,
406-
add_thinking=self.add_thinking
407-
)
408-
return self.tokenizer.encode(formatted_prompt).ids
409-
410-
def decode(self, token_ids):
411-
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
412-
413-
@staticmethod
414-
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
415-
prompt = ""
416-
for msg in messages:
417-
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
418-
if add_generation_prompt:
419-
prompt += "<|im_start|>assistant"
420-
if not add_thinking:
421-
prompt += "<|think>\n\n<|/think>\n\n"
422-
else:
423-
prompt += "\n"
424-
return prompt
425-
426-
427-
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
428-
base_url = "https://huggingface.co"
429-
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
430-
Path(local_dir).mkdir(parents=True, exist_ok=True)
431-
dest_path = os.path.join(local_dir, filename)
432-
print(f"Downloading {url} to {dest_path}...")
433-
urllib.request.urlretrieve(url, dest_path)
434-
return dest_path

0 commit comments

Comments
 (0)