Skip to content

Commit 717ebae

Browse files
committed
feat: improve consumer shutdown procedure
1 parent 9cb9578 commit 717ebae

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

internal/cmd/consumer/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"go.uber.org/zap"
1616
)
1717

18-
func main(_ consumer.Consumer) {
18+
func main(_ *consumer.Consumer) {
1919
area, _ := pterm.DefaultArea.WithCenter().Start()
2020
text, _ := pterm.DefaultBigText.WithLetters(putils.LettersFromString("Redpanda101")).Srender()
2121
area.Update(text)

internal/infra/consumer/consumer.go

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package consumer
33
import (
44
"context"
55
"encoding/json"
6+
"sync"
67
"time"
78

89
"github.com/1995parham-teaching/redpanda101/internal/domain/model"
@@ -23,42 +24,96 @@ type Consumer struct {
2324
db *pgxpool.Pool
2425
tracer *kotel.Tracer
2526
metric *Metric
27+
wg sync.WaitGroup
2628
}
2729

28-
func Provide(
30+
func Provide( // nolint: funlen
2931
lc fx.Lifecycle,
3032
client *kgo.Client,
3133
logger *zap.Logger,
3234
db *pgxpool.Pool,
3335
tracer *kotel.Tracer,
3436
tele telemetry.Telemetery,
35-
) Consumer {
36-
c := Consumer{
37+
) *Consumer {
38+
c := &Consumer{
3739
client: client,
3840
logger: logger,
3941
db: db,
4042
tracer: tracer,
4143
metric: NewMetric(tele.MeterRegistry, tele.Namespace, tele.ServiceName),
44+
wg: sync.WaitGroup{},
4245
}
4346

47+
shutdown := make(chan struct{})
48+
4449
client.AddConsumeTopics(constant.Topic)
4550

46-
lc.Append(fx.StartHook(func() {
47-
go c.Consume()
48-
}))
51+
lc.Append(fx.Hook{
52+
OnStart: func(ctx context.Context) error {
53+
ctx = context.WithoutCancel(ctx)
54+
ctx, cancel := context.WithCancel(ctx)
55+
56+
go func() {
57+
for {
58+
<-shutdown
59+
cancel()
60+
}
61+
}()
62+
63+
go c.Consume(ctx)
64+
65+
return nil
66+
},
67+
OnStop: func(ctx context.Context) error {
68+
logger.Info("shutting down consumer gracefully")
69+
70+
close(shutdown)
71+
72+
done := make(chan struct{})
73+
74+
go func() {
75+
c.wg.Wait()
76+
77+
close(done)
78+
}()
79+
80+
select {
81+
case <-done:
82+
logger.Info("consumer shutdown completed successfully")
83+
case <-ctx.Done():
84+
logger.Warn("consumer shutdown timed out, forcing shutdown")
85+
}
86+
87+
return nil
88+
},
89+
})
4990

5091
return c
5192
}
5293

53-
func (c Consumer) Consume() {
54-
ch := make(chan *kgo.Record)
94+
func (c *Consumer) Consume(ctx context.Context) {
95+
ch := make(chan *kgo.Record, numberOfProcessors)
5596

5697
for range numberOfProcessors {
57-
go c.process(ch)
98+
c.wg.Add(1)
99+
100+
go c.process(ctx, ch)
58101
}
59102

103+
c.logger.Info("consumer started", zap.Int("workers", numberOfProcessors))
104+
105+
// Main consume loop
60106
for {
61-
fetches := c.client.PollFetches(context.Background())
107+
select {
108+
case <-ctx.Done():
109+
c.logger.Info("consumer cancelled, stopping fetch loop")
110+
close(ch)
111+
112+
return
113+
default:
114+
}
115+
116+
fetches := c.client.PollFetches(ctx)
62117

63118
if errs := fetches.Errors(); len(errs) > 0 {
64119
for _, err := range errs {
@@ -74,12 +129,24 @@ func (c Consumer) Consume() {
74129
iter := fetches.RecordIter()
75130
for !iter.Done() {
76131
record := iter.Next()
77-
ch <- record
132+
133+
select {
134+
case <-ctx.Done():
135+
c.logger.Info("consumer cancelled while sending record to workers")
136+
close(ch)
137+
138+
return
139+
case ch <- record:
140+
}
78141
}
79142
}
80143
}
81144

82-
func (c Consumer) process(ch <-chan *kgo.Record) {
145+
func (c *Consumer) process(_ context.Context, ch <-chan *kgo.Record) {
146+
defer c.wg.Done()
147+
148+
c.logger.Debug("worker started")
149+
83150
for record := range ch {
84151
ctx, span := c.tracer.WithProcessSpan(record)
85152

@@ -90,12 +157,17 @@ func (c Consumer) process(ch <-chan *kgo.Record) {
90157
err := json.Unmarshal(record.Value, &order)
91158
if err != nil {
92159
c.logger.Error("failed to parse an order from json", zap.Error(err), zap.ByteString("record", record.Value))
160+
span.RecordError(err)
161+
span.End()
162+
163+
continue
93164
}
94165

95166
c.logger.Info("new order received", zap.Any("order", order))
96167

97168
start := time.Now()
98169

170+
// nolint: contextcheck
99171
_, err = c.db.Exec(
100172
ctx,
101173
"INSERT INTO orders (description, src_currency, dst_currency, channel) VALUES ($1, $2, $3, $4)",
@@ -106,10 +178,13 @@ func (c Consumer) process(ch <-chan *kgo.Record) {
106178
)
107179
if err != nil {
108180
c.logger.Error("database insertion failed", zap.Error(err))
181+
span.RecordError(err)
109182
}
110183

111184
c.metric.DatabaseInsertionTime.Observe(time.Since(start).Seconds())
112185

113186
span.End()
114187
}
188+
189+
c.logger.Debug("worker stopped")
115190
}

0 commit comments

Comments
 (0)