Skip to content

Commit 44eaf79

Browse files
authored
Accept context in core GPBFT interfaces (#965)
Accept the context from caller in core GPBFT messages instead of shadowing it with internal running context. Separately, Refactor adversary types to reduce duplicate code and pass existing context in metric collection wherever possible.
1 parent 70dab9c commit 44eaf79

31 files changed

+311
-376
lines changed

blssig/aggregation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (a *aggregation) VerifyAggregate(mask []int, msg []byte, signature []byte)
8282
}
8383

8484
metrics.verifyAggregate.Record(
85-
context.TODO(), int64(len(mask)),
85+
context.Background(), int64(len(mask)),
8686
metric.WithAttributes(status),
8787
)
8888
}()

certchain/certchain.go

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func New(o ...Option) (*CertChain, error) {
4545
}, nil
4646
}
4747

48-
func (cc *CertChain) GetCommittee(instance uint64) (*gpbft.Committee, error) {
48+
func (cc *CertChain) GetCommittee(ctx context.Context, instance uint64) (*gpbft.Committee, error) {
4949
var committeeEpoch int64
5050
if instance < cc.m.InitialInstance+cc.m.CommitteeLookback {
5151
committeeEpoch = cc.m.BootstrapEpoch - cc.m.EC.Finality
@@ -54,31 +54,27 @@ func (cc *CertChain) GetCommittee(instance uint64) (*gpbft.Committee, error) {
5454
certAtLookback := cc.certificates[lookbackIndex]
5555
committeeEpoch = certAtLookback.ECChain.Head().Epoch
5656
}
57-
//TODO refactor CommitteeProvider in gpbft to take context.
58-
ctx := context.TODO()
5957
tspt, err := cc.getTipSetWithPowerTableByEpoch(ctx, committeeEpoch)
6058
if err != nil {
6159
return nil, err
6260
}
6361
return cc.getCommittee(tspt)
6462
}
6563

66-
func (cc *CertChain) GetProposal(instance uint64) (*gpbft.SupplementalData, *gpbft.ECChain, error) {
67-
//TODO refactor ProposalProvider in gpbft to take context.
68-
ctx := context.TODO()
64+
func (cc *CertChain) GetProposal(ctx context.Context, instance uint64) (*gpbft.SupplementalData, *gpbft.ECChain, error) {
6965
proposal, err := cc.generateProposal(ctx, instance)
7066
if err != nil {
7167
return nil, nil, err
7268
}
73-
suppData, err := cc.getSupplementalData(instance)
69+
suppData, err := cc.getSupplementalData(ctx, instance)
7470
if err != nil {
7571
return nil, nil, err
7672
}
7773
return suppData, proposal, nil
7874
}
7975

80-
func (cc *CertChain) getSupplementalData(instance uint64) (*gpbft.SupplementalData, error) {
81-
nextCommittee, err := cc.GetCommittee(instance + 1)
76+
func (cc *CertChain) getSupplementalData(ctx context.Context, instance uint64) (*gpbft.SupplementalData, error) {
77+
nextCommittee, err := cc.GetCommittee(ctx, instance+1)
8278
if err != nil {
8379
return nil, err
8480
}
@@ -270,13 +266,13 @@ func (cc *CertChain) Generate(ctx context.Context, length uint64) ([]*certs.Fina
270266
}
271267

272268
instance := cc.m.InitialInstance
273-
committee, err := cc.GetCommittee(instance)
269+
committee, err := cc.GetCommittee(ctx, instance)
274270
if err != nil {
275271
return nil, err
276272
}
277273
var nextCommittee *gpbft.Committee
278274
for range length {
279-
suppData, proposal, err := cc.GetProposal(instance)
275+
suppData, proposal, err := cc.GetProposal(ctx, instance)
280276
if err != nil {
281277
return nil, err
282278
}
@@ -291,7 +287,7 @@ func (cc *CertChain) Generate(ctx context.Context, length uint64) ([]*certs.Fina
291287
return nil, err
292288
}
293289

294-
nextCommittee, err = cc.GetCommittee(instance + 1)
290+
nextCommittee, err = cc.GetCommittee(ctx, instance+1)
295291
if err != nil {
296292
return nil, err
297293
}
@@ -314,14 +310,14 @@ func (cc *CertChain) Validate(ctx context.Context, crts []*certs.FinalityCertifi
314310
for _, cert := range crts {
315311
instance := cert.GPBFTInstance
316312
proposal := cert.ECChain
317-
suppData, err := cc.getSupplementalData(instance)
313+
suppData, err := cc.getSupplementalData(ctx, instance)
318314
if err != nil {
319315
return err
320316
}
321317
if !suppData.Eq(&cert.SupplementalData) {
322318
return fmt.Errorf("supplemental data mismatch at instance %d", instance)
323319
}
324-
committee, err := cc.GetCommittee(instance)
320+
committee, err := cc.GetCommittee(ctx, instance)
325321
if err != nil {
326322
return fmt.Errorf("getting committee for instance %d: %w", instance, err)
327323
}

certchain/certchain_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestCertChain_GenerateAndVerify(t *testing.T) {
7373
generatedChain, err := subject.Generate(ctx, certChainLength)
7474
require.NoError(t, err)
7575

76-
initialCommittee, err := subject.GetCommittee(m.InitialInstance)
76+
initialCommittee, err := subject.GetCommittee(ctx, m.InitialInstance)
7777
require.NoError(t, err)
7878

7979
nextInstance, _, _, err := certs.ValidateFinalityCertificates(

certexchange/polling/poller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//go:generate go run golang.org/x/tools/cmd/stringer@v0.22.0 -type=PollStatus
1+
//go:generate go run golang.org/x/tools/cmd/stringer@v0.32.0 -type=PollStatus
22
package polling
33

44
import (

consensus_inputs.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func (h *gpbftInputs) GetProposal(ctx context.Context, instance uint64) (_ *gpbf
191191

192192
func (h *gpbftInputs) GetCommittee(ctx context.Context, instance uint64) (_ *gpbft.Committee, _err error) {
193193
defer func(start time.Time) {
194-
metrics.committeeFetchTime.Record(context.TODO(), time.Since(start).Seconds(), metric.WithAttributes(attrStatusFromErr(_err)))
194+
metrics.committeeFetchTime.Record(ctx, time.Since(start).Seconds(), metric.WithAttributes(attrStatusFromErr(_err)))
195195
}(time.Now())
196196

197197
var powerTsk gpbft.TipSetKey

emulator/driver.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (d *Driver) PeekLastBroadcastRequest() *gpbft.GMessage {
6868

6969
func (d *Driver) DeliverAlarm() (bool, error) {
7070
if d.host.maybeReceiveAlarm() {
71-
return true, d.subject.ReceiveAlarm()
71+
return true, d.subject.ReceiveAlarm(context.Background())
7272
}
7373
return false, nil
7474
}
@@ -98,9 +98,10 @@ func (d *Driver) prepareMessage(partialMessage *gpbft.GMessage) *gpbft.GMessage
9898
}
9999

100100
func (d *Driver) deliverMessage(msg *gpbft.GMessage) error {
101-
if validated, err := d.subject.ValidateMessage(msg); err != nil {
101+
ctx := context.Background()
102+
if validated, err := d.subject.ValidateMessage(ctx, msg); err != nil {
102103
return err
103104
} else {
104-
return d.subject.ReceiveMessage(validated)
105+
return d.subject.ReceiveMessage(ctx, validated)
105106
}
106107
}

emulator/host.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (h *driverHost) RequestRebroadcast(instant gpbft.Instant) error {
6262
return nil
6363
}
6464

65-
func (h *driverHost) ReceiveDecision(decision *gpbft.Justification) (time.Time, error) {
65+
func (h *driverHost) ReceiveDecision(_ context.Context, decision *gpbft.Justification) (time.Time, error) {
6666
require.NoError(h.t, h.maybeReceiveDecision(decision))
6767
return h.now, nil
6868
}
@@ -79,15 +79,15 @@ func (h *driverHost) maybeReceiveDecision(decision *gpbft.Justification) error {
7979
}
8080
}
8181

82-
func (h *driverHost) GetProposal(id uint64) (*gpbft.SupplementalData, *gpbft.ECChain, error) {
82+
func (h *driverHost) GetProposal(_ context.Context, id uint64) (*gpbft.SupplementalData, *gpbft.ECChain, error) {
8383
instance := h.chain[id]
8484
if instance == nil {
8585
return nil, nil, fmt.Errorf("instance ID %d not found", id)
8686
}
8787
return &instance.supplementalData, instance.Proposal(), nil
8888
}
8989

90-
func (h *driverHost) GetCommittee(id uint64) (*gpbft.Committee, error) {
90+
func (h *driverHost) GetCommittee(_ context.Context, id uint64) (*gpbft.Committee, error) {
9191
instance := h.chain[id]
9292
if instance == nil {
9393
return nil, fmt.Errorf("instance ID %d not found", id)

gpbft/api.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@ type Instant struct {
1414
}
1515

1616
type MessageValidator interface {
17-
// Validates a Granite message.
18-
// An invalid message can never become valid, so may be dropped.
19-
// Returns an error, wrapping (use errors.Is()/Unwrap()):
20-
// - ErrValidationTooOld if the message is for a prior instance;
21-
// - both ErrValidationNoCommittee and an error describing the reason;
22-
// if there is no committee available with with to validate the message;
23-
// - both ErrValidationInvalid and a cause if the message is invalid,
24-
// Returns a validated message if the message is valid.
17+
// ValidateMessage validates a GPBFT message.
18+
//
19+
// An invalid message can never become valid, which may be dropped. If a message
20+
// is invalid, an error of type ValidationError is returned, wrapping the cause.
21+
// Otherwise, returns a validated message that may be passed onto MessageReceiver
22+
// for processing.
2523
//
2624
// Implementations must be safe for concurrent use.
27-
ValidateMessage(msg *GMessage) (valid ValidatedMessage, err error)
25+
ValidateMessage(context.Context, *GMessage) (valid ValidatedMessage, err error)
2826
}
2927

3028
// Opaque type tagging a validated message.
@@ -43,9 +41,9 @@ type MessageReceiver interface {
4341
// - ErrValidationWrongBase if the message has an invalid base chain
4442
// - ErrReceivedAfterTermination if the message is received after the instance has terminated (a programming error)
4543
// - both ErrReceivedInternalError and a cause if there was an internal error processing the message
46-
ReceiveMessage(msg ValidatedMessage) error
44+
ReceiveMessage(ctx context.Context, msg ValidatedMessage) error
4745
// ReceiveAlarm signals the trigger of the alarm set by Clock.SetAlarm.
48-
ReceiveAlarm() error
46+
ReceiveAlarm(ctx context.Context) error
4947
}
5048

5149
// Interface from host to a network participant.
@@ -70,7 +68,7 @@ type ProposalProvider interface {
7068
// supplemental data.
7169
//
7270
// Returns an error if the chain for the specified instance is not available.
73-
GetProposal(instance uint64) (data *SupplementalData, chain *ECChain, err error)
71+
GetProposal(ctx context.Context, instance uint64) (data *SupplementalData, chain *ECChain, err error)
7472
}
7573

7674
// CommitteeProvider defines an interface for retrieving committee information
@@ -81,7 +79,7 @@ type CommitteeProvider interface {
8179
// final, with the offset determined by the host.
8280
//
8381
// Returns an error if the committee is unavailable for the specified instance.
84-
GetCommittee(instance uint64) (*Committee, error)
82+
GetCommittee(ctx context.Context, instance uint64) (*Committee, error)
8583
}
8684

8785
// Committee captures the voting power and beacon value associated to an instance
@@ -153,7 +151,7 @@ type DecisionReceiver interface {
153151
// The notification must return the timestamp at which the next instance should begin,
154152
// based on the decision received (which may be in the past).
155153
// E.g. this might be: finalised tipset timestamp + epoch duration + stabilisation delay.
156-
ReceiveDecision(decision *Justification) (time.Time, error)
154+
ReceiveDecision(ctx context.Context, decision *Justification) (time.Time, error)
157155
}
158156

159157
// Tracer collects trace logs that capture logical state changes.

gpbft/committee.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gpbft
22

33
import (
4+
"context"
45
"fmt"
56
"sync"
67
)
@@ -22,13 +23,13 @@ func newCachedCommitteeProvider(delegate CommitteeProvider) *cachedCommitteeProv
2223
}
2324
}
2425

25-
func (c *cachedCommitteeProvider) GetCommittee(instance uint64) (*Committee, error) {
26+
func (c *cachedCommitteeProvider) GetCommittee(ctx context.Context, instance uint64) (*Committee, error) {
2627
c.mu.Lock()
2728
defer c.mu.Unlock()
2829
if committee, found := c.committees[instance]; found {
2930
return committee, nil
3031
}
31-
switch committee, err := c.delegate.GetCommittee(instance); {
32+
switch committee, err := c.delegate.GetCommittee(ctx, instance); {
3233
case err != nil:
3334
return nil, fmt.Errorf("instance %d: %w: %w", instance, ErrValidationNoCommittee, err)
3435
case committee == nil:

gpbft/committee_test.go

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gpbft
22

33
import (
4+
"context"
45
"errors"
56
"math/rand/v2"
67
"testing"
@@ -17,8 +18,8 @@ type mockCommitteeProvider struct {
1718
mock.Mock
1819
}
1920

20-
func (m *mockCommitteeProvider) GetCommittee(instance uint64) (*Committee, error) {
21-
args := m.Called(instance)
21+
func (m *mockCommitteeProvider) GetCommittee(ctx context.Context, instance uint64) (*Committee, error) {
22+
args := m.Called(ctx, instance)
2223
if committee, ok := args.Get(0).(*Committee); ok {
2324
return committee, args.Error(1)
2425
}
@@ -27,6 +28,7 @@ func (m *mockCommitteeProvider) GetCommittee(instance uint64) (*Committee, error
2728

2829
func TestCachedCommitteeProvider_GetCommittee(t *testing.T) {
2930
var (
31+
ctx = context.Background()
3032
instance1 = uint64(1)
3133
instance2 = uint64(2)
3234
instance3 = uint64(3)
@@ -53,64 +55,64 @@ func TestCachedCommitteeProvider_GetCommittee(t *testing.T) {
5355
subject = newCachedCommitteeProvider(mockDelegate)
5456
)
5557

56-
mockDelegate.On("GetCommittee", instance1).Return(committeeWithValidPowerTable, nil)
58+
mockDelegate.On("GetCommittee", mock.Anything, instance1).Return(committeeWithValidPowerTable, nil)
5759
t.Run("delegates cache miss", func(t *testing.T) {
58-
result, err := subject.GetCommittee(1)
60+
result, err := subject.GetCommittee(ctx, 1)
5961
require.NoError(t, err)
6062
require.Equal(t, committeeWithValidPowerTable, result)
61-
mockDelegate.AssertCalled(t, "GetCommittee", instance1)
63+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance1)
6264
})
6365
t.Run("caches", func(t *testing.T) {
64-
result, err := subject.GetCommittee(1)
66+
result, err := subject.GetCommittee(ctx, 1)
6567
require.NoError(t, err)
6668
require.Equal(t, committeeWithValidPowerTable, result)
6769
mockDelegate.AssertNotCalled(t, "GetCommittee")
6870
})
6971
t.Run("delegates error", func(t *testing.T) {
7072
wantErr := errors.New("undadasea")
71-
mockDelegate.On("GetCommittee", instance2).Return(nil, wantErr)
72-
result, err := subject.GetCommittee(instance2)
73+
mockDelegate.On("GetCommittee", mock.Anything, instance2).Return(nil, wantErr)
74+
result, err := subject.GetCommittee(ctx, instance2)
7375
require.ErrorIs(t, err, ErrValidationNoCommittee)
7476
require.ErrorIs(t, err, wantErr)
7577
require.Nil(t, result)
76-
mockDelegate.AssertCalled(t, "GetCommittee", instance2)
78+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance2)
7779
})
7880
t.Run("checks nil committee", func(t *testing.T) {
79-
mockDelegate.On("GetCommittee", instance3).Return(nil, nil)
80-
result, err := subject.GetCommittee(instance3)
81+
mockDelegate.On("GetCommittee", mock.Anything, instance3).Return(nil, nil)
82+
result, err := subject.GetCommittee(ctx, instance3)
8183
require.ErrorContains(t, err, "unexpected")
8284
require.Nil(t, result)
83-
mockDelegate.AssertCalled(t, "GetCommittee", instance3)
85+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance3)
8486
})
8587
t.Run("evicts instances before given", func(t *testing.T) {
86-
mockDelegate.On("GetCommittee", instance5).Return(committee5, nil)
87-
mockDelegate.On("GetCommittee", instance6).Return(committee6, nil)
88-
mockDelegate.On("GetCommittee", instance7).Return(committee7, nil)
88+
mockDelegate.On("GetCommittee", mock.Anything, instance5).Return(committee5, nil)
89+
mockDelegate.On("GetCommittee", mock.Anything, instance6).Return(committee6, nil)
90+
mockDelegate.On("GetCommittee", mock.Anything, instance7).Return(committee7, nil)
8991

9092
// Populate
91-
result, err := subject.GetCommittee(instance5)
93+
result, err := subject.GetCommittee(ctx, instance5)
9294
require.NoError(t, err)
9395
require.Equal(t, committee5, result)
94-
mockDelegate.AssertCalled(t, "GetCommittee", instance5)
95-
result, err = subject.GetCommittee(instance6)
96+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance5)
97+
result, err = subject.GetCommittee(ctx, instance6)
9698
require.NoError(t, err)
9799
require.Equal(t, committee6, result)
98-
mockDelegate.AssertCalled(t, "GetCommittee", instance6)
99-
result, err = subject.GetCommittee(instance7)
100+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance6)
101+
result, err = subject.GetCommittee(ctx, instance7)
100102
require.NoError(t, err)
101103
require.Equal(t, committee7, result)
102-
mockDelegate.AssertCalled(t, "GetCommittee", instance7)
104+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance7)
103105

104106
// Assert cache hit.
105-
result, err = subject.GetCommittee(instance5)
107+
result, err = subject.GetCommittee(ctx, instance5)
106108
require.NoError(t, err)
107109
require.Equal(t, committee5, result)
108110
mockDelegate.AssertNotCalled(t, "GetCommittee")
109-
result, err = subject.GetCommittee(instance6)
111+
result, err = subject.GetCommittee(ctx, instance6)
110112
require.NoError(t, err)
111113
require.Equal(t, committee6, result)
112114
mockDelegate.AssertNotCalled(t, "GetCommittee")
113-
result, err = subject.GetCommittee(instance7)
115+
result, err = subject.GetCommittee(ctx, instance7)
114116
require.NoError(t, err)
115117
require.Equal(t, committee7, result)
116118
mockDelegate.AssertNotCalled(t, "GetCommittee")
@@ -119,14 +121,14 @@ func TestCachedCommitteeProvider_GetCommittee(t *testing.T) {
119121
subject.EvictCommitteesBefore(instance6)
120122

121123
// Assert cache miss.
122-
result, err = subject.GetCommittee(instance5)
124+
result, err = subject.GetCommittee(ctx, instance5)
123125
require.NoError(t, err)
124126
require.Equal(t, committee5, result)
125-
mockDelegate.AssertCalled(t, "GetCommittee", instance5)
126-
result, err = subject.GetCommittee(instance1)
127+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance5)
128+
result, err = subject.GetCommittee(ctx, instance1)
127129
require.NoError(t, err)
128130
require.Equal(t, committeeWithValidPowerTable, result)
129-
mockDelegate.AssertCalled(t, "GetCommittee", instance1)
131+
mockDelegate.AssertCalled(t, "GetCommittee", mock.Anything, instance1)
130132
})
131133
}
132134

0 commit comments

Comments
 (0)