1
+ import asyncio
2
+ import contextlib
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import sys
7
+ import threading
8
+ import time
9
+ from tempfile import TemporaryDirectory
10
+
11
+ import docker
12
+ import pytest
13
+ from docker .errors import NotFound
14
+ import logging
15
+ from gaudi .test_embed import TEST_CONFIGS
16
+ import aiohttp
17
+
18
+ logging .basicConfig (
19
+ level = logging .INFO ,
20
+ format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
21
+ stream = sys .stdout ,
22
+ )
23
+ logger = logging .getLogger (__file__ )
24
+
25
+ # Use the latest image from the local docker build
26
+ DOCKER_IMAGE = os .getenv ("DOCKER_IMAGE" , "tei_hpu" )
27
+ DOCKER_VOLUME = os .getenv ("DOCKER_VOLUME" , None )
28
+
29
+ if DOCKER_VOLUME is None :
30
+ logger .warning (
31
+ "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
32
+ )
33
+
34
+ LOG_LEVEL = os .getenv ("LOG_LEVEL" , "info" )
35
+
36
+ BASE_ENV = {
37
+ "HF_HUB_ENABLE_HF_TRANSFER" : "1" ,
38
+ "LOG_LEVEL" : LOG_LEVEL ,
39
+ "HABANA_VISIBLE_DEVICES" : "all" ,
40
+ }
41
+
42
+ HABANA_RUN_ARGS = {
43
+ "runtime" : "habana" ,
44
+ }
45
+
46
+ def stream_container_logs (container , test_name ):
47
+ """Stream container logs in a separate thread."""
48
+ try :
49
+ for log in container .logs (stream = True , follow = True ):
50
+ print (
51
+ f"[TEI Server Logs - { test_name } ] { log .decode ('utf-8' )} " ,
52
+ end = "" ,
53
+ file = sys .stderr ,
54
+ flush = True ,
55
+ )
56
+ except Exception as e :
57
+ logger .error (f"Error streaming container logs: { str (e )} " )
58
+
59
+
60
+ class LauncherHandle :
61
+ def __init__ (self , port : int ):
62
+ self .port = port
63
+ self .base_url = f"http://127.0.0.1:{ port } "
64
+
65
+ async def generate (self , prompt : str ):
66
+ async with aiohttp .ClientSession () as session :
67
+ async with session .post (
68
+ f"{ self .base_url } /embed" ,
69
+ json = {"inputs" : prompt },
70
+ headers = {"Content-Type" : "application/json" }
71
+ ) as response :
72
+ if response .status != 200 :
73
+ error_text = await response .text ()
74
+ raise RuntimeError (f"Request failed with status { response .status } : { error_text } " )
75
+ return await response .json ()
76
+
77
+ def _inner_health (self ):
78
+ raise NotImplementedError
79
+
80
+ async def health (self , timeout : int = 60 ):
81
+ assert timeout > 0
82
+ start_time = time .time ()
83
+ logger .info (f"Starting health check with timeout of { timeout } s" )
84
+
85
+ for attempt in range (timeout ):
86
+ if not self ._inner_health ():
87
+ logger .error ("Launcher crashed during health check" )
88
+ raise RuntimeError ("Launcher crashed" )
89
+
90
+ try :
91
+ # Try to make a request using generate
92
+ await self .generate ("test" )
93
+ elapsed = time .time () - start_time
94
+ logger .info (f"Health check passed after { elapsed :.1f} s" )
95
+ return
96
+ except (aiohttp .ClientError , asyncio .TimeoutError ) as e :
97
+ if attempt == timeout - 1 :
98
+ logger .error (f"Health check failed after { timeout } s: { str (e )} " )
99
+ raise RuntimeError (f"Health check failed: { str (e )} " )
100
+ if attempt % 10 == 0 and attempt != 0 : # Only log every 10th attempt
101
+ logger .debug (f"Connection attempt { attempt } /{ timeout } failed: { str (e )} " )
102
+ await asyncio .sleep (1 )
103
+ except Exception as e :
104
+ logger .error (f"Unexpected error during health check: { str (e )} " )
105
+ import traceback
106
+ logger .error (f"Full traceback:\n { traceback .format_exc ()} " )
107
+ raise
108
+
109
+
110
+ class ContainerLauncherHandle (LauncherHandle ):
111
+ def __init__ (self , docker_client , container_name , port : int ):
112
+ super ().__init__ (port )
113
+ self .docker_client = docker_client
114
+ self .container_name = container_name
115
+
116
+ def _inner_health (self ) -> bool :
117
+ try :
118
+ container = self .docker_client .containers .get (self .container_name )
119
+ status = container .status
120
+ if status not in ["running" , "created" ]:
121
+ logger .warning (f"Container status is { status } " )
122
+ # Get container logs for debugging
123
+ logs = container .logs ().decode ("utf-8" )
124
+ logger .debug (f"Container logs:\n { logs } " )
125
+ return False
126
+ return True
127
+ except Exception as e :
128
+ logger .error (f"Error checking container health: { str (e )} " )
129
+ return False
130
+
131
+ class ProcessLauncherHandle (LauncherHandle ):
132
+ def __init__ (self , process , port : int ):
133
+ super (ProcessLauncherHandle , self ).__init__ (port )
134
+ self .process = process
135
+
136
+ def _inner_health (self ) -> bool :
137
+ return self .process .poll () is None
138
+
139
+
140
+ @pytest .fixture (scope = "module" )
141
+ def data_volume ():
142
+ tmpdir = TemporaryDirectory ()
143
+ yield tmpdir .name
144
+ try :
145
+ # Cleanup the temporary directory using sudo as it contains root files created by the container
146
+ subprocess .run (shlex .split (f"sudo rm -rf { tmpdir .name } " ), check = True )
147
+ except subprocess .CalledProcessError as e :
148
+ logger .error (f"Error cleaning up temporary directory: { str (e )} " )
149
+
150
+
151
+ @pytest .fixture (scope = "function" )
152
+ def gaudi_launcher (event_loop ):
153
+ @contextlib .contextmanager
154
+ def docker_launcher (
155
+ model_id : str ,
156
+ test_name : str ,
157
+ ):
158
+ logger .info (
159
+ f"Starting docker launcher for model { model_id } and test { test_name } "
160
+ )
161
+
162
+
163
+ port = 8080
164
+
165
+ client = docker .from_env ()
166
+
167
+ container_name = f"tei-hpu-test-{ test_name .replace ('/' , '-' )} "
168
+
169
+ try :
170
+ container = client .containers .get (container_name )
171
+ logger .info (
172
+ f"Stopping existing container { container_name } for test { test_name } "
173
+ )
174
+ container .stop ()
175
+ container .wait ()
176
+ except NotFound :
177
+ pass
178
+ except Exception as e :
179
+ logger .error (f"Error handling existing container: { str (e )} " )
180
+
181
+ tei_args = TEST_CONFIGS [test_name ]["args" ].copy ()
182
+
183
+ # add model_id to tei args
184
+ tei_args .append ("--model-id" )
185
+ tei_args .append (model_id )
186
+
187
+ env = BASE_ENV .copy ()
188
+ env ["HF_TOKEN" ] = os .getenv ("HF_TOKEN" )
189
+
190
+ # Add env config that is definied in the fixture parameter
191
+ if "env_config" in TEST_CONFIGS [test_name ]:
192
+ env .update (TEST_CONFIGS [test_name ]["env_config" ].copy ())
193
+
194
+ volumes = [f"{ DOCKER_VOLUME } :/data" ]
195
+ logger .debug (f"Using volume { volumes } " )
196
+
197
+ try :
198
+ logger .info (f"Creating container with name { container_name } " )
199
+
200
+ # Log equivalent docker run command for debugging, this is not actually executed
201
+ container = client .containers .run (
202
+ DOCKER_IMAGE ,
203
+ command = tei_args ,
204
+ name = container_name ,
205
+ environment = env ,
206
+ detach = True ,
207
+ volumes = volumes ,
208
+ ports = {"80/tcp" : port },
209
+ ** HABANA_RUN_ARGS ,
210
+ )
211
+
212
+ logger .info (f"Container { container_name } started successfully" )
213
+
214
+ # Start log streaming in a background thread
215
+ log_thread = threading .Thread (
216
+ target = stream_container_logs ,
217
+ args = (container , test_name ),
218
+ daemon = True , # This ensures the thread will be killed when the main program exits
219
+ )
220
+ log_thread .start ()
221
+
222
+ # Add a small delay to allow container to initialize
223
+ time .sleep (2 )
224
+
225
+ # Check container status after creation
226
+ status = container .status
227
+ logger .debug (f"Initial container status: { status } " )
228
+ if status not in ["running" , "created" ]:
229
+ logs = container .logs ().decode ("utf-8" )
230
+ logger .error (f"Container failed to start properly. Logs:\n { logs } " )
231
+
232
+ yield ContainerLauncherHandle (client , container .name , port )
233
+
234
+ except Exception as e :
235
+ logger .error (f"Error starting container: { str (e )} " )
236
+ # Get full traceback for debugging
237
+ import traceback
238
+
239
+ logger .error (f"Full traceback:\n { traceback .format_exc ()} " )
240
+ raise
241
+ finally :
242
+ try :
243
+ container = client .containers .get (container_name )
244
+ logger .info (f"Stopping container { container_name } " )
245
+ container .stop ()
246
+ container .wait ()
247
+
248
+ container_output = container .logs ().decode ("utf-8" )
249
+ print (container_output , file = sys .stderr )
250
+
251
+ container .remove ()
252
+ logger .info (f"Container { container_name } removed successfully" )
253
+ except NotFound :
254
+ pass
255
+ except Exception as e :
256
+ logger .warning (f"Error cleaning up container: { str (e )} " )
257
+
258
+ return docker_launcher
0 commit comments