Skip to content

Commit 8c2c199

Browse files
authored
Jetstream by default (#118)
* test(slots): add unit tests for slots for jetstream too Implementation is slightly different, so a separate test is added. * test(truncate): adapt test for jetstream too * refactor(test): make tinyllama test work for Jetstream and Torch/XLA Most tests work for both, except for the continuous batching one. This allows to remove the old GPT2 based tests, that are quite slow and do not use any sharding or KV cache, so they might not really be representative of most relevant models on TGI. * test(gpt2): remove old test There are equivalent tests now on the TinyLlama model, that run faster, use the KV cache and sharding. The only test that does not have an equivalence is the continuous batching one, but the test was not working for most other models, so I prefer to remove it anyway, as having it passing was not representative anyway of the current state. * feat(tgi): Jetstream/Pytorch is now the default engine Now that the engine is stable and tested, its engine is set as the default one for TGI. * review(test): refactor slot test to avoid repeating code * feat(tests): use pytests markers to filter jetstream and torch xla tests So far filtering was done using the name of the test. Now the selection is done using a custom marker, that allows for clearer filtering. * review(tests): skip test message clarification * ci(torch xla): use JETSTREAM_PT_DISABLE env var in command line For some reason the env var was not carried on (though Jetstream was disabled anyway). Moving the variable to the command line invocation will remove a warning in the logs. * review(ci): fix JETSTREAM_PT_DISABLE env var usage again * fix(tests): remove expected results from tests with do_sample Some tests result change when operations are done in a slightly different way. This has happened now with the torch xla tests, resulting in different results on the CI. To avoid this, now tests compare the obtained token and text is different from the one obtained when running with greedy search.
1 parent e7474e0 commit 8c2c199

18 files changed

+181
-250
lines changed

.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,28 @@ jobs:
3434
- name: Run TGI Jetstream Pytorch - Llama
3535
run: |
3636
python -m \
37-
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Llama"
37+
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Llama"
3838
- name: Run TGI Jetstream Pytorch - Gemma
3939
run: |
4040
python -m \
41-
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and gemma"
41+
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and gemma"
4242
- name: Run TGI Jetstream Pytorch - Mixtral greedy
4343
run: |
4444
python -m \
45-
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Mixtral and greedy"
45+
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Mixtral and greedy"
4646
- name: Run TGI Jetstream Pytorch - Quantization Mixtral
4747
run: |
4848
python -m \
49-
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Mixtral"
49+
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Mixtral"
5050
- name: Run TGI Jetstream Pytorch - Quantization Llama-3 8B
5151
run: |
5252
python -m \
53-
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-8B"
53+
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-8B"
5454
- name: Run TGI Jetstream Pytorch - Quantization Llama 3 70B
5555
run: |
5656
python -m \
57-
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-70B"
57+
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-70B"
5858
- name: Run TGI Jetstream Pytorch - Other tests
5959
run: |
6060
python -m \
61-
pytest -sv text-generation-inference/tests --runslow -k "jetstream and not decode and not quant"
61+
pytest -sv text-generation-inference/tests --runslow -m jetstream -k "not decode"

.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
PJRT_DEVICE: TPU
2424
HF_TOKEN: ${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }}
2525
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
26+
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
2627
steps:
2728
- name: Checkout
2829
uses: actions/checkout@v4

.github/workflows/test-pytorch-xla-tpu-tgi.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
env:
2626
PJRT_DEVICE: TPU
2727
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
28+
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
2829
steps:
2930
- name: Checkout
3031
uses: actions/checkout@v4

.github/workflows/test-pytorch-xla-tpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
env:
2626
PJRT_DEVICE: TPU
2727
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
28+
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
2829
steps:
2930
- name: Checkout
3031
uses: actions/checkout@v4

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ jetstream_requirements: test_installs
9595
tgi_test_jetstream: test_installs jetstream_requirements tgi_server
9696
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
9797
-exec python -m pip install --force-reinstall {} \;
98-
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream
98+
python -m pytest -sv text-generation-inference/tests -m jetstream
9999

100100
tgi_test: test_installs tgi_server
101101
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
102102
-exec python -m pip install --force-reinstall {} \;
103-
python -m pytest -sv text-generation-inference/tests
103+
python -m pytest -sv text-generation-inference/tests -m torch_xla
104104

105105
tgi_docker_test: tpu-tgi
106106
python -m pip install -r text-generation-inference/integration-tests/requirements.txt

docs/source/howto/serving.mdx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ curl localhost/generate_stream \
5050
-H 'Content-Type: application/json'
5151
```
5252

53-
### Using Jetstream Pytorch as backend
53+
### Jetstream Pytorch and Pytorch XLA backends
54+
55+
[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.
56+
If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable.
5457

55-
[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable.
5658

5759
When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable.
5860

optimum/tpu/jetstream_pt_support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool:
88
"""Check if the necessary imports to use jetstream_pt are available.
99
"""
1010
try:
11-
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
12-
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
13-
if not jetstream_pt_enabled:
11+
# Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable.
12+
jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1"
13+
if jetstream_pt_disabled:
1414
return False
1515
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
1616
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:

text-generation-inference/server/text_generation_server/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC
171171
self._max_new_tokens = self._generation_config.max_new_tokens
172172
# TODO: stop_sequences, ignore_eos_token
173173

174-
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
174+
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, selector: TokenSelector = None):
175175
"""Reset the slot for the next generation.
176176
177177
Args:

text-generation-inference/tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44

5+
from optimum.tpu import jetstream_pt_available
6+
57

68
# See https://stackoverflow.com/a/61193490/217945 for run_slow
79
def pytest_addoption(parser):
@@ -33,3 +35,14 @@ def quantization_jetstream_int8():
3335
# Clean up
3436
os.environ.clear()
3537
os.environ.update(old_environ)
38+
39+
40+
def pytest_runtest_setup(item):
41+
marker_names = [marker.name for marker in item.own_markers]
42+
jetstream_pt_enabled = jetstream_pt_available()
43+
# Skip tests that require torch xla but not jetstream
44+
if "torch_xla" in marker_names and "jetstream" not in marker_names:
45+
if jetstream_pt_enabled:
46+
pytest.skip("Jetstream is enabled: xla test will be skipped")
47+
elif "jetstream" in marker_names and not jetstream_pt_enabled:
48+
pytest.skip("Test requires Jetstream PyTorch to be enabled")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[pytest]
2+
markers =
3+
jetstream: mark a test as a test that uses jetstream backend
4+
torch_xla: mark a test as a test that uses torch_xla backend

text-generation-inference/tests/test_decode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from decode_tests_utils import DecodeTestParams, decode_single_test
44

55

6+
# All tests in this file are for torch xla
7+
pytestmark = pytest.mark.torch_xla
8+
69
@pytest.mark.parametrize("params",
710
[
811
DecodeTestParams(
@@ -21,6 +24,7 @@
2124
def test_decode_single(params):
2225
decode_single_test(params)
2326

27+
2428
@pytest.mark.slow
2529
@pytest.mark.parametrize("params",
2630
[

text-generation-inference/tests/test_decode_jetstream.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import pytest
33
from decode_tests_utils import DecodeTestParams, decode_single_test
44

5-
from optimum.tpu.jetstream_pt_support import jetstream_pt_available
65

6+
# All tests in this file are for jetstream
7+
pytestmark = pytest.mark.jetstream
78

89
@pytest.mark.slow
910
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@@ -35,8 +36,6 @@
3536
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
3637
)
3738
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
38-
if not jetstream_pt_available():
39-
pytest.skip("Jetstream PyTorch is not available")
4039
params.do_sample = do_sample
4140
decode_single_test(params)
4241

@@ -64,7 +63,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
6463
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
6564
)
6665
def test_decode_single_jetstream_pytorch(params, do_sample):
67-
if not jetstream_pt_available():
68-
pytest.skip("Jetstream PyTorch is not available")
6966
params.do_sample = do_sample
7067
decode_single_test(params)

text-generation-inference/tests/test_decode_jetstream_quant.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import pytest
33
from decode_tests_utils import DecodeTestParams, decode_single_test
44

5-
from optimum.tpu.jetstream_pt_support import jetstream_pt_available
65

6+
# All tests in this file are for jetstream
7+
pytestmark = pytest.mark.jetstream
78

89
@pytest.mark.parametrize("params",
910
[
@@ -22,8 +23,6 @@
2223
ids=["gemma-2b", "TinyLLama-v0"],
2324
)
2425
def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
25-
if not jetstream_pt_available():
26-
pytest.skip("Jetstream PyTorch is not available")
2726
decode_single_test(params)
2827

2928

@@ -49,6 +48,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
4948
ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"],
5049
)
5150
def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params):
52-
if not jetstream_pt_available():
53-
pytest.skip("Jetstream PyTorch is not available")
5451
decode_single_test(params)
Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
1+
import numpy as np
12
import pytest
2-
import torch
33
from text_generation_server.pb.generate_pb2 import Request
44
from transformers import AutoTokenizer, GenerationConfig
55

66

77
TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"]
88

9-
10-
@pytest.fixture(params=TOKENIZERS)
11-
def tokenizer(request):
12-
t = AutoTokenizer.from_pretrained(request.param)
13-
t.padding_side = "left"
14-
t.pad_token_id = t.eos_token_id
15-
return t
16-
17-
18-
@pytest.mark.parametrize(
9+
# Defining this global variable will parametrize all tests in this file
10+
pytestmark = pytest.mark.parametrize(
1911
"input_text, generated_text",
2012
[
2113
[
@@ -29,26 +21,31 @@ def tokenizer(request):
2921
],
3022
ids=["spaces", "chinese-utf8", "emojis"],
3123
)
32-
def test_decode_streaming(tokenizer, input_text, generated_text):
33-
from text_generation_server.generator import Slot
34-
# Note: device used is cpu to make it faster
35-
slot = Slot(0, tokenizer, "cpu")
24+
25+
26+
@pytest.fixture(params=TOKENIZERS)
27+
def tokenizer(request):
28+
t = AutoTokenizer.from_pretrained(request.param)
29+
t.padding_side = "left"
30+
t.pad_token_id = t.eos_token_id
31+
return t
32+
33+
34+
def _test_decode_streaming(slot, return_tensors, tokenizer, input_text, generated_text):
3635
request = Request(id=0, inputs=input_text)
3736
slot.assign(0, request, GenerationConfig())
38-
assert slot.cached_text == input_text
3937

40-
inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")
38+
inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors=return_tensors)
4139
input_ids = inputs["input_ids"][0]
42-
attention_mask = inputs["attention_mask"][0]
4340
generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"]
4441

4542
# We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
46-
all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])
43+
all_input_ids = np.concatenate([input_ids, generated_tokens])
4744
full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
4845
regenerated_text = full_text[len(input_text) :]
4946

5047
# Initialize the slot with the inputs
51-
slot.reset(input_ids, attention_mask, selector=None)
48+
slot.reset(input_ids, selector=None)
5249

5350
assert slot.generated_tokens == 0
5451

@@ -60,3 +57,19 @@ def test_decode_streaming(tokenizer, input_text, generated_text):
6057
decoded_text += text
6158

6259
assert decoded_text == regenerated_text
60+
61+
62+
@pytest.mark.jetstream
63+
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
64+
from text_generation_server.jetstream_pt_support.generator import Slot
65+
66+
slot = Slot(0, tokenizer)
67+
_test_decode_streaming(slot, "np", tokenizer, input_text, generated_text)
68+
69+
@pytest.mark.torch_xla
70+
def test_decode_streaming(tokenizer, input_text, generated_text):
71+
from text_generation_server.generator import Slot
72+
73+
# Note: device used is cpu to make it faster
74+
slot = Slot(0, tokenizer, "cpu")
75+
_test_decode_streaming(slot, "pt", tokenizer, input_text, generated_text)

0 commit comments

Comments
 (0)