Skip to content

Commit ae86f55

Browse files
Improving some internal error-handling (#846)
* refactor: use more errors.Is to allow errors to be wrapped more safely * refactor(dialer): wrap returned errors with context * refactor: decorate more errors * revert some error wrap changes to reduce PR noise * revert conn changes
1 parent ec59669 commit ae86f55

File tree

15 files changed

+66
-53
lines changed

15 files changed

+66
-53
lines changed

batch.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package kafka
22

33
import (
44
"bufio"
5+
"errors"
56
"io"
67
"sync"
78
"time"
@@ -82,7 +83,7 @@ func (batch *Batch) close() (err error) {
8283
batch.msgs.discard()
8384
}
8485

85-
if err = batch.err; err == io.EOF {
86+
if err = batch.err; errors.Is(batch.err, io.EOF) {
8687
err = nil
8788
}
8889

@@ -93,7 +94,8 @@ func (batch *Batch) close() (err error) {
9394
conn.mutex.Unlock()
9495

9596
if err != nil {
96-
if _, ok := err.(Error); !ok && err != io.ErrShortBuffer {
97+
var kafkaError Error
98+
if !errors.As(err, &kafkaError) && !errors.Is(err, io.ErrShortBuffer) {
9799
conn.Close()
98100
}
99101
}
@@ -238,11 +240,11 @@ func (batch *Batch) readMessage(
238240

239241
var lastOffset int64
240242
offset, lastOffset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val)
241-
switch err {
242-
case nil:
243+
switch {
244+
case err == nil:
243245
batch.offset = offset + 1
244246
batch.lastOffset = lastOffset
245-
case errShortRead:
247+
case errors.Is(err, errShortRead):
246248
// As an "optimization" kafka truncates the returned response after
247249
// producing MaxBytes, which could then cause the code to return
248250
// errShortRead.
@@ -272,7 +274,7 @@ func (batch *Batch) readMessage(
272274
// to MaxBytes truncation
273275
// - `batch.lastOffset` to ensure that the message format contains
274276
// `lastOffset`
275-
if batch.err == io.EOF && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
277+
if errors.Is(batch.err, io.EOF) && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
276278
// Log compaction can create batches that end with compacted
277279
// records so the normal strategy that increments the "next"
278280
// offset as records are read doesn't work as the compacted

batch_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package kafka
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"net"
78
"strconv"
@@ -30,11 +31,11 @@ func TestBatchDontExpectEOF(t *testing.T) {
3031

3132
batch := conn.ReadBatch(1024, 8192)
3233

33-
if _, err := batch.ReadMessage(); err != io.ErrUnexpectedEOF {
34+
if _, err := batch.ReadMessage(); !errors.Is(err, io.ErrUnexpectedEOF) {
3435
t.Error("bad error when reading message:", err)
3536
}
3637

37-
if err := batch.Close(); err != io.ErrUnexpectedEOF {
38+
if err := batch.Close(); !errors.Is(err, io.ErrUnexpectedEOF) {
3839
t.Error("bad error when closing the batch:", err)
3940
}
4041
}

client.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package kafka
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"net"
78
"time"
89

@@ -67,7 +68,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
6768
})
6869

6970
if err != nil {
70-
return nil, err
71+
return nil, fmt.Errorf("failed to get topic metadata :%w", err)
7172
}
7273

7374
topic := metadata.Topics[0]
@@ -85,7 +86,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
8586
})
8687

8788
if err != nil {
88-
return nil, err
89+
return nil, fmt.Errorf("failed to get offsets: %w", err)
8990
}
9091

9192
topicOffsets := offsets.Topics[topic.Name]

compress/snappy/xerial.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package snappy
33
import (
44
"bytes"
55
"encoding/binary"
6+
"errors"
67
"io"
78

89
"github.com/klauspost/compress/snappy"
@@ -64,7 +65,7 @@ func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
6465
}
6566

6667
if _, err := x.readChunk(nil); err != nil {
67-
if err == io.EOF {
68+
if errors.Is(err, io.EOF) {
6869
err = nil
6970
}
7071
return wn, err
@@ -128,7 +129,7 @@ func (x *xerialReader) readChunk(dst []byte) (int, error) {
128129
n, err := x.read(x.input[len(x.input):cap(x.input)])
129130
x.input = x.input[:len(x.input)+n]
130131
if err != nil {
131-
if err == io.EOF && len(x.input) > 0 {
132+
if errors.Is(err, io.EOF) && len(x.input) > 0 {
132133
break
133134
}
134135
return 0, err
@@ -212,7 +213,7 @@ func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
212213
}
213214

214215
if err != nil {
215-
if err == io.EOF {
216+
if errors.Is(err, io.EOF) {
216217
err = nil
217218
}
218219
return wn, err

conn.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
853853
default:
854854
throttle, highWaterMark, remain, err = readFetchResponseHeaderV2(&c.rbuf, size)
855855
}
856-
if err == errShortRead {
856+
if errors.Is(err, errShortRead) {
857857
err = checkTimeoutErr(adjustedDeadline)
858858
}
859859

@@ -865,9 +865,10 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
865865
msgs, err = newMessageSetReader(&c.rbuf, remain)
866866
}
867867
}
868-
if err == errShortRead {
868+
if errors.Is(err, errShortRead) {
869869
err = checkTimeoutErr(adjustedDeadline)
870870
}
871+
871872
return &Batch{
872873
conn: c,
873874
msgs: msgs,

conn_test.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package kafka
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"math/rand"
@@ -640,10 +641,13 @@ func testConnReadBatchWithMaxWait(t *testing.T, conn *Conn) {
640641
conn.Seek(0, SeekAbsolute)
641642
conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
642643
batch = conn.ReadBatchWith(cfg)
644+
var netErr net.Error
643645
if err := batch.Err(); err == nil {
644646
t.Fatal("should have timed out, but got no error")
645-
} else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
646-
t.Fatalf("should have timed out, but got: %v", err)
647+
} else if errors.As(err, &netErr) {
648+
if !netErr.Timeout() {
649+
t.Fatalf("should have timed out, but got: %v", err)
650+
}
647651
}
648652
}
649653

@@ -761,7 +765,7 @@ func testConnFindCoordinator(t *testing.T, conn *Conn) {
761765

762766
func testConnJoinGroupInvalidGroupID(t *testing.T, conn *Conn) {
763767
_, err := conn.joinGroup(joinGroupRequestV1{})
764-
if err != InvalidGroupId && err != NotCoordinatorForGroup {
768+
if !errors.Is(err, InvalidGroupId) && !errors.Is(err, NotCoordinatorForGroup) {
765769
t.Fatalf("expected %v or %v; got %v", InvalidGroupId, NotCoordinatorForGroup, err)
766770
}
767771
}
@@ -773,7 +777,7 @@ func testConnJoinGroupInvalidSessionTimeout(t *testing.T, conn *Conn) {
773777
_, err := conn.joinGroup(joinGroupRequestV1{
774778
GroupID: groupID,
775779
})
776-
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
780+
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
777781
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
778782
}
779783
}
@@ -786,7 +790,7 @@ func testConnJoinGroupInvalidRefreshTimeout(t *testing.T, conn *Conn) {
786790
GroupID: groupID,
787791
SessionTimeout: int32(3 * time.Second / time.Millisecond),
788792
})
789-
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
793+
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
790794
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
791795
}
792796
}
@@ -798,7 +802,7 @@ func testConnHeartbeatErr(t *testing.T, conn *Conn) {
798802
_, err := conn.syncGroup(syncGroupRequestV0{
799803
GroupID: groupID,
800804
})
801-
if err != UnknownMemberId && err != NotCoordinatorForGroup {
805+
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
802806
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
803807
}
804808
}
@@ -810,7 +814,7 @@ func testConnLeaveGroupErr(t *testing.T, conn *Conn) {
810814
_, err := conn.leaveGroup(leaveGroupRequestV0{
811815
GroupID: groupID,
812816
})
813-
if err != UnknownMemberId && err != NotCoordinatorForGroup {
817+
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
814818
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
815819
}
816820
}
@@ -822,7 +826,7 @@ func testConnSyncGroupErr(t *testing.T, conn *Conn) {
822826
_, err := conn.syncGroup(syncGroupRequestV0{
823827
GroupID: groupID,
824828
})
825-
if err != UnknownMemberId && err != NotCoordinatorForGroup {
829+
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
826830
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
827831
}
828832
}
@@ -985,7 +989,7 @@ func testConnReadShortBuffer(t *testing.T, conn *Conn) {
985989
b[3] = 0
986990

987991
n, err := conn.Read(b)
988-
if err != io.ErrShortBuffer {
992+
if !errors.Is(err, io.ErrShortBuffer) {
989993
t.Error("bad error:", i, err)
990994
}
991995
if n != 4 {
@@ -1061,7 +1065,7 @@ func testDeleteTopicsInvalidTopic(t *testing.T, conn *Conn) {
10611065
}
10621066
conn.SetDeadline(time.Now().Add(5 * time.Second))
10631067
err = conn.DeleteTopics("invalid-topic", topic)
1064-
if err != UnknownTopicOrPartition {
1068+
if !errors.Is(err, UnknownTopicOrPartition) {
10651069
t.Fatalf("expected UnknownTopicOrPartition error, but got %v", err)
10661070
}
10671071
partitions, err := conn.ReadPartitions(topic)
@@ -1154,7 +1158,7 @@ func TestUnsupportedSASLMechanism(t *testing.T) {
11541158
}
11551159
defer conn.Close()
11561160

1157-
if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism {
1161+
if err := conn.saslHandshake("FOO"); !errors.Is(err, UnsupportedSASLMechanism) {
11581162
t.Errorf("Expected UnsupportedSASLMechanism but got %v", err)
11591163
}
11601164
}

consumergroup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup
10261026
// assignments for the topic. this matches the behavior of the official
10271027
// clients: java, python, and librdkafka.
10281028
// a topic watcher can trigger a rebalance when the topic comes into being.
1029-
if err != nil && err != UnknownTopicOrPartition {
1029+
if err != nil && !errors.Is(err, UnknownTopicOrPartition) {
10301030
return nil, err
10311031
}
10321032

consumergroup_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func TestConsumerGroup(t *testing.T) {
285285
if gen != nil {
286286
t.Errorf("expected generation to be nil")
287287
}
288-
if err != context.Canceled {
288+
if !errors.Is(err, context.Canceled) {
289289
t.Errorf("expected context.Canceled, but got %+v", err)
290290
}
291291
},
@@ -301,7 +301,7 @@ func TestConsumerGroup(t *testing.T) {
301301
if gen != nil {
302302
t.Errorf("expected generation to be nil")
303303
}
304-
if err != ErrGroupClosed {
304+
if !errors.Is(err, ErrGroupClosed) {
305305
t.Errorf("expected ErrGroupClosed, but got %+v", err)
306306
}
307307
},
@@ -398,7 +398,7 @@ func TestConsumerGroupErrors(t *testing.T) {
398398
gen, err := group.Next(ctx)
399399
if err == nil {
400400
t.Errorf("expected an error")
401-
} else if err != NotCoordinatorForGroup {
401+
} else if !errors.Is(err, NotCoordinatorForGroup) {
402402
t.Errorf("got wrong error: %+v", err)
403403
}
404404
if gen != nil {
@@ -460,7 +460,7 @@ func TestConsumerGroupErrors(t *testing.T) {
460460
gen, err := group.Next(ctx)
461461
if err == nil {
462462
t.Errorf("expected an error")
463-
} else if err != InvalidTopic {
463+
} else if !errors.Is(err, InvalidTopic) {
464464
t.Errorf("got wrong error: %+v", err)
465465
}
466466
if gen != nil {
@@ -540,7 +540,7 @@ func TestConsumerGroupErrors(t *testing.T) {
540540
gen, err := group.Next(ctx)
541541
if err == nil {
542542
t.Errorf("expected an error")
543-
} else if err != InvalidTopic {
543+
} else if !errors.Is(err, InvalidTopic) {
544544
t.Errorf("got wrong error: %+v", err)
545545
}
546546
if gen != nil {
@@ -672,7 +672,7 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) {
672672
case <-time.After(time.Second):
673673
t.Fatal("timed out waiting for func to run")
674674
case err := <-ch:
675-
if err != ErrGenerationEnded {
675+
if !errors.Is(err, ErrGenerationEnded) {
676676
t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
677677
}
678678
}

0 commit comments

Comments
 (0)