diff --git a/ring.go b/ring.go index fe8a6dc47..8a004b8c0 100644 --- a/ring.go +++ b/ring.go @@ -847,3 +847,26 @@ func (c *Ring) Close() error { return c.sharding.Close() } + +// GetShardClients returns a list of all shard clients in the ring. +// This can be used to create dedicated connections (e.g., PubSub) for each shard. +func (c *Ring) GetShardClients() []*Client { + shards := c.sharding.List() + clients := make([]*Client, 0, len(shards)) + for _, shard := range shards { + if shard.IsUp() { + clients = append(clients, shard.Client) + } + } + return clients +} + +// GetShardClientForKey returns the shard client that would handle the given key. +// This can be used to determine which shard a particular key/channel would be routed to. +func (c *Ring) GetShardClientForKey(key string) (*Client, error) { + shard, err := c.sharding.GetByKey(key) + if err != nil { + return nil, err + } + return shard.Client, nil +} diff --git a/ring_test.go b/ring_test.go index 599f6888a..aaac74dc9 100644 --- a/ring_test.go +++ b/ring_test.go @@ -782,3 +782,82 @@ var _ = Describe("Ring Tx timeout", func() { testTimeout() }) }) + +var _ = Describe("Ring GetShardClients and GetShardClientForKey", func() { + var ring *redis.Ring + + BeforeEach(func() { + ring = redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "shard1": ":6379", + "shard2": ":6380", + }, + }) + }) + + AfterEach(func() { + Expect(ring.Close()).NotTo(HaveOccurred()) + }) + + It("GetShardClients returns active shard clients", func() { + shards := ring.GetShardClients() + // Note: This test will pass even if Redis servers are not running, + // because GetShardClients only returns clients that are marked as "up", + // and newly created shards start as "up" until the first health check fails. + + if len(shards) == 0 { + // Expected if Redis servers are not running + Skip("No active shards found (Redis servers not running)") + } else { + Expect(len(shards)).To(BeNumerically(">", 0)) + for _, client := range shards { + Expect(client).NotTo(BeNil()) + } + } + }) + + It("GetShardClientForKey returns correct shard for keys", func() { + testKeys := []string{"key1", "key2", "user:123", "channel:test"} + + for _, key := range testKeys { + client, err := ring.GetShardClientForKey(key) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + } + }) + + It("GetShardClientForKey is consistent for same key", func() { + key := "test:consistency" + + // Call GetShardClientForKey multiple times with the same key + // Should always return the same shard + var firstClient *redis.Client + for i := 0; i < 5; i++ { + client, err := ring.GetShardClientForKey(key) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + + if i == 0 { + firstClient = client + } else { + Expect(client.String()).To(Equal(firstClient.String())) + } + } + }) + + It("GetShardClientForKey distributes keys across shards", func() { + testKeys := []string{"key1", "key2", "key3", "key4", "key5"} + shardMap := make(map[string]int) + + for _, key := range testKeys { + client, err := ring.GetShardClientForKey(key) + Expect(err).NotTo(HaveOccurred()) + shardMap[client.String()]++ + } + + // Should have at least 1 shard (could be all keys go to same shard due to hashing) + Expect(len(shardMap)).To(BeNumerically(">=", 1)) + // But with multiple keys, we expect some distribution + Expect(len(shardMap)).To(BeNumerically("<=", 2)) // At most 2 shards (our setup) + }) +})