Skip to content

Commit bc3e92a

Browse files
authored
Enqaw drain pusher ws on reconnects (#2397)
* Draining pusher sockets on disconnects * Optimize pusher ws; auto-recovery on failure
1 parent 3e943fb commit bc3e92a

File tree

3 files changed

+29
-49
lines changed

3 files changed

+29
-49
lines changed

backend/routers/pusher.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,6 @@ async def _websocket_util_trigger(
2929
websocket_active = True
3030
websocket_close_code = 1000
3131

32-
# heart beat
33-
async def send_heartbeat():
34-
print("pusher send_heartbeat", uid)
35-
nonlocal websocket_active
36-
nonlocal websocket_close_code
37-
try:
38-
while websocket_active:
39-
if websocket.client_state == WebSocketState.CONNECTED:
40-
await websocket.send_json({"type": "ping"})
41-
else:
42-
break
43-
await asyncio.sleep(10)
44-
except WebSocketDisconnect:
45-
print("WebSocket disconnected")
46-
except Exception as e:
47-
print(f'Heartbeat error: {e}')
48-
websocket_close_code = 1011
49-
finally:
50-
websocket_active = False
51-
52-
# start heart beat
53-
heartbeat_task = asyncio.create_task(send_heartbeat())
54-
5532
loop = asyncio.get_event_loop()
5633

5734
# audio bytes
@@ -60,7 +37,7 @@ async def send_heartbeat():
6037
has_audio_apps_enabled = is_audio_bytes_app_enabled(uid)
6138

6239
# task
63-
async def receive_audio_bytes():
40+
async def receive_tasks():
6441
nonlocal websocket_active
6542
nonlocal websocket_close_code
6643

@@ -106,8 +83,8 @@ async def receive_audio_bytes():
10683
websocket_active = False
10784

10885
try:
109-
receive_task = asyncio.create_task(receive_audio_bytes())
110-
await asyncio.gather(receive_task, heartbeat_task)
86+
receive_task = asyncio.create_task(receive_tasks())
87+
await asyncio.gather(receive_task)
11188

11289
except Exception as e:
11390
print(f"Error during WebSocket operation: {e}")

backend/routers/transcribe.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,6 @@ def create_pusher_task_handler():
417417
pusher_connected = False
418418

419419
# Transcript
420-
transcript_ws = None
421420
segment_buffers = []
422421
in_progress_conversation_id = None
423422

@@ -431,11 +430,11 @@ async def transcript_consume():
431430
nonlocal websocket_active
432431
nonlocal segment_buffers
433432
nonlocal in_progress_conversation_id
434-
nonlocal transcript_ws
433+
nonlocal pusher_ws
435434
nonlocal pusher_connected
436435
while websocket_active or len(segment_buffers) > 0:
437436
await asyncio.sleep(1)
438-
if transcript_ws and len(segment_buffers) > 0:
437+
if pusher_connected and pusher_ws and len(segment_buffers) > 0:
439438
try:
440439
# 102|data
441440
data = bytearray()
@@ -444,17 +443,16 @@ async def transcript_consume():
444443
bytes(json.dumps({"segments": segment_buffers, "memory_id": in_progress_conversation_id}),
445444
"utf-8"))
446445
segment_buffers = [] # reset
447-
await transcript_ws.send(data)
446+
await pusher_ws.send(data)
448447
except websockets.exceptions.ConnectionClosed as e:
449448
print(f"Pusher transcripts Connection closed: {e}", uid)
450-
transcript_ws = None
451449
pusher_connected = False
452-
await reconnect()
453450
except Exception as e:
454451
print(f"Pusher transcripts failed: {e}", uid)
452+
if pusher_connected is False:
453+
await connect()
455454

456455
# Audio bytes
457-
audio_bytes_ws = None
458456
audio_buffers = bytearray()
459457
audio_bytes_enabled = bool(get_audio_bytes_webhook_seconds(uid)) or is_audio_bytes_app_enabled(uid)
460458

@@ -465,52 +463,56 @@ def audio_bytes_send(audio_bytes):
465463
async def audio_bytes_consume():
466464
nonlocal websocket_active
467465
nonlocal audio_buffers
468-
nonlocal audio_bytes_ws
466+
nonlocal pusher_ws
469467
nonlocal pusher_connected
470468
while websocket_active or len(audio_buffers) > 0:
471469
await asyncio.sleep(1)
472-
if audio_bytes_ws and len(audio_buffers) > 0:
470+
if pusher_connected and pusher_ws and len(audio_buffers) > 0:
473471
try:
474472
# 101|data
475473
data = bytearray()
476474
data.extend(struct.pack("I", 101))
477475
data.extend(audio_buffers.copy())
478476
audio_buffers = bytearray() # reset
479-
await audio_bytes_ws.send(data)
477+
await pusher_ws.send(data)
480478
except websockets.exceptions.ConnectionClosed as e:
481479
print(f"Pusher audio_bytes Connection closed: {e}", uid)
482-
audio_bytes_ws = None
483480
pusher_connected = False
484-
await reconnect()
485481
except Exception as e:
486482
print(f"Pusher audio_bytes failed: {e}", uid)
483+
if pusher_connected is False:
484+
await connect()
487485

488-
async def reconnect():
486+
async def connect():
489487
nonlocal pusher_connected
490488
nonlocal pusher_connect_lock
489+
nonlocal pusher_ws
491490
async with pusher_connect_lock:
492491
if pusher_connected:
493492
return
494-
await connect()
493+
# drain
494+
if pusher_ws:
495+
try:
496+
await pusher_ws.close()
497+
pusher_ws = None
498+
except Exception as e:
499+
print(f"Pusher draining failed: {e}", uid)
500+
# connect
501+
await _connect()
495502

496-
async def connect():
503+
async def _connect():
497504
nonlocal pusher_ws
498-
nonlocal transcript_ws
499-
nonlocal audio_bytes_ws
500-
nonlocal audio_bytes_enabled
501505
nonlocal pusher_connected
502506

503507
try:
504508
pusher_ws = await connect_to_trigger_pusher(uid, sample_rate)
505509
pusher_connected = True
506-
transcript_ws = pusher_ws
507-
if audio_bytes_enabled:
508-
audio_bytes_ws = pusher_ws
509510
except Exception as e:
510511
print(f"Exception in connect: {e}")
511512

512513
async def close(code: int = 1000):
513-
await pusher_ws.close(code)
514+
if pusher_ws:
515+
await pusher_ws.close(code)
514516

515517
return (connect, close,
516518
transcript_send, transcript_consume,

backend/utils/pusher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ async def _connect_to_trigger_pusher(uid: str, sample_rate: int = 8000):
2424
try:
2525
print("Connecting to Pusher transcripts trigger WebSocket...", uid)
2626
ws_host = PusherAPI.replace("http", "ws")
27-
socket = await websockets.connect(f"{ws_host}/v1/trigger/listen?uid={uid}&sample_rate={sample_rate}")
27+
socket = await websockets.connect(f"{ws_host}/v1/trigger/listen?uid={uid}&sample_rate={sample_rate}",
28+
ping_interval=15,)
2829
print("Connected to Pusher transcripts trigger WebSocket.", uid)
2930
return socket
3031
except Exception as e:

0 commit comments

Comments
 (0)