Skip to content

Commit 7742d05

Browse files
author
郑锐
committed
add code for ch5
1 parent a1b941a commit 7742d05

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+6714
-0
lines changed

.DS_Store

6 KB
Binary file not shown.

code/ch5/.gitignore

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
results/
132+
outputs/
133+
134+
.amltconfig
135+
.test_output
136+
*.hdf5
137+
*.h5

code/ch5/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## 📰 Chapter 5 Deepspeed-Chat SFT 实践 📰
2+
3+
### 🐼 环境安装
4+
5+
```bash
6+
pip install deepspeed>=0.9.0
7+
8+
pip install -r requirements.txt
9+
pip install -e .
10+
```
11+
12+
### 🐼 数据预处理
13+
14+
在数据处理代码文件`dschat/utils/data/raw_datasets.py``dschat/utils/data/data_utils.py`(新版代码路径与教材中有所不同)中增加对新增数据的处理。
15+
16+
### 🐼 自定义模型
17+
虽然 Deepspeed-Chat 内置了在各项评估上都表现良好的 Llama-2 7B 模型,但是由于模型在 预训练中并没有在足够的中文数据上训练,因此其中文能力并不强。当需要使用支持中文的预训练 模型,或者更换其他模型时,就需要对 Deepspeed-Chat 进行相应的更改来适配其他自定义的模型。对`training/step1_supervised_finetuning/main.py`进行修改来导入相应的模型并在`training/step1_supervised_finetuning/training_scripts`修改对应训练脚本。
18+
19+
20+
### 🐼 模型训练
21+
```bash
22+
# 在路径 training/step1_supervised_finetuning 下运行, 示例中在一台 8 卡 Nvidia A100 机器下进行训练
23+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash training/step1_supervised_finetuning/training_scripts/baichuan/run_baichuan_7b.sh
24+
```

code/ch5/chat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import argparse
7+
import subprocess
8+
9+
if __name__ == "__main__":
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--path",
12+
type=str,
13+
help="Directory containing trained actor model")
14+
parser.add_argument(
15+
"--max_new_tokens",
16+
type=int,
17+
default=128,
18+
help="Maximum new tokens to generate per response",
19+
)
20+
args = parser.parse_args()
21+
22+
cmd = f"python3 ./inference/chatbot.py --path {args.path} --max_new_tokens {args.max_new_tokens}"
23+
p = subprocess.Popen(cmd, shell=True)
24+
p.wait()

0 commit comments

Comments
 (0)