Skip to content

Commit 9646182

Browse files
committed
feat: ✨ fix redis stream empty subscribe bug && supported redis stream cached(#148)
1 parent a422d8a commit 9646182

File tree

12 files changed

+178
-27
lines changed

12 files changed

+178
-27
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ test.db
77
venv/
88
build/
99
dist/
10+
.idea/
11+
.vscode/

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Python 3.8+
8484
* `Broadcast('memory://')`
8585
* `Broadcast("redis://localhost:6379")`
8686
* `Broadcast("redis-stream://localhost:6379")`
87+
* `Broadcast("redis-stream-cached://localhost:6379")`
8788
* `Broadcast("postgres://localhost:5432/broadcaster")`
8889
* `Broadcast("kafka://localhost:9092")`
8990

broadcaster/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from ._base import Broadcast, Event
1+
from ._base import Broadcast
2+
from ._event import Event
23
from .backends.base import BroadcastBackend
34

45
__version__ = "0.3.1"

broadcaster/_base.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,12 @@
55
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast
66
from urllib.parse import urlparse
77

8-
if TYPE_CHECKING: # pragma: no cover
9-
from broadcaster.backends.base import BroadcastBackend
10-
11-
12-
class Event:
13-
def __init__(self, channel: str, message: str) -> None:
14-
self.channel = channel
15-
self.message = message
8+
from broadcaster.backends.base import BroadcastCacheBackend
169

17-
def __eq__(self, other: object) -> bool:
18-
return isinstance(other, Event) and self.channel == other.channel and self.message == other.message
10+
from ._event import Event
1911

20-
def __repr__(self) -> str:
21-
return f"Event(channel={self.channel!r}, message={self.message!r})"
12+
if TYPE_CHECKING: # pragma: no cover
13+
from broadcaster.backends.base import BroadcastBackend
2214

2315

2416
class Unsubscribed(Exception):
@@ -43,6 +35,11 @@ def _create_backend(self, url: str) -> BroadcastBackend:
4335

4436
return RedisStreamBackend(url)
4537

38+
elif parsed_url.scheme == "redis-stream-cached":
39+
from broadcaster.backends.redis import RedisStreamCachedBackend
40+
41+
return RedisStreamCachedBackend(url)
42+
4643
elif parsed_url.scheme in ("postgres", "postgresql"):
4744
from broadcaster.backends.postgres import PostgresBackend
4845

@@ -87,15 +84,28 @@ async def publish(self, channel: str, message: Any) -> None:
8784
await self._backend.publish(channel, message)
8885

8986
@asynccontextmanager
90-
async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
87+
async def subscribe(self, channel: str, history: int | None = None) -> AsyncIterator[Subscriber]:
9188
queue: asyncio.Queue[Event | None] = asyncio.Queue()
9289

9390
try:
9491
if not self._subscribers.get(channel):
9592
await self._backend.subscribe(channel)
9693
self._subscribers[channel] = {queue}
9794
else:
98-
self._subscribers[channel].add(queue)
95+
if isinstance(self._backend, BroadcastCacheBackend):
96+
try:
97+
current_id = await self._backend.get_current_channel_id(channel)
98+
self._backend._ready.clear()
99+
messages = await self._backend.get_history_messages(channel, current_id, history)
100+
for message in messages:
101+
queue.put_nowait(message)
102+
self._subscribers[channel].add(queue)
103+
finally:
104+
# wake up the listener after inqueue history messages
105+
# for sorted messages by publish time
106+
self._backend._ready.set()
107+
else:
108+
self._subscribers[channel].add(queue)
99109

100110
yield Subscriber(queue)
101111
finally:

broadcaster/_event.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Event:
2+
def __init__(self, channel: str, message: str) -> None:
3+
self.channel = channel
4+
self.message = message
5+
6+
def __eq__(self, other: object) -> bool:
7+
return isinstance(other, Event) and self.channel == other.channel and self.message == other.message
8+
9+
def __repr__(self) -> str:
10+
return f"Event(channel={self.channel!r}, message={self.message!r})"

broadcaster/backends/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Any
1+
from __future__ import annotations
22

3-
from .._base import Event
3+
from typing import Any, AsyncGenerator
4+
5+
from .._event import Event
46

57

68
class BroadcastBackend:
@@ -24,3 +26,16 @@ async def publish(self, channel: str, message: Any) -> None:
2426

2527
async def next_published(self) -> Event:
2628
raise NotImplementedError()
29+
30+
31+
class BroadcastCacheBackend(BroadcastBackend):
32+
async def get_current_channel_id(self, channel: str):
33+
raise NotImplementedError()
34+
35+
async def get_history_messages(
36+
self,
37+
channel: str,
38+
msg_id: int | bytes | str | memoryview,
39+
count: int | None = None,
40+
) -> AsyncGenerator[Event, None]:
41+
raise NotImplementedError()

broadcaster/backends/kafka.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
88

9-
from .._base import Event
9+
from .._event import Event
1010
from .base import BroadcastBackend
1111

1212

broadcaster/backends/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import typing
55

6-
from .._base import Event
6+
from .._event import Event
77
from .base import BroadcastBackend
88

99

broadcaster/backends/postgres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import asyncpg
55

6-
from .._base import Event
6+
from .._event import Event
77
from .base import BroadcastBackend
88

99

broadcaster/backends/redis.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from redis import asyncio as redis
77

8-
from .._base import Event
9-
from .base import BroadcastBackend
8+
from .._event import Event
9+
from .base import BroadcastBackend, BroadcastCacheBackend
1010

1111

1212
class RedisBackend(BroadcastBackend):
@@ -88,14 +88,71 @@ async def subscribe(self, channel: str) -> None:
8888

8989
async def unsubscribe(self, channel: str) -> None:
9090
self.streams.pop(channel, None)
91+
if not self.streams:
92+
self._ready.clear()
93+
94+
async def publish(self, channel: str, message: typing.Any) -> None:
95+
await self._producer.xadd(channel, {"message": message})
96+
97+
async def wait_for_messages(self) -> list[StreamMessageType]:
98+
messages = None
99+
while not messages:
100+
if not self.streams:
101+
# 1. save cpu usage
102+
# 2. redis raise expection when self.streams is empty
103+
self._ready.clear()
104+
await self._ready.wait()
105+
messages = await self._consumer.xread(self.streams, count=1, block=100)
106+
return messages
107+
108+
async def next_published(self) -> Event:
109+
messages = await self.wait_for_messages()
110+
stream, events = messages[0]
111+
_msg_id, message = events[0]
112+
self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8")
113+
return Event(
114+
channel=stream.decode("utf-8"),
115+
message=message.get(b"message", b"").decode("utf-8"),
116+
)
117+
118+
119+
class RedisStreamCachedBackend(BroadcastCacheBackend):
120+
def __init__(self, url: str):
121+
url = url.replace("redis-stream-cached", "redis", 1)
122+
self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {}
123+
self._ready = asyncio.Event()
124+
self._producer = redis.Redis.from_url(url)
125+
self._consumer = redis.Redis.from_url(url)
126+
127+
async def connect(self) -> None:
128+
pass
129+
130+
async def disconnect(self) -> None:
131+
await self._producer.aclose()
132+
await self._consumer.aclose()
133+
134+
async def subscribe(self, channel: str) -> None:
135+
# read from beginning
136+
last_id = "0"
137+
self.streams[channel] = last_id
138+
self._ready.set()
139+
140+
async def unsubscribe(self, channel: str) -> None:
141+
self.streams.pop(channel, None)
142+
if not self.streams:
143+
self._ready.clear()
91144

92145
async def publish(self, channel: str, message: typing.Any) -> None:
93146
await self._producer.xadd(channel, {"message": message})
94147

95148
async def wait_for_messages(self) -> list[StreamMessageType]:
96-
await self._ready.wait()
97149
messages = None
98150
while not messages:
151+
if not self.streams:
152+
# 1. save cpu usage
153+
# 2. redis raise expection when self.streams is empty
154+
self._ready.clear()
155+
await self._ready.wait()
99156
messages = await self._consumer.xread(self.streams, count=1, block=100)
100157
return messages
101158

@@ -108,3 +165,24 @@ async def next_published(self) -> Event:
108165
channel=stream.decode("utf-8"),
109166
message=message.get(b"message", b"").decode("utf-8"),
110167
)
168+
169+
async def get_current_channel_id(self, channel: str):
170+
try:
171+
info = await self._consumer.xinfo_stream(channel)
172+
last_id = info["last-generated-id"]
173+
except redis.ResponseError:
174+
last_id = "0"
175+
return last_id
176+
177+
async def get_history_messages(
178+
self,
179+
channel: str,
180+
msg_id: int | bytes | str | memoryview,
181+
count: int | None = None,
182+
) -> typing.AsyncGenerator[Event, None]:
183+
messages = await self._consumer.xrevrange(channel, max=msg_id, count=count)
184+
for _, message in reversed(messages or []):
185+
yield Event(
186+
channel=channel,
187+
message=message.get(b"message", b"").decode("utf-8"),
188+
)

