Skip to content

Commit 1af9b42

Browse files
Implement full context support in ZDNS and better SIGTERM/SIGINT handling (#538)
* added context and compiles. Killing now kills everything much faster * add ctx handling for the SIGINT and SIGTERM os signals * fix up tests with context.Background * have status handler print a Scan Aborted msg if a user SIGINT's or SIGTERM's a scan rather than Scan Complete * comment and ensure chan was written to and not closed --------- Co-authored-by: Zakir Durumeric <zakird@gmail.com>
1 parent ff137b1 commit 1af9b42

File tree

23 files changed

+187
-110
lines changed

23 files changed

+187
-110
lines changed

examples/all_nameservers_lookup/main.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package main
1515

1616
import (
17+
"context"
1718
"net"
1819

1920
"github.com/miekg/dns"
@@ -33,15 +34,15 @@ func main() {
3334
}
3435
defer resolver.Close()
3536
// LookupAllNameserversIterative will query all root nameservers, and then all TLD nameservers, and then all authoritative nameservers for the domain.
36-
result, _, status, err := resolver.LookupAllNameserversIterative(dnsQuestion, nil)
37+
result, _, status, err := resolver.LookupAllNameserversIterative(context.Background(), dnsQuestion, nil)
3738
if err != nil {
3839
log.Fatal("Error looking up domain: ", err)
3940
}
4041
log.Warnf("Result: %v", result)
4142
log.Warnf("Status: %v", status)
4243
log.Info("We can also specify which root nameservers to use by setting the argument.")
4344

44-
result, _, status, err = resolver.LookupAllNameserversIterative(dnsQuestion, []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}) // a.root-servers.net
45+
result, _, status, err = resolver.LookupAllNameserversIterative(context.Background(), dnsQuestion, []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}) // a.root-servers.net
4546
if err != nil {
4647
log.Fatal("Error looking up domain: ", err)
4748
}
@@ -50,7 +51,7 @@ func main() {
5051

5152
log.Info("You can query multiple recursive resolvers as well")
5253

53-
externalResult, _, status, err := resolver.LookupAllNameserversExternal(dnsQuestion, []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}, {IP: net.ParseIP("8.8.8.8"), Port: 53}}) // Cloudflare and Google recursive resolvers, respectively
54+
externalResult, _, status, err := resolver.LookupAllNameserversExternal(context.Background(), dnsQuestion, []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}, {IP: net.ParseIP("8.8.8.8"), Port: 53}}) // Cloudflare and Google recursive resolvers, respectively
5455
if err != nil {
5556
log.Fatal("Error looking up domain: ", err)
5657
}

src/cli/cli.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package cli
1616

1717
import (
18+
"context"
1819
"errors"
1920
"fmt"
2021
"net"
@@ -32,13 +33,13 @@ import (
3233
var parser *flags.Parser
3334

3435
type InputHandler interface {
35-
FeedChannel(in chan<- string, wg *sync.WaitGroup) error
36+
FeedChannel(ctx context.Context, in chan<- string, wg *sync.WaitGroup) error
3637
}
3738
type OutputHandler interface {
3839
WriteResults(results <-chan string, wg *sync.WaitGroup) error
3940
}
4041
type StatusHandler interface {
41-
LogPeriodicUpdates(statusChan <-chan zdns.Status, wg *sync.WaitGroup) error
42+
LogPeriodicUpdates(statusChan <-chan zdns.Status, statusAbortChan <-chan struct{}, wg *sync.WaitGroup) error
4243
}
4344

4445
// GeneralOptions core options for all ZDNS modules

src/cli/iohandlers/file_handlers.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package iohandlers
1616

1717
import (
1818
"bufio"
19+
"context"
1920
"os"
2021
"sync"
2122

@@ -36,7 +37,7 @@ func NewFileInputHandler(filepath string) *FileInputHandler {
3637
}
3738
}
3839

39-
func (h *FileInputHandler) FeedChannel(in chan<- string, wg *sync.WaitGroup) error {
40+
func (h *FileInputHandler) FeedChannel(ctx context.Context, in chan<- string, wg *sync.WaitGroup) error {
4041
defer close(in)
4142
defer (*wg).Done()
4243

@@ -52,6 +53,9 @@ func (h *FileInputHandler) FeedChannel(in chan<- string, wg *sync.WaitGroup) err
5253
}
5354
s := bufio.NewScanner(f)
5455
for s.Scan() {
56+
if util.HasCtxExpired(ctx) {
57+
return nil // context expired, exiting
58+
}
5559
in <- s.Text()
5660
}
5761
if err := s.Err(); err != nil {

src/cli/iohandlers/status_handler.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func NewStatusHandler(filePath string) *StatusHandler {
4747
}
4848

4949
// LogPeriodicUpdates prints a per-second update to the user scan progress and per-status statistics
50-
func (h *StatusHandler) LogPeriodicUpdates(statusChan <-chan zdns.Status, wg *sync.WaitGroup) error {
50+
func (h *StatusHandler) LogPeriodicUpdates(statusChan <-chan zdns.Status, statusAbortChan <-chan struct{}, wg *sync.WaitGroup) error {
5151
defer wg.Done()
5252
// open file for writing
5353
var f *os.File
@@ -66,20 +66,25 @@ func (h *StatusHandler) LogPeriodicUpdates(statusChan <-chan zdns.Status, wg *sy
6666
}
6767
}(f)
6868
}
69-
if err := h.statusLoop(statusChan, f); err != nil {
69+
if err := h.statusLoop(statusChan, statusAbortChan, f); err != nil {
7070
return errors.Wrap(err, "error encountered in status loop")
7171
}
7272
return nil
7373
}
7474

7575
// statusLoop will print a per-second summary of the scan progress and per-status statistics
76-
func (h *StatusHandler) statusLoop(statusChan <-chan zdns.Status, statusFile *os.File) error {
76+
// statusChan is a channel that will receive the statuses from each lookup and updates it's internal state stats
77+
// statusAbortChan is used for the main thread to notify the status loop that it has aborted and we should notify the
78+
// user appropriately
79+
// statusFile is where to write the status updates to
80+
func (h *StatusHandler) statusLoop(statusChan <-chan zdns.Status, statusAbortChan <-chan struct{}, statusFile *os.File) error {
7781
// initialize stats
7882
stats := scanStats{
7983
statusOccurance: make(map[zdns.Status]int),
8084
scanStartTime: time.Now(),
8185
}
8286
ticker := time.NewTicker(time.Second)
87+
scanAborted := false
8388
statusLoop:
8489
for {
8590
select {
@@ -111,13 +116,26 @@ statusLoop:
111116
stats.statusOccurance[status] = 0
112117
}
113118
stats.statusOccurance[status] += 1
119+
case _, ok := <-statusAbortChan:
120+
if ok {
121+
// We can't exit the loop here because the lookup threads will block on writing to the status channel.
122+
// Instead, we'll set a flag and when we break out of this loop later (because the status channel is closed)
123+
// we'll print an abort message for the user.
124+
scanAborted = true
125+
}
114126
}
115127
}
128+
scanStateString := "Scan Complete"
129+
if scanAborted {
130+
scanStateString = "Scan Aborted"
131+
}
132+
116133
timeSinceStart := time.Since(stats.scanStartTime)
117-
s := fmt.Sprintf("%02dh:%02dm:%02ds; Scan Complete; %d names scanned; %.02f names/sec; %.01f%% success rate; %s\n",
134+
s := fmt.Sprintf("%02dh:%02dm:%02ds; %s; %d names scanned; %.02f names/sec; %.01f%% success rate; %s\n",
118135
int(timeSinceStart.Hours()),
119136
int(timeSinceStart.Minutes())%60,
120137
int(timeSinceStart.Seconds())%60,
138+
scanStateString,
121139
stats.domainsScanned,
122140
float64(stats.domainsScanned)/time.Since(stats.scanStartTime).Seconds(),
123141
float64(stats.domainsSuccess*100)/float64(stats.domainsScanned),

src/cli/iohandlers/stream_handlers.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ package iohandlers
1616

1717
import (
1818
"bufio"
19+
"context"
1920
"io"
2021
"sync"
2122

2223
"github.com/pkg/errors"
2324

2425
log "github.com/sirupsen/logrus"
26+
27+
"github.com/zmap/zdns/v2/src/internal/util"
2528
)
2629

2730
type StreamInputHandler struct {
@@ -34,12 +37,15 @@ func NewStreamInputHandler(r io.Reader) *StreamInputHandler {
3437
}
3538
}
3639

37-
func (h *StreamInputHandler) FeedChannel(in chan<- string, wg *sync.WaitGroup) error {
40+
func (h *StreamInputHandler) FeedChannel(ctx context.Context, in chan<- string, wg *sync.WaitGroup) error {
3841
defer close(in)
3942
defer (*wg).Done()
4043

4144
s := bufio.NewScanner(h.reader)
4245
for s.Scan() {
46+
if util.HasCtxExpired(ctx) {
47+
return nil // context expired, exiting
48+
}
4349
in <- s.Text()
4450
}
4551
if err := s.Err(); err != nil {

src/cli/iohandlers/string_slice_input_handler.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
package iohandlers
1616

1717
import (
18+
"context"
1819
"sync"
1920

2021
log "github.com/sirupsen/logrus"
22+
23+
"github.com/zmap/zdns/v2/src/internal/util"
2124
)
2225

2326
// StringSliceInputHandler Feeds a channel with the strings in the slice.
@@ -32,10 +35,13 @@ func NewStringSliceInputHandler(domains []string) *StringSliceInputHandler {
3235
return &StringSliceInputHandler{Names: domains}
3336
}
3437

35-
func (h *StringSliceInputHandler) FeedChannel(in chan<- string, wg *sync.WaitGroup) error {
38+
func (h *StringSliceInputHandler) FeedChannel(ctx context.Context, in chan<- string, wg *sync.WaitGroup) error {
3639
defer close(in)
3740
defer wg.Done()
3841
for _, name := range h.Names {
42+
if util.HasCtxExpired(ctx) {
43+
return nil
44+
}
3945
in <- name
4046
}
4147
return nil

src/cli/modules.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626

2727
type LookupModule interface {
2828
CLIInit(gc *CLIConf, rc *zdns.ResolverConfig) error
29-
Lookup(resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error)
29+
Lookup(ctx context.Context, resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error)
3030
Help() string // needed to satisfy the ZCommander interface in ZFlags.
3131
GetDescription() string // needed to add a command to the parser, printed to the user. Printed to the user when they run the help command for a given module
3232
Validate(args []string) error // needed to satisfy the ZCommander interface in ZFlags
@@ -183,17 +183,17 @@ func (lm *BasicLookupModule) NewFlags() interface{} {
183183
// non-Iterative query -> we'll send a query to the nameserver provided. If none provided, a random nameserver from the resolver's external nameservers will be used
184184
// iterative + all-Nameservers query -> we'll send a query to each root NS and query all nameservers down the chain.
185185
// iterative query -> we'll send a query to a random root NS and query all nameservers down the chain.
186-
func (lm *BasicLookupModule) Lookup(resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
186+
func (lm *BasicLookupModule) Lookup(ctx context.Context, resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
187187
if lm.LookupAllNameServers && lm.IsIterative {
188-
return resolver.LookupAllNameserversIterative(&zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass}, nil)
188+
return resolver.LookupAllNameserversIterative(ctx, &zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass}, nil)
189189
}
190190
if lm.LookupAllNameServers {
191-
return resolver.LookupAllNameserversExternal(&zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass}, nil)
191+
return resolver.LookupAllNameserversExternal(ctx, &zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass}, nil)
192192
}
193193
if lm.IsIterative {
194-
return resolver.IterativeLookup(context.Background(), &zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass})
194+
return resolver.IterativeLookup(ctx, &zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass})
195195
}
196-
return resolver.ExternalLookup(context.Background(), &zdns.Question{Type: lm.DNSType, Class: lm.DNSClass, Name: lookupName}, nameServer)
196+
return resolver.ExternalLookup(ctx, &zdns.Question{Type: lm.DNSType, Class: lm.DNSClass, Name: lookupName}, nameServer)
197197
}
198198

199199
func GetLookupModule(name string) (LookupModule, error) {

src/cli/worker_manager.go

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@
1414
package cli
1515

1616
import (
17+
"context"
1718
"encoding/csv"
1819
"encoding/json"
1920
"fmt"
2021
"io"
2122
"math/rand"
2223
"net"
2324
"os"
25+
"os/signal"
2426
"runtime"
2527
"strconv"
2628
"strings"
2729
"sync"
30+
"syscall"
2831
"time"
2932

3033
"github.com/zmap/zcrypto/x509"
@@ -482,6 +485,10 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res
482485
}
483486

484487
func Run(gc CLIConf) {
488+
// Create a context that is cancelled on SIGINT or SIGTERM.
489+
ctx, cancel := context.WithCancel(context.Background())
490+
defer cancel()
491+
485492
gc = *populateCLIConfig(&gc)
486493
resolverConfig := populateResolverConfig(&gc)
487494
// Log any information about the resolver configuration, according to log level
@@ -502,10 +509,21 @@ func Run(gc CLIConf) {
502509
// - process until inChan closes, then wg.done()
503510
// Once we processing threads have all finished, wait until the
504511
// output and metadata threads have completed
505-
inChan := make(chan string)
506-
outChan := make(chan string)
507-
metaChan := make(chan routineMetadata, gc.Threads)
508-
statusChan := make(chan zdns.Status)
512+
inChan := make(chan string) // input handler feeds inChan with input
513+
outChan := make(chan string) // lookup workers write to outChan, output handler reads from outChan
514+
metaChan := make(chan routineMetadata, gc.Threads) // lookup workers write to metaChan, metadata is collected at end of scan
515+
statusChan := make(chan zdns.Status) // lookup workers write to status chan after each lookup, status handler reads from statusChan
516+
statusAbortChan := make(chan struct{}) // used by this thread to signal to the status handler that we've aborted
517+
518+
sigs := make(chan os.Signal, 1)
519+
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
520+
521+
go func() {
522+
<-sigs // SIGINT or SIGTERM received
523+
cancel() // signal to all goroutines to clean up
524+
statusAbortChan <- struct{}{} // signal to status handler that we've aborted so it will print a suitable message
525+
// to the user vs. the 'Scan Complete' it normally prints
526+
}()
509527
var routineWG sync.WaitGroup
510528

511529
inHandler := gc.InputHandler
@@ -525,7 +543,7 @@ func Run(gc CLIConf) {
525543

526544
// Use handlers to populate the input and output/results channel
527545
go func() {
528-
if inErr := inHandler.FeedChannel(inChan, &routineWG); inErr != nil {
546+
if inErr := inHandler.FeedChannel(ctx, inChan, &routineWG); inErr != nil {
529547
log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr))
530548
}
531549
}()
@@ -539,7 +557,7 @@ func Run(gc CLIConf) {
539557

540558
if !gc.QuietStatusUpdates {
541559
go func() {
542-
if statusErr := statusHandler.LogPeriodicUpdates(statusChan, &routineWG); statusErr != nil {
560+
if statusErr := statusHandler.LogPeriodicUpdates(statusChan, statusAbortChan, &routineWG); statusErr != nil {
543561
log.Fatal(fmt.Sprintf("could not log periodic status updates: %v", statusErr))
544562
}
545563
}()
@@ -554,7 +572,7 @@ func Run(gc CLIConf) {
554572
for i := 0; i < gc.Threads; i++ {
555573
i := i
556574
go func(threadID int) {
557-
initWorkerErr := doLookupWorker(&gc, resolverConfig, inChan, outChan, metaChan, statusChan, &lookupWG)
575+
initWorkerErr := doLookupWorker(ctx, &gc, resolverConfig, inChan, outChan, metaChan, statusChan, &lookupWG)
558576
if initWorkerErr != nil {
559577
log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr)
560578
}
@@ -564,6 +582,7 @@ func Run(gc CLIConf) {
564582
close(outChan)
565583
close(metaChan)
566584
close(statusChan)
585+
close(statusAbortChan)
567586
routineWG.Wait()
568587
if gc.MetadataFilePath != "" {
569588
// we're done processing data. aggregate all the data from individual routines
@@ -610,25 +629,27 @@ func Run(gc CLIConf) {
610629
}
611630

612631
// doLookupWorker is a single worker thread that processes lookups from the input channel. It calls wg.Done when it is finished.
613-
func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan string, outputChan chan<- string, metaChan chan<- routineMetadata, statusChan chan<- zdns.Status, wg *sync.WaitGroup) error {
632+
func doLookupWorker(ctx context.Context, gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan string, outputChan chan<- string, metaChan chan<- routineMetadata, statusChan chan<- zdns.Status, wg *sync.WaitGroup) error {
614633
defer wg.Done()
615634
resolver, err := zdns.InitResolver(rc)
616635
if err != nil {
617636
return fmt.Errorf("could not init resolver: %w", err)
618637
}
638+
defer resolver.Close() // close the resolver, freeing up resources
619639
var metadata routineMetadata
620640
metadata.Status = make(map[zdns.Status]int)
621641

622642
for line := range inputChan {
623-
handleWorkerInput(gc, rc, line, resolver, &metadata, outputChan, statusChan)
643+
if util.HasCtxExpired(ctx) {
644+
break
645+
}
646+
handleWorkerInput(ctx, gc, rc, line, resolver, &metadata, outputChan, statusChan)
624647
}
625-
// close the resolver, freeing up resources
626-
resolver.Close()
627648
metaChan <- metadata
628649
return nil
629650
}
630651

631-
func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolver *zdns.Resolver, metadata *routineMetadata, outputChan chan<- string, statusChan chan<- zdns.Status) {
652+
func handleWorkerInput(ctx context.Context, gc *CLIConf, rc *zdns.ResolverConfig, line string, resolver *zdns.Resolver, metadata *routineMetadata, outputChan chan<- string, statusChan chan<- zdns.Status) {
632653
// we'll process each module sequentially, parallelism is per-domain
633654
res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))}
634655
// get the fields that won't change for each lookup module
@@ -685,7 +706,7 @@ func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolv
685706
res.Class = dns.Class(gc.Class).String()
686707

687708
startTime := time.Now()
688-
innerRes, trace, status, err = module.Lookup(resolver, lookupName, nameServer)
709+
innerRes, trace, status, err = module.Lookup(ctx, resolver, lookupName, nameServer)
689710

690711
lookupRes := zdns.SingleModuleResult{
691712
Timestamp: time.Now().Format(gc.TimeFormat),

src/modules/alookup/a_lookup.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package alookup
1616

1717
import (
18+
"context"
19+
1820
"github.com/pkg/errors"
1921

2022
"github.com/zmap/zdns/v2/src/cli"
@@ -51,8 +53,8 @@ func (aMod *ALookupModule) Init(ipv4Lookup bool, ipv6Lookup bool) {
5153
aMod.IPv6Lookup = ipv6Lookup
5254
}
5355

54-
func (aMod *ALookupModule) Lookup(r *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
55-
ipResult, trace, status, err := r.DoTargetedLookup(lookupName, nameServer, aMod.baseModule.IsIterative, aMod.IPv4Lookup, aMod.IPv6Lookup)
56+
func (aMod *ALookupModule) Lookup(ctx context.Context, r *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
57+
ipResult, trace, status, err := r.DoTargetedLookup(ctx, lookupName, nameServer, aMod.baseModule.IsIterative, aMod.IPv4Lookup, aMod.IPv6Lookup)
5658
return ipResult, trace, status, err
5759
}
5860

0 commit comments

Comments
 (0)