diff --git a/in_session.go b/in_session.go index 816b4b240..6edbb8409 100644 --- a/in_session.go +++ b/in_session.go @@ -225,38 +225,31 @@ func (state inSession) handleResendRequest(session *session, msg *Message) (next return state } -func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int, inReplyTo Message) (err error) { +func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int, inReplyTo Message) error { if session.DisableMessagePersist { - err = state.generateSequenceReset(session, beginSeqNo, endSeqNo+1, inReplyTo) - return - } - - msgs, err := session.store.GetMessages(beginSeqNo, endSeqNo) - if err != nil { - session.log.OnEventf("error retrieving messages from store: %s", err.Error()) - return + return state.generateSequenceReset(session, beginSeqNo, endSeqNo+1, inReplyTo) } seqNum := beginSeqNo nextSeqNum := seqNum msg := NewMessage() - for _, msgBytes := range msgs { - err = ParseMessageWithDataDictionary(msg, bytes.NewBuffer(msgBytes), session.transportDataDictionary, session.appDataDictionary) + err := session.store.IterateMessages(beginSeqNo, endSeqNo, func(msgBytes []byte) error { + err := ParseMessageWithDataDictionary(msg, bytes.NewBuffer(msgBytes), session.transportDataDictionary, session.appDataDictionary) if err != nil { session.log.OnEventf("Resend Msg Parse Error: %v, %v", err.Error(), bytes.NewBuffer(msgBytes).String()) - return // We cant continue with a message that cant be parsed correctly. + return err // We cant continue with a message that cant be parsed correctly. } msgType, _ := msg.Header.GetBytes(tagMsgType) sentMessageSeqNum, _ := msg.Header.GetInt(tagMsgSeqNum) if isAdminMessageType(msgType) { nextSeqNum = sentMessageSeqNum + 1 - continue + return nil } if !session.resend(msg) { nextSeqNum = sentMessageSeqNum + 1 - continue + return nil } if seqNum != sentMessageSeqNum { @@ -271,6 +264,11 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int seqNum = sentMessageSeqNum + 1 nextSeqNum = seqNum + return nil + }) + if err != nil { + session.log.OnEventf("error retrieving messages from store: %s", err.Error()) + return err } if seqNum != nextSeqNum { // gapfill for catch-up @@ -279,7 +277,7 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int } } - return + return nil } func (state inSession) processReject(session *session, msg *Message, rej MessageRejectError) sessionState { diff --git a/internal/testsuite/store_suite.go b/internal/testsuite/store_suite.go index 0563a5861..49b76b05a 100644 --- a/internal/testsuite/store_suite.go +++ b/internal/testsuite/store_suite.go @@ -77,6 +77,28 @@ func (s *StoreTestSuite) TestMessageStoreReset() { s.Equal(1, s.MsgStore.NextTargetMsgSeqNum()) } +func (s *StoreTestSuite) fetchMessages(beginSeqNum, endSeqNum int) (msgs [][]byte) { + s.T().Helper() + + // Fetch messages from the new iterator + err := s.MsgStore.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error { + msgs = append(msgs, msg) + return nil + }) + s.Require().Nil(err) + + // Fetch messages from the old getter + oldMsgs, err := s.MsgStore.GetMessages(beginSeqNum, endSeqNum) + s.Require().Nil(err) + + // Ensure the output is the same + s.Require().Len(msgs, len(oldMsgs)) + for idx, msg := range msgs { + s.Require().EqualValues(msg, oldMsgs[idx]) + } + return +} + func (s *StoreTestSuite) TestMessageStoreSaveMessageGetMessage() { // Given the following saved messages expectedMsgsBySeqNum := map[int]string{ @@ -89,8 +111,7 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageGetMessage() { } // When the messages are retrieved from the MessageStore - actualMsgs, err := s.MsgStore.GetMessages(1, 3) - s.Require().Nil(err) + actualMsgs := s.fetchMessages(1, 3) // Then the messages should be s.Require().Len(actualMsgs, 3) @@ -102,8 +123,7 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageGetMessage() { s.Require().Nil(s.MsgStore.Refresh()) // And the messages are retrieved from the MessageStore - actualMsgs, err = s.MsgStore.GetMessages(1, 3) - s.Require().Nil(err) + actualMsgs = s.fetchMessages(1, 3) // Then the messages should still be s.Require().Len(actualMsgs, 3) @@ -127,8 +147,7 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() { s.Equal(423, s.MsgStore.NextSenderMsgSeqNum()) // When the messages are retrieved from the MessageStore - actualMsgs, err := s.MsgStore.GetMessages(1, 3) - s.Require().Nil(err) + actualMsgs := s.fetchMessages(1, 3) // Then the messages should be s.Require().Len(actualMsgs, 3) @@ -140,8 +159,7 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() { s.Require().Nil(s.MsgStore.Refresh()) // And the messages are retrieved from the MessageStore - actualMsgs, err = s.MsgStore.GetMessages(1, 3) - s.Require().Nil(err) + actualMsgs = s.fetchMessages(1, 3) s.Equal(423, s.MsgStore.NextSenderMsgSeqNum()) @@ -154,8 +172,7 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() { func (s *StoreTestSuite) TestMessageStoreGetMessagesEmptyStore() { // When messages are retrieved from an empty store - messages, err := s.MsgStore.GetMessages(1, 2) - require.Nil(s.T(), err) + messages := s.fetchMessages(1, 2) // Then no messages should be returned require.Empty(s.T(), messages, "Did not expect messages from empty store") @@ -187,8 +204,7 @@ func (s *StoreTestSuite) TestMessageStoreGetMessagesVariousRanges() { // Then the returned messages should be for _, tc := range testCases { - actualMsgs, err := s.MsgStore.GetMessages(tc.beginSeqNo, tc.endSeqNo) - require.Nil(t, err) + actualMsgs := s.fetchMessages(tc.beginSeqNo, tc.endSeqNo) require.Len(t, actualMsgs, len(tc.expectedBytes)) for i, expectedMsg := range tc.expectedBytes { assert.Equal(t, string(expectedMsg), string(actualMsgs[i])) diff --git a/memorystore.go b/memorystore.go index 2379016e2..7c57ca783 100644 --- a/memorystore.go +++ b/memorystore.go @@ -97,14 +97,24 @@ func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg return store.IncrNextSenderMsgSeqNum() } -func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { - var msgs [][]byte +func (store *memoryStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error { for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ { if m, ok := store.messageMap[seqNum]; ok { - msgs = append(msgs, m) + if err := cb(m); err != nil { + return err + } } } - return msgs, nil + return nil +} + +func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { + var msgs [][]byte + err := store.IterateMessages(beginSeqNum, endSeqNum, func(m []byte) error { + msgs = append(msgs, m) + return nil + }) + return msgs, err } type memoryStoreFactory struct{} diff --git a/store.go b/store.go index 34e2570e4..6689297ba 100644 --- a/store.go +++ b/store.go @@ -36,6 +36,7 @@ type MessageStore interface { SaveMessage(seqNum int, msg []byte) error SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) + IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error Refresh() error Reset() error diff --git a/store/file/filestore.go b/store/file/filestore.go index 0cc1e3c33..ae44540b1 100644 --- a/store/file/filestore.go +++ b/store/file/filestore.go @@ -378,18 +378,28 @@ func (store *fileStore) getMessage(seqNum int) (msg []byte, found bool, err erro return msg, true, nil } -func (store *fileStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { - var msgs [][]byte +func (store *fileStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error { for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ { m, found, err := store.getMessage(seqNum) if err != nil { - return nil, err + return err } if found { - msgs = append(msgs, m) + if err = cb(m); err != nil { + return err + } } } - return msgs, nil + return nil +} + +func (store *fileStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { + var msgs [][]byte + err := store.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error { + msgs = append(msgs, msg) + return nil + }) + return msgs, err } // Close closes the store's files. diff --git a/store/mongo/mongostore.go b/store/mongo/mongostore.go index 5af278e10..42696388e 100644 --- a/store/mongo/mongostore.go +++ b/store/mongo/mongostore.go @@ -338,17 +338,17 @@ func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [ return store.cache.SetNextSenderMsgSeqNum(next) } -func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) { +func (store *mongoStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error { msgFilter := generateMessageFilter(&store.sessionID) // Marshal into database form. msgFilterBytes, err := bson.Marshal(msgFilter) if err != nil { - return + return err } seqFilter := bson.M{} err = bson.Unmarshal(msgFilterBytes, &seqFilter) if err != nil { - return + return err } // Modify the query to use a range for the sequence filter. seqFilter["msgseq"] = bson.M{ @@ -358,18 +358,26 @@ func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, sortOpt := options.Find().SetSort(bson.D{{Key: "msgseq", Value: 1}}) cursor, err := store.db.Database(store.mongoDatabase).Collection(store.messagesCollection).Find(context.Background(), seqFilter, sortOpt) if err != nil { - return + return err } - + defer func() { _ = cursor.Close(context.Background()) }() for cursor.Next(context.Background()) { if err = cursor.Decode(&msgFilter); err != nil { - return + return err + } else if err = cb(msgFilter.Message); err != nil { + return err } - msgs = append(msgs, msgFilter.Message) } + return nil +} - err = cursor.Close(context.Background()) - return +func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { + var msgs [][]byte + err := store.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error { + msgs = append(msgs, msg) + return nil + }) + return msgs, err } // Close closes the store's database connection. diff --git a/store/sql/sqlstore.go b/store/sql/sqlstore.go index 7550e3012..aeaeb6eb3 100644 --- a/store/sql/sqlstore.go +++ b/store/sql/sqlstore.go @@ -352,9 +352,8 @@ func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []b return store.cache.SetNextSenderMsgSeqNum(next) } -func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { +func (store *sqlStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error { s := store.sessionID - var msgs [][]byte rows, err := store.db.Query(sqlString(`SELECT message FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? @@ -366,23 +365,29 @@ func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) s.TargetCompID, s.TargetSubID, s.TargetLocationID, beginSeqNum, endSeqNum) if err != nil { - return nil, err + return err } - defer rows.Close() + defer func() { _ = rows.Close() }() for rows.Next() { var message string - if err := rows.Scan(&message); err != nil { - return nil, err + if err = rows.Scan(&message); err != nil { + return err + } else if err = cb([]byte(message)); err != nil { + return err } - msgs = append(msgs, []byte(message)) } - if err := rows.Err(); err != nil { - return nil, err - } + return rows.Err() +} - return msgs, nil +func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { + var msgs [][]byte + err := store.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error { + msgs = append(msgs, msg) + return nil + }) + return msgs, err } // Close closes the store's database connection.