example/app.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import os
1+
from __future__ import annotations
22

3+
import os
4+
from pathlib import Path
35

46
import anyio
57
from starlette.applications import Starlette
@@ -10,8 +12,9 @@
1012

1113
BROADCAST_URL = os.environ.get("BROADCAST_URL", "memory://")
1214

13-
broadcast = Broadcast(BROADCAST_URL)
14-
templates = Jinja2Templates("example/templates")
15+
templates_dir = Path(__file__).parent / "templates"
16+
broadcast = Broadcast("redis-stream-cached://localhost:6379/8")
17+
templates = Jinja2Templates(templates_dir)
1518

1619

1720
async def homepage(request):
@@ -51,5 +54,13 @@ async def chatroom_ws_sender(websocket):
5154

5255

5356
app = Starlette(
54-
routes=routes, on_startup=[broadcast.connect], on_shutdown=[broadcast.disconnect],
57+
routes=routes,
58+
on_startup=[broadcast.connect],
59+
on_shutdown=[broadcast.disconnect],
5560
)
61+
62+
63+
if __name__ == "__main__":
64+
import uvicorn
65+
66+
uvicorn.run(app, host="0.0.0.0", port=7777)

tests/test_broadcast.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,29 @@ async def test_redis_stream():
7171
assert event.message == "hello"
7272

7373

74+
@pytest.mark.asyncio
75+
async def test_redis_stream_cache():
76+
messages = ["hello", "I'm cached"]
77+
async with Broadcast("redis-stream-cached://localhost:6379") as broadcast:
78+
await broadcast.publish("chatroom_cached", messages[0])
79+
await broadcast.publish("chatroom_cached", messages[1])
80+
await broadcast.publish("chatroom_cached", "quit")
81+
sub1_messages = []
82+
async with broadcast.subscribe("chatroom_cached") as subscriber:
83+
async for event in subscriber:
84+
if event.message == "quit":
85+
break
86+
sub1_messages.append(event.message)
87+
sub2_messages = []
88+
async with broadcast.subscribe("chatroom_cached") as subscriber:
89+
async for event in subscriber:
90+
if event.message == "quit":
91+
break
92+
sub2_messages.append(event.message)
93+
94+
assert sub1_messages == sub2_messages == messages
95+
96+
7497
@pytest.mark.asyncio
7598
async def test_postgres():
7699
async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast:

0 commit comments

Comments
 (0)