Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions v2/inflightmessages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package shuttle

import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
)

const inFlightMessageAbandonTimeout = 10 * time.Second

type inFlightMessages struct {
mu sync.RWMutex
tracked map[*azservicebus.ReceivedMessage]struct{}
}

func newInFlightMessages() *inFlightMessages {
return &inFlightMessages{
tracked: make(map[*azservicebus.ReceivedMessage]struct{}),
}
}

func (m *inFlightMessages) track(message *azservicebus.ReceivedMessage) {
m.mu.Lock()
defer m.mu.Unlock()
m.tracked[message] = struct{}{}
}

func (m *inFlightMessages) forget(message *azservicebus.ReceivedMessage) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.tracked, message)
}

func (m *inFlightMessages) close(ctx context.Context, settler MessageSettler) error {
var errs []error
for _, message := range m.messages() {
abandonCtx, cancel := context.WithTimeout(ctx, inFlightMessageAbandonTimeout)
err := settler.AbandonMessage(abandonCtx, message, nil)
cancel()
if err != nil {
errs = append(errs, fmt.Errorf("failed to abandon message %s during processor close: %w", message.MessageID, err))
continue
}
m.forget(message)
}
return errors.Join(errs...)
}

func (m *inFlightMessages) messages() []*azservicebus.ReceivedMessage {
m.mu.RLock()
defer m.mu.RUnlock()

messages := make([]*azservicebus.ReceivedMessage, 0, len(m.tracked))
for message := range m.tracked {
messages = append(messages, message)
}
return messages
}
167 changes: 167 additions & 0 deletions v2/inflightmessages_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package shuttle

import (
"context"
"errors"
"sync"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
"github.com/stretchr/testify/require"
)

func TestInFlightMessages_CloseAbandonsTrackedMessages(t *testing.T) {
inFlight := newInFlightMessages()
settler := &inFlightMessageSettler{}
trackedMessage := &azservicebus.ReceivedMessage{MessageID: "tracked"}
forgottenMessage := &azservicebus.ReceivedMessage{MessageID: "forgotten"}

inFlight.track(trackedMessage)
inFlight.track(forgottenMessage)
inFlight.forget(forgottenMessage)

require.NoError(t, inFlight.close(context.Background(), settler))

messages := settler.abandonedMessages()
require.Len(t, messages, 1)
require.Same(t, trackedMessage, messages[0])
}

func TestInFlightMessages_CloseRemovesTrackedMessages(t *testing.T) {
inFlight := newInFlightMessages()
settler := &inFlightMessageSettler{}
message := &azservicebus.ReceivedMessage{MessageID: "message"}

inFlight.track(message)

require.NoError(t, inFlight.close(context.Background(), settler))
require.NoError(t, inFlight.close(context.Background(), settler))

messages := settler.abandonedMessages()
require.Len(t, messages, 1)
require.Same(t, message, messages[0])
}

func TestInFlightMessages_CloseReturnsAbandonErrors(t *testing.T) {
firstErr := errors.New("first abandon failed")
secondErr := errors.New("second abandon failed")
inFlight := newInFlightMessages()
settler := &inFlightMessageSettler{
abandonErrors: []error{firstErr, secondErr},
}

inFlight.track(&azservicebus.ReceivedMessage{MessageID: "first"})
inFlight.track(&azservicebus.ReceivedMessage{MessageID: "second"})

err := inFlight.close(context.Background(), settler)

require.Error(t, err)
require.ErrorIs(t, err, firstErr)
require.ErrorIs(t, err, secondErr)
}

func TestInFlightMessages_CloseKeepsMessagesWhenAbandonFails(t *testing.T) {
abandonErr := errors.New("abandon failed")
inFlight := newInFlightMessages()
failingSettler := &inFlightMessageSettler{
abandonErrors: []error{abandonErr},
}
message := &azservicebus.ReceivedMessage{MessageID: "message"}

inFlight.track(message)

err := inFlight.close(context.Background(), failingSettler)

require.ErrorIs(t, err, abandonErr)
messages := inFlight.messages()
require.Len(t, messages, 1)
require.Same(t, message, messages[0])

successfulSettler := &inFlightMessageSettler{}
require.NoError(t, inFlight.close(context.Background(), successfulSettler))
require.Empty(t, inFlight.messages())
abandonedMessages := successfulSettler.abandonedMessages()
require.Len(t, abandonedMessages, 1)
require.Same(t, message, abandonedMessages[0])
}

func TestInFlightMessages_CloseUsesAbandonTimeout(t *testing.T) {
inFlight := newInFlightMessages()
settler := &inFlightMessageSettler{}

inFlight.track(&azservicebus.ReceivedMessage{MessageID: "message"})

beforeClose := time.Now()
require.NoError(t, inFlight.close(context.Background(), settler))
afterClose := time.Now()

deadlines := settler.abandonDeadlines()
require.Len(t, deadlines, 1)
require.True(t, deadlines[0].ok)
require.True(t, deadlines[0].deadline.After(beforeClose.Add(9*time.Second)))
require.True(t, deadlines[0].deadline.Before(afterClose.Add(11*time.Second)))
}

type inFlightMessageSettler struct {
mu sync.Mutex
abandoned []*azservicebus.ReceivedMessage
deadlines []abandonDeadline
abandonErrors []error
abandonAttempt int
}

type abandonDeadline struct {
deadline time.Time
ok bool
}

func (s *inFlightMessageSettler) AbandonMessage(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.AbandonMessageOptions) error {
deadline, ok := ctx.Deadline()

s.mu.Lock()
defer s.mu.Unlock()
s.abandoned = append(s.abandoned, message)
s.deadlines = append(s.deadlines, abandonDeadline{
deadline: deadline,
ok: ok,
})
err := error(nil)
if s.abandonAttempt < len(s.abandonErrors) {
err = s.abandonErrors[s.abandonAttempt]
}
s.abandonAttempt++
return err
}

func (s *inFlightMessageSettler) CompleteMessage(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.CompleteMessageOptions) error {
return nil
}

func (s *inFlightMessageSettler) DeadLetterMessage(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.DeadLetterOptions) error {
return nil
}

func (s *inFlightMessageSettler) DeferMessage(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.DeferMessageOptions) error {
return nil
}

func (s *inFlightMessageSettler) RenewMessageLock(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.RenewMessageLockOptions) error {
return nil
}

func (s *inFlightMessageSettler) abandonedMessages() []*azservicebus.ReceivedMessage {
s.mu.Lock()
defer s.mu.Unlock()
messages := make([]*azservicebus.ReceivedMessage, len(s.abandoned))
copy(messages, s.abandoned)
return messages
}

func (s *inFlightMessageSettler) abandonDeadlines() []abandonDeadline {
s.mu.Lock()
defer s.mu.Unlock()
deadlines := make([]abandonDeadline, len(s.deadlines))
copy(deadlines, s.deadlines)
return deadlines
}
84 changes: 67 additions & 17 deletions v2/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
Expand Down Expand Up @@ -42,7 +43,11 @@ type Processor struct {
receiver Receiver
options ProcessorOptions
handle Handler
concurrencyTokens chan struct{} // tracks how many concurrent messages are currently being handled by the processor
concurrencyTokens chan struct{} // TODO: remove once receive sizing and in-flight tracking share a simpler lifecycle model.
inFlightMessages *inFlightMessages
shutdownCtx context.Context
shutdownCancel context.CancelFunc
receiveMu sync.Mutex
}

// ProcessorOptions configures the processor
Expand Down Expand Up @@ -101,18 +106,27 @@ func applyProcessorOptions(options *ProcessorOptions) *ProcessorOptions {
// NewProcessor creates a new processor with the provided receiver and handler.
func NewProcessor(receiver Receiver, handler HandlerFunc, options *ProcessorOptions) *Processor {
opts := applyProcessorOptions(options)
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
return &Processor{
receiver: receiver,
handle: handler,
options: *opts,
concurrencyTokens: make(chan struct{}, opts.MaxConcurrency),
inFlightMessages: newInFlightMessages(),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
}

// Start starts processing on the receiver and blocks until the processor is stopped or the context is canceled.
// It will retry starting the processor based on the StartMaxAttempt and StartRetryDelayStrategy.
// Returns a combined list of errors encountered during each processor start attempt.
func (p *Processor) Start(ctx context.Context) (err error) {
if p.isClosed() {
return context.Canceled
}
ctx, cancel := p.startContext(ctx)
defer cancel()
defer func() {
if rec := recover(); rec != nil {
err = fmt.Errorf("panic recovered from processor: %s", rec)
Expand All @@ -121,6 +135,15 @@ func (p *Processor) Start(ctx context.Context) (err error) {
return p.startWithRetries(ctx)
}

// Close stops receiving new messages, cancels in-flight message handlers, and
// abandons messages currently held by the processor.
func (p *Processor) Close(ctx context.Context) error {
p.shutdownCancel()
p.receiveMu.Lock()
defer p.receiveMu.Unlock()
return p.inFlightMessages.close(ctx, p.receiver)
}

// startWithRetries starts a processor and blocks until an error occurs or the context is canceled.
// It will retry starting the processor based on the StartMaxAttempt and StartRetryDelayStrategy.
// Returns a combined list of errors during the start attempts or ctx.Err() if the context
Expand All @@ -147,18 +170,32 @@ func (p *Processor) startWithRetries(ctx context.Context) error {
return savedError
}

func (p *Processor) startContext(ctx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
if p.isClosed() {
cancel()
return ctx, cancel
}
go func() {
select {
case <-p.shutdownCtx.Done():
cancel()
case <-ctx.Done():
}
}()
return ctx, cancel
}

func (p *Processor) isClosed() bool {
return p.shutdownCtx.Err() != nil
}

// start starts the processor and blocks until an error occurs or the context is canceled.
func (p *Processor) start(ctx context.Context) error {
logger := getLogger(ctx)
logger.Info("starting processor")
messages, err := p.receiver.ReceiveMessages(ctx, p.options.MaxReceiveCount, nil)
if err != nil {
return fmt.Errorf("failed to receive messages: %w", err)
}
logger.Info(fmt.Sprintf("received %d messages - initial", len(messages)))
processor.Metric.IncMessageReceived(float64(len(messages)))
for _, msg := range messages {
p.process(ctx, msg)
if err := p.receiveAndProcess(ctx, p.options.MaxReceiveCount, "initial"); err != nil {
return err
}
for ctx.Err() == nil {
select {
Expand All @@ -167,14 +204,8 @@ func (p *Processor) start(ctx context.Context) error {
if ctx.Err() != nil || maxMessages == 0 {
break
}
messages, err := p.receiver.ReceiveMessages(ctx, maxMessages, nil)
if err != nil {
return fmt.Errorf("failed to receive messages: %w", err)
}
logger.Info(fmt.Sprintf("received %d messages from processor loop", len(messages)))
processor.Metric.IncMessageReceived(float64(len(messages)))
for _, msg := range messages {
p.process(ctx, msg)
if err := p.receiveAndProcess(ctx, maxMessages, "from processor loop"); err != nil {
return err
}
case <-ctx.Done():
logger.Info("context done, stop receiving from processor")
Expand All @@ -184,13 +215,32 @@ func (p *Processor) start(ctx context.Context) error {
return ctx.Err()
}

func (p *Processor) receiveAndProcess(ctx context.Context, maxMessages int, source string) error {
p.receiveMu.Lock()
defer p.receiveMu.Unlock()

messages, err := p.receiver.ReceiveMessages(ctx, maxMessages, nil)
if err != nil {
return fmt.Errorf("failed to receive messages: %w", err)
}
getLogger(ctx).Info(fmt.Sprintf("received %d messages - %s", len(messages), source))
processor.Metric.IncMessageReceived(float64(len(messages)))
for _, msg := range messages {
p.process(ctx, msg)
}
return nil
}

func (p *Processor) process(ctx context.Context, message *azservicebus.ReceivedMessage) {
p.concurrencyTokens <- struct{}{}
p.inFlightMessages.track(message)

go func() {
msgContext, cancel := context.WithCancel(ctx)
// cancel messageContext when we get out of this goroutine
defer cancel()
defer func() {
p.inFlightMessages.forget(message)
<-p.concurrencyTokens
processor.Metric.IncMessageHandled(message)
processor.Metric.DecConcurrentMessageCount(message)
Expand Down
Loading
Loading