diff --git a/pkg/chainaccessor/legacy_accessor.go b/pkg/chainaccessor/legacy_accessor.go index 366e44546..23e394469 100644 --- a/pkg/chainaccessor/legacy_accessor.go +++ b/pkg/chainaccessor/legacy_accessor.go @@ -7,8 +7,12 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "github.com/smartcontractkit/chainlink-ccip/pkg/consts" "github.com/smartcontractkit/chainlink-ccip/pkg/contractreader" + "github.com/smartcontractkit/chainlink-ccip/pkg/logutil" cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3" ) @@ -59,16 +63,13 @@ func (l *LegacyAccessor) GetContractAddress(contractName string) ([]byte, error) return addressBytes, nil } -func (l *LegacyAccessor) GetChainFeeComponents( - ctx context.Context, -) map[cciptypes.ChainSelector]cciptypes.ChainFeeComponents { - // TODO(NONEVM-1865): implement - panic("implement me") -} +func (l *LegacyAccessor) GetChainFeeComponents(ctx context.Context) (cciptypes.ChainFeeComponents, error) { + fc, err := l.contractWriter.GetFeeComponents(ctx) + if err != nil { + return cciptypes.ChainFeeComponents{}, fmt.Errorf("get fee components: %w", err) + } -func (l *LegacyAccessor) GetDestChainFeeComponents(ctx context.Context) (types.ChainFeeComponents, error) { - // TODO(NONEVM-1865): implement - panic("implement me") + return *fc, nil } func (l *LegacyAccessor) Sync(ctx context.Context, contracts cciptypes.ContractAddresses) error { @@ -78,11 +79,82 @@ func (l *LegacyAccessor) Sync(ctx context.Context, contracts cciptypes.ContractA func (l *LegacyAccessor) MsgsBetweenSeqNums( ctx context.Context, - dest cciptypes.ChainSelector, + destChainSelector cciptypes.ChainSelector, seqNumRange cciptypes.SeqNumRange, ) ([]cciptypes.Message, error) { - // TODO(NONEVM-1865): implement - panic("implement me") + lggr := logutil.WithContextValues(ctx, l.lggr) + seq, err := l.contractReader.ExtendedQueryKey( + ctx, + consts.ContractNameOnRamp, + query.KeyFilter{ + Key: consts.EventNameCCIPMessageSent, + Expressions: []query.Expression{ + query.Comparator(consts.EventAttributeSourceChain, primitives.ValueComparator{ + Value: l.chainSelector, + Operator: primitives.Eq, + }), + query.Comparator(consts.EventAttributeDestChain, primitives.ValueComparator{ + Value: destChainSelector, + Operator: primitives.Eq, + }), + query.Comparator(consts.EventAttributeSequenceNumber, primitives.ValueComparator{ + Value: seqNumRange.Start(), + Operator: primitives.Gte, + }, primitives.ValueComparator{ + Value: seqNumRange.End(), + Operator: primitives.Lte, + }), + query.Confidence(primitives.Finalized), + }, + }, + query.LimitAndSort{ + SortBy: []query.SortBy{ + query.NewSortBySequence(query.Asc), + }, + Limit: query.Limit{ + Count: uint64(seqNumRange.End() - seqNumRange.Start() + 1), + }, + }, + &SendRequestedEvent{}, + ) + if err != nil { + return nil, fmt.Errorf("failed to query onRamp: %w", err) + } + + lggr.Infow("queried messages between sequence numbers", + "numMsgs", len(seq), + "sourceChainSelector", l.chainSelector, + "seqNumRange", seqNumRange.String(), + ) + + onRampAddress, err := l.GetContractAddress(consts.ContractNameOnRamp) + if err != nil { + return nil, fmt.Errorf("failed to get onRamp contract address: %w", err) + } + + msgs := make([]cciptypes.Message, 0) + for _, item := range seq { + msg, ok := item.Data.(*SendRequestedEvent) + if !ok { + return nil, fmt.Errorf("failed to cast %v to Message", item.Data) + } + + if err := ValidateSendRequestedEvent(msg, l.chainSelector, destChainSelector, seqNumRange); err != nil { + lggr.Errorw("validate send requested event", "err", err, "message", msg) + continue + } + + msg.Message.Header.OnRamp = onRampAddress + msgs = append(msgs, msg.Message) + } + + lggr.Infow("decoded messages between sequence numbers", + "msgs", msgs, + "sourceChainSelector", l.chainSelector, + "seqNumRange", seqNumRange.String(), + ) + + return msgs, nil } func (l *LegacyAccessor) LatestMsgSeqNum(ctx context.Context, dest cciptypes.ChainSelector) (cciptypes.SeqNum, error) { diff --git a/pkg/chainaccessor/types.go b/pkg/chainaccessor/types.go new file mode 100644 index 000000000..3b72f12e1 --- /dev/null +++ b/pkg/chainaccessor/types.go @@ -0,0 +1,10 @@ +package chainaccessor + +import cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3" + +// SendRequestedEvent represents the contents of the event emitted by the CCIP onramp when a message is sent. +type SendRequestedEvent struct { + DestChainSelector cciptypes.ChainSelector + SequenceNumber cciptypes.SeqNum + Message cciptypes.Message +} diff --git a/pkg/chainaccessor/validators.go b/pkg/chainaccessor/validators.go new file mode 100644 index 000000000..7af12c63a --- /dev/null +++ b/pkg/chainaccessor/validators.go @@ -0,0 +1,57 @@ +package chainaccessor + +import ( + "fmt" + + cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3" +) + +func ValidateSendRequestedEvent( + ev *SendRequestedEvent, source, dest cciptypes.ChainSelector, seqNumRange cciptypes.SeqNumRange, +) error { + if ev == nil { + return fmt.Errorf("send requested event is nil") + } + + if ev.Message.Header.DestChainSelector != dest { + return fmt.Errorf("msg dest chain is not the expected queried one") + } + if ev.DestChainSelector != dest { + return fmt.Errorf("dest chain is not the expected queried one") + } + + if ev.Message.Header.SourceChainSelector != source { + return fmt.Errorf("source chain is not the expected queried one") + } + + if ev.SequenceNumber != ev.Message.Header.SequenceNumber { + return fmt.Errorf("event sequence number does not match the message sequence number %d != %d", + ev.SequenceNumber, ev.Message.Header.SequenceNumber) + } + + if ev.SequenceNumber < seqNumRange.Start() || ev.SequenceNumber > seqNumRange.End() { + return fmt.Errorf("send requested event sequence number is not in the expected range") + } + + if ev.Message.Header.MessageID.IsEmpty() { + return fmt.Errorf("message ID is zero") + } + + if len(ev.Message.Receiver) == 0 { + return fmt.Errorf("empty receiver address: %s", ev.Message.Receiver.String()) + } + + if ev.Message.Sender.IsZeroOrEmpty() { + return fmt.Errorf("invalid sender address: %s", ev.Message.Sender.String()) + } + + if ev.Message.FeeTokenAmount.IsEmpty() { + return fmt.Errorf("fee token amount is zero") + } + + if ev.Message.FeeToken.IsZeroOrEmpty() { + return fmt.Errorf("invalid fee token: %s", ev.Message.FeeToken.String()) + } + + return nil +} diff --git a/pkg/reader/ccip.go b/pkg/reader/ccip.go index 56913548f..1282d7a32 100644 --- a/pkg/reader/ccip.go +++ b/pkg/reader/ccip.go @@ -500,64 +500,20 @@ func createExecutedMessagesKeyFilter( return keyFilter, countSqNrs } -type SendRequestedEvent struct { - DestChainSelector cciptypes.ChainSelector - SequenceNumber cciptypes.SeqNum - Message cciptypes.Message -} - func (r *ccipChainReader) MsgsBetweenSeqNums( ctx context.Context, sourceChainSelector cciptypes.ChainSelector, seqNumRange cciptypes.SeqNumRange, ) ([]cciptypes.Message, error) { - lggr := logutil.WithContextValues(ctx, r.lggr) - if err := validateAccessorExistence(r.accessors, sourceChainSelector); err != nil { return nil, err } - onRampAddress, err := r.accessors[sourceChainSelector].GetContractAddress(consts.ContractNameOnRamp) + onRampAddressBeforeQuery, err := r.accessors[sourceChainSelector].GetContractAddress(consts.ContractNameOnRamp) if err != nil { return nil, fmt.Errorf("get onRamp address: %w", err) } - if err := validateReaderExistence(r.contractReaders, sourceChainSelector); err != nil { - return nil, err - } - seq, err := r.contractReaders[sourceChainSelector].ExtendedQueryKey( - ctx, - consts.ContractNameOnRamp, - query.KeyFilter{ - Key: consts.EventNameCCIPMessageSent, - Expressions: []query.Expression{ - query.Comparator(consts.EventAttributeSourceChain, primitives.ValueComparator{ - Value: sourceChainSelector, - Operator: primitives.Eq, - }), - query.Comparator(consts.EventAttributeDestChain, primitives.ValueComparator{ - Value: r.destChain, - Operator: primitives.Eq, - }), - query.Comparator(consts.EventAttributeSequenceNumber, primitives.ValueComparator{ - Value: seqNumRange.Start(), - Operator: primitives.Gte, - }, primitives.ValueComparator{ - Value: seqNumRange.End(), - Operator: primitives.Lte, - }), - query.Confidence(primitives.Finalized), - }, - }, - query.LimitAndSort{ - SortBy: []query.SortBy{ - query.NewSortBySequence(query.Asc), - }, - Limit: query.Limit{ - Count: uint64(seqNumRange.End() - seqNumRange.Start() + 1), - }, - }, - &SendRequestedEvent{}, - ) + messages, err := r.accessors[sourceChainSelector].MsgsBetweenSeqNums(ctx, r.destChain, seqNumRange) if err != nil { - return nil, fmt.Errorf("failed to query onRamp: %w", err) + return nil, fmt.Errorf("failed to call MsgsBetweenSeqNums on accessor: %w", err) } onRampAddressAfterQuery, err := r.accessors[sourceChainSelector].GetContractAddress(consts.ContractNameOnRamp) @@ -566,43 +522,11 @@ func (r *ccipChainReader) MsgsBetweenSeqNums( } // Ensure the onRamp address hasn't changed during the query. - if !bytes.Equal(onRampAddress, onRampAddressAfterQuery) { - return nil, fmt.Errorf("onRamp address has changed from %s to %s", onRampAddress, onRampAddressAfterQuery) - } - - lggr.Infow("queried messages between sequence numbers", - "numMsgs", len(seq), - "sourceChainSelector", sourceChainSelector, - "seqNumRange", seqNumRange.String(), - ) - - msgs := make([]cciptypes.Message, 0) - for _, item := range seq { - msg, ok := item.Data.(*SendRequestedEvent) - if !ok { - return nil, fmt.Errorf("failed to cast %v to Message", item.Data) - } - - if err := validateSendRequestedEvent(msg, sourceChainSelector, r.destChain, seqNumRange); err != nil { - lggr.Errorw("validate send requested event", "err", err, "message", msg) - continue - } - - txHash, err := contractreader.ExtractTxHash(item.Cursor) - if err != nil { - lggr.Warnw("failed to extract transaction hash from cursor", "err", err, "cursor", item.Cursor) - } - - msg.Message.Header.OnRamp = onRampAddress - msg.Message.Header.TxHash = txHash - msgs = append(msgs, msg.Message) + if !bytes.Equal(onRampAddressBeforeQuery, onRampAddressAfterQuery) { + return nil, fmt.Errorf("onRamp address has changed from %s to %s", onRampAddressBeforeQuery, onRampAddressAfterQuery) } - lggr.Infow("decoded messages between sequence numbers", "msgs", msgs, - "sourceChainSelector", sourceChainSelector, - "seqNumRange", seqNumRange.String()) - - return msgs, nil + return messages, nil } // LatestMsgSeqNum reads the source chain and returns the latest finalized message sequence number. @@ -636,7 +560,7 @@ func (r *ccipChainReader) LatestMsgSeqNum( }, Limit: query.Limit{Count: 1}, }, - &SendRequestedEvent{}, + &chainaccessor.SendRequestedEvent{}, ) if err != nil { return 0, fmt.Errorf("failed to query onRamp: %w", err) @@ -652,12 +576,12 @@ func (r *ccipChainReader) LatestMsgSeqNum( } item := seq[0] - msg, ok := item.Data.(*SendRequestedEvent) + msg, ok := item.Data.(*chainaccessor.SendRequestedEvent) if !ok { return 0, fmt.Errorf("failed to cast %v to SendRequestedEvent", item.Data) } - if err := validateSendRequestedEvent(msg, chain, r.destChain, + if err := chainaccessor.ValidateSendRequestedEvent(msg, chain, r.destChain, cciptypes.NewSeqNumRange(msg.Message.Header.SequenceNumber, msg.Message.Header.SequenceNumber)); err != nil { return 0, fmt.Errorf("message invalid msg %v: %w", msg, err) } @@ -881,12 +805,12 @@ func (r *ccipChainReader) GetChainsFeeComponents( feeComponents := make(map[cciptypes.ChainSelector]types.ChainFeeComponents, len(r.contractWriters)) for _, chain := range chains { - chainWriter, ok := r.contractWriters[chain] - if !ok { - lggr.Errorw("contract writer not found", "chain", chain) + if err := validateAccessorExistence(r.accessors, chain); err != nil { + lggr.Errorw("accessor not found", "chain", chain, "err", err) continue } - feeComponent, err := chainWriter.GetFeeComponents(ctx) + + feeComponent, err := r.accessors[chain].GetChainFeeComponents(ctx) if err != nil { lggr.Errorw("failed to get chain fee components", "chain", chain, "err", err) continue @@ -901,7 +825,7 @@ func (r *ccipChainReader) GetChainsFeeComponents( continue } - feeComponents[chain] = *feeComponent + feeComponents[chain] = feeComponent } return feeComponents } @@ -2227,55 +2151,6 @@ func validateExecutionStateChangedEvent( return nil } -func validateSendRequestedEvent( - ev *SendRequestedEvent, source, dest cciptypes.ChainSelector, seqNumRange cciptypes.SeqNumRange) error { - if ev == nil { - return fmt.Errorf("send requested event is nil") - } - - if ev.Message.Header.DestChainSelector != dest { - return fmt.Errorf("msg dest chain is not the expected queried one") - } - if ev.DestChainSelector != dest { - return fmt.Errorf("dest chain is not the expected queried one") - } - - if ev.Message.Header.SourceChainSelector != source { - return fmt.Errorf("source chain is not the expected queried one") - } - - if ev.SequenceNumber != ev.Message.Header.SequenceNumber { - return fmt.Errorf("event sequence number does not match the message sequence number %d != %d", - ev.SequenceNumber, ev.Message.Header.SequenceNumber) - } - - if ev.SequenceNumber < seqNumRange.Start() || ev.SequenceNumber > seqNumRange.End() { - return fmt.Errorf("send requested event sequence number is not in the expected range") - } - - if ev.Message.Header.MessageID.IsEmpty() { - return fmt.Errorf("message ID is zero") - } - - if len(ev.Message.Receiver) == 0 { - return fmt.Errorf("empty receiver address: %s", ev.Message.Receiver.String()) - } - - if ev.Message.Sender.IsZeroOrEmpty() { - return fmt.Errorf("invalid sender address: %s", ev.Message.Sender.String()) - } - - if ev.Message.FeeTokenAmount.IsEmpty() { - return fmt.Errorf("fee token amount is zero") - } - - if ev.Message.FeeToken.IsZeroOrEmpty() { - return fmt.Errorf("invalid fee token: %s", ev.Message.FeeToken.String()) - } - - return nil -} - // ccipReaderInternal defines the interface that ConfigPoller needs from the ccipChainReader // This allows for better encapsulation and easier testing through mocking type ccipReaderInternal interface { diff --git a/pkg/reader/ccip_test.go b/pkg/reader/ccip_test.go index b1bd60562..7c8c0a6a3 100644 --- a/pkg/reader/ccip_test.go +++ b/pkg/reader/ccip_test.go @@ -1085,9 +1085,6 @@ func TestCCIPChainReader_getFeeQuoterTokenPriceUSD(t *testing.T) { } func TestCCIPFeeComponents_HappyPath(t *testing.T) { - destCR := reader_mocks.NewMockContractReaderFacade(t) - destCR.EXPECT().Bind(mock.Anything, mock.Anything).Return(nil) - cw := writer_mocks.NewMockContractWriter(t) cw.EXPECT().GetFeeComponents(mock.Anything).Return( &types.ChainFeeComponents{ @@ -1097,15 +1094,25 @@ func TestCCIPFeeComponents_HappyPath(t *testing.T) { ) contractWriters := make(map[cciptypes.ChainSelector]types.ContractWriter) - // Missing writer for chainB contractWriters[chainA] = cw + contractWriters[chainB] = cw contractWriters[chainC] = cw + sourceCRs := make(map[cciptypes.ChainSelector]*reader_mocks.MockContractReaderFacade) + for _, chain := range []cciptypes.ChainSelector{chainA, chainB, chainC} { + sourceCRs[chain] = reader_mocks.NewMockContractReaderFacade(t) + } + + destChain := chainC + sourceCRs[destChain].EXPECT().Bind(mock.Anything, mock.Anything).Return(nil) + ccipReader, err := newCCIPChainReaderInternal( tests.Context(t), logger.Test(t), map[cciptypes.ChainSelector]contractreader.ContractReaderFacade{ - chainC: destCR, + chainA: sourceCRs[chainA], + chainB: sourceCRs[chainB], + chainC: sourceCRs[chainC], }, contractWriters, chainC, @@ -1124,10 +1131,12 @@ func TestCCIPFeeComponents_HappyPath(t *testing.T) { ctx := context.Background() feeComponents := ccipReader.GetChainsFeeComponents(ctx, []cciptypes.ChainSelector{chainA, chainB, chainC}) - assert.Len(t, feeComponents, 2) + assert.Len(t, feeComponents, 3) assert.Equal(t, big.NewInt(1), feeComponents[chainA].ExecutionFee) - assert.Equal(t, big.NewInt(2), feeComponents[chainA].DataAvailabilityFee) + assert.Equal(t, big.NewInt(1), feeComponents[chainB].ExecutionFee) assert.Equal(t, big.NewInt(1), feeComponents[chainC].ExecutionFee) + assert.Equal(t, big.NewInt(2), feeComponents[chainA].DataAvailabilityFee) + assert.Equal(t, big.NewInt(2), feeComponents[chainB].DataAvailabilityFee) assert.Equal(t, big.NewInt(2), feeComponents[chainC].DataAvailabilityFee) destChainFeeComponent, err := ccipReader.GetDestChainFeeComponents(ctx) diff --git a/pkg/types/ccipocr3/chain_accessor.go b/pkg/types/ccipocr3/chain_accessor.go index c107953ba..95471c868 100644 --- a/pkg/types/ccipocr3/chain_accessor.go +++ b/pkg/types/ccipocr3/chain_accessor.go @@ -68,17 +68,7 @@ type AllAccessors interface { // Access Type: ChainWriter // Contract: N/A // Confidence: N/A - GetChainFeeComponents( - ctx context.Context, - ) map[ChainSelector]ChainFeeComponents - - // GetDestChainFeeComponents seems redundant. If the error is needed lets - // add it to GetChainFeeComponents. - // - // Deprecated: use GetChainFeeComponents instead. - GetDestChainFeeComponents( - ctx context.Context, - ) (types.ChainFeeComponents, error) + GetChainFeeComponents(ctx context.Context) (ChainFeeComponents, error) // Sync can be used to perform frequent syncing operations inside the reader implementation. // Returns a bool indicating whether something was updated.