diff --git a/go.mod b/go.mod index ff3bc144..56ea0b20 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/quic-go/qpack v0.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/txthinking/runnergroup v0.0.0-20230325130830-408dc5853f86 // indirect + github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.35.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect diff --git a/go.sum b/go.sum index e131ec8f..9a9ab26d 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/txthinking/runnergroup v0.0.0-20230325130830-408dc5853f86/go.mod h1:c github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 h1:d/Wr/Vl/wiJHc3AHYbYs5I3PucJvRuw3SvbmlIRf+oM= github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301/go.mod h1:ntmMHL/xPq1WLeKiw8p/eRATaae6PiVRNipHFJxI8PM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2 h1:PF/PSu+RcSOzNpCqGo6KEO+4qw+LEiX/oqVxGm8rup8= +github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/lua-edns0.go b/lua-edns0.go new file mode 100644 index 00000000..2740fbd4 --- /dev/null +++ b/lua-edns0.go @@ -0,0 +1,900 @@ +package rdns + +import ( + "fmt" + "net" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// EDNS0 functions + +func (s *LuaScript) RegisterEDNS0Types() { + s.registerEDNS0COOKIEType() + s.registerEDNS0DAUType() + s.registerEDNS0DHUType() + s.registerEDNS0EDEType() + s.registerEDNS0ESUType() + s.registerEDNS0EXPIREType() + s.registerEDNS0LLQType() + s.registerEDNS0LOCALType() + s.registerEDNS0N3UType() + s.registerEDNS0NSIDType() + s.registerEDNS0PADDINGType() + s.registerEDNS0SUBNETType() + s.registerEDNS0TCPKEEPALIVEType() + s.registerEDNS0ULType() +} + +func (s *LuaScript) registerEDNS0COOKIEType() { + L := s.L + mtName := "EDNS0_COOKIE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_COOKIE) + e.Code = dns.EDNS0COOKIE + nArgs := L.GetTop() + if nArgs >= 1 { // Cookie + e.Cookie = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_COOKIE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "cookie": + L.Push(lua.LString(e.Cookie)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_COOKIE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "cookie": + e.Cookie = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0DAUType() { + L := s.L + mtName := "EDNS0_DAU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_DAU) + e.Code = dns.EDNS0DAU + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DAU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DAU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0DHUType() { + L := s.L + mtName := "EDNS0_DHU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_DHU) + e.Code = dns.EDNS0DHU + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DHU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DHU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0EDEType() { + L := s.L + mtName := "EDNS0_EDE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_EDE) + nArgs := L.GetTop() + if nArgs >= 1 { // Code + e.InfoCode = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Extra Text + e.ExtraText = L.CheckString(2) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EDE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "infocode": + L.Push(lua.LNumber(e.InfoCode)) + case "extratext": + L.Push(lua.LString(e.ExtraText)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EDE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "infocode": + e.InfoCode = uint16(L.CheckNumber(3)) + case "extratext": + e.ExtraText = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0ESUType() { + L := s.L + mtName := "EDNS0_ESU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_ESU) + e.Code = dns.EDNS0ESU + nArgs := L.GetTop() + if nArgs >= 1 { // URI + e.Uri = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_ESU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "uri": + L.Push(lua.LString(e.Uri)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_ESU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "uri": + e.Uri = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0EXPIREType() { + L := s.L + mtName := "EDNS0_EXPIRE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_EXPIRE) + e.Code = dns.EDNS0EXPIRE + nArgs := L.GetTop() + if nArgs >= 1 { // Expire + e.Expire = uint32(L.CheckNumber(1)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EXPIRE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "expire": + L.Push(lua.LNumber(e.Expire)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EXPIRE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "expire": + e.Expire = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0LLQType() { + L := s.L + mtName := "EDNS0_LLQ" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_LLQ) + e.Code = dns.EDNS0LLQ + nArgs := L.GetTop() + if nArgs >= 1 { // Version + e.Version = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Opcode + e.Opcode = uint16(L.CheckNumber(2)) + } + if nArgs >= 3 { // Error + e.Error = uint16(L.CheckNumber(3)) + } + if nArgs >= 4 { // Id + e.Id = uint64(L.CheckNumber(4)) + } + if nArgs >= 5 { // LeaseLife + e.LeaseLife = uint32(L.CheckNumber(5)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LLQ](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "version": + L.Push(lua.LNumber(e.Version)) + case "opcode": + L.Push(lua.LNumber(e.Opcode)) + case "error": + L.Push(lua.LNumber(e.Error)) + case "id": + L.Push(lua.LNumber(e.Id)) + case "leaselife": + L.Push(lua.LNumber(e.LeaseLife)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LLQ](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "version": + e.Version = uint16(L.CheckNumber(3)) + case "opcode": + e.Opcode = uint16(L.CheckNumber(3)) + case "error": + e.Error = uint16(L.CheckNumber(3)) + case "id": + e.Id = uint64(L.CheckNumber(3)) + case "leaselife": + e.LeaseLife = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0LOCALType() { + L := s.L + mtName := "EDNS0_LOCAL" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_LOCAL) + nArgs := L.GetTop() + if nArgs >= 1 { // Code + e.Code = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Data + e.Data = []byte(L.CheckString(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LOCAL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "code": + L.Push(lua.LNumber(e.Code)) + case "data": + L.Push(lua.LString(e.Data)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LOCAL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "code": + e.Code = uint16(L.CheckNumber(3)) + case "data": + e.Data = []byte(L.CheckString(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0N3UType() { + L := s.L + mtName := "EDNS0_N3U" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_N3U) + e.Code = dns.EDNS0N3U + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_N3U](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_N3U](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0NSIDType() { + L := s.L + mtName := "EDNS0_NSID" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_NSID) + e.Code = dns.EDNS0NSID + nArgs := L.GetTop() + if nArgs >= 1 { // NSID + e.Nsid = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_NSID](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "nsid": + L.Push(lua.LString(e.Nsid)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_NSID](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "nsid": + e.Nsid = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0PADDINGType() { + L := s.L + mtName := "EDNS0_PADDING" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_PADDING) + nArgs := L.GetTop() + if nArgs >= 1 { // NSID + e.Padding = []byte(L.CheckString(1)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_PADDING](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "padding": + L.Push(lua.LString(e.Padding)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_PADDING](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "padding": + e.Padding = []byte(L.CheckString(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0SUBNETType() { + L := s.L + mtName := "EDNS0_SUBNET" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + nArgs := L.GetTop() + if nArgs >= 1 { // Family + e.Family = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // SourceNetmask + e.SourceNetmask = uint8(L.CheckNumber(2)) + } + if nArgs >= 3 { // SourceScope + e.SourceScope = uint8(L.CheckNumber(3)) + } + if nArgs >= 4 { // Address + value := L.CheckString(4) + ip := net.ParseIP(value) + if ip == nil { + L.ArgError(4, fmt.Sprintf("expected IP address, got %q", value)) + return 0 + } + e.Address = ip + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_SUBNET](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "family": + L.Push(lua.LNumber(e.Family)) + case "sourcenetmask": + L.Push(lua.LNumber(e.SourceNetmask)) + case "sourcescope": + L.Push(lua.LNumber(e.SourceScope)) + case "address": + L.Push(lua.LString(e.Address.String())) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_SUBNET](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "family": + e.Family = uint16(L.CheckNumber(3)) + case "sourcenetmask": + e.SourceNetmask = uint8(L.CheckNumber(3)) + case "sourcescope": + e.SourceScope = uint8(L.CheckNumber(3)) + case "address": + value := L.CheckString(3) + ip := net.ParseIP(value) + if ip == nil { + L.ArgError(4, fmt.Sprintf("expected IP address, got %q", value)) + return 0 + } + e.Address = ip + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0TCPKEEPALIVEType() { + L := s.L + mtName := "EDNS0_TCP_KEEPALIVE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_TCP_KEEPALIVE) + e.Code = dns.EDNS0TCPKEEPALIVE + nArgs := L.GetTop() + if nArgs >= 1 { // Timeout + e.Timeout = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Length + e.Length = uint16(L.CheckNumber(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_TCP_KEEPALIVE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "timeout": + L.Push(lua.LNumber(e.Timeout)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_TCP_KEEPALIVE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "timeout": + e.Timeout = uint16(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0ULType() { + L := s.L + mtName := "EDNS0_UL" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_UL) + e.Code = dns.EDNS0UL + nArgs := L.GetTop() + if nArgs >= 1 { // Lease + e.Lease = uint32(L.CheckNumber(1)) + } + if nArgs >= 2 { // KeyLease + e.KeyLease = uint32(L.CheckNumber(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_UL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "lease": + L.Push(lua.LNumber(e.Lease)) + case "keylease": + L.Push(lua.LNumber(e.KeyLease)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_UL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "lease": + e.Lease = uint32(L.CheckNumber(3)) + case "keylease": + e.KeyLease = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} diff --git a/lua-error.go b/lua-error.go new file mode 100644 index 00000000..8224d97b --- /dev/null +++ b/lua-error.go @@ -0,0 +1,33 @@ +package rdns + +import ( + "errors" + + lua "github.com/yuin/gopher-lua" +) + +// Error functions + +const luaErrorMetatableName = "Error" + +func (s *LuaScript) RegisterErrorType() { + L := s.L + mt := L.NewTypeMetatable(luaErrorMetatableName) + L.SetGlobal(luaErrorMetatableName, mt) + + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + err := errors.New(L.CheckString(1)) + L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "error": method(func(L *lua.LState, r error) int { + L.Push(lua.LString(r.Error())) + return 1 + }), + })) +} diff --git a/lua-helpers.go b/lua-helpers.go new file mode 100644 index 00000000..0acf2ba1 --- /dev/null +++ b/lua-helpers.go @@ -0,0 +1,72 @@ +package rdns + +import ( + "fmt" + "reflect" + + lua "github.com/yuin/gopher-lua" +) + +// Helper functions + +func method[T any](f func(*lua.LState, T) int) func(*lua.LState) int { + return func(L *lua.LState) int { + if L.GetTop() < 1 { + L.ArgError(1, "expected at least 1 argument") + return 0 + } + ud := L.CheckUserData(1) + r, ok := ud.Value.(T) + if !ok { + L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) + return 0 + } + return f(L, r) + } +} + +func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { + ud := L.CheckUserData(n) + v, ok := ud.Value.(T) + if !ok { + L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) + return v, false + } + return v, true +} + +func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { + ud := L.NewUserData() + ud.Value = value + L.SetMetatable(ud, L.GetTypeMetatable(mtName)) + return ud +} + +type numbers interface { + int | int8 | int16 | int32 | int64 | float32 | float64 | uint | uint8 | uint16 | uint32 | uint64 +} + +func getNumberSlice[T numbers](L *lua.LState, n int) ([]T, bool) { + table := L.CheckTable(n) + size := table.Len() + values := make([]T, 0, size) + for i := range size { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTNumber { + L.ArgError(n, "invalid type, expected number") + return nil, false + } + value := T(element.(lua.LNumber)) + values = append(values, value) + } + return values, true +} + +func numberSliceToTable[T numbers](L *lua.LState, values []T) *lua.LTable { + table := L.CreateTable(len(values), 0) + for _, value := range values { + table.Append(lua.LNumber(value)) + } + L.Push(table) + return table +} diff --git a/lua-message.go b/lua-message.go new file mode 100644 index 00000000..c3490305 --- /dev/null +++ b/lua-message.go @@ -0,0 +1,143 @@ +package rdns + +import ( + "fmt" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Message functions + +const luaMessageMetatableName = "Message" + +func (s *LuaScript) RegisterMessageType() { + L := s.L + mt := L.NewTypeMetatable(luaMessageMetatableName) + L.SetGlobal(luaMessageMetatableName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(func(L *lua.LState) int { + L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg))) + return 1 + })) + // methods and fields + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + msg, ok := getUserDataArg[*dns.Msg](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "questions": + table := L.CreateTable(len(msg.Question), 0) + for _, q := range msg.Question { + lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) + table.Append(lv) + } + L.Push(table) + case "id": + L.Push(lua.LNumber(msg.Id)) + case "response": + L.Push(lua.LBool(msg.Response)) + case "rcode": + L.Push(lua.LNumber(msg.Rcode)) + case "recursion_desired": + L.Push(lua.LBool(msg.RecursionDesired)) + case "recursion_available": + L.Push(lua.LBool(msg.RecursionAvailable)) + case "authenticated_data": + L.Push(lua.LBool(msg.AuthenticatedData)) + case "set_reply": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + request, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return 0 + } + msg.SetReply(request) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 + }))) + case "set_question": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3))) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 + }))) + case "is_edns0": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + opt := msg.IsEdns0() + if opt == nil { + L.Push(lua.LNil) + return 1 + } + // TODO: Return an OPT metatable name here, not generic RR + lv := userDataWithMetatable(L, luaRRHeaderMetatableName, opt) + L.Push(lv) + return 1 + }))) + case "set_edns0": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + msg.SetEdns0(uint16(L.CheckNumber(2)), L.CheckBool(3)) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 + }))) + default: + L.ArgError(2, fmt.Sprintf("message does not have field %q", fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + msg, ok := getUserDataArg[*dns.Msg](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "questions": + table := L.CheckTable(3) + n := table.Len() + questions := make([]dns.Question, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTUserData { + L.ArgError(3, "invalid type, expected userdata") + return 0 + } + lq := element.(*lua.LUserData) + q, ok := lq.Value.(*dns.Question) + if !ok { + L.ArgError(3, "invalid type, expected question") + return 0 + } + questions = append(questions, *q) + } + msg.Question = questions + case "id": + msg.Id = uint16(L.CheckInt(3)) + case "response": + msg.Response = L.CheckBool(3) + case "rcode": + msg.Rcode = L.CheckInt(3) + case "recursion_desired": + msg.RecursionDesired = L.CheckBool(3) + case "recursion_available": + msg.RecursionAvailable = L.CheckBool(3) + case "authenticated_data": + msg.AuthenticatedData = L.CheckBool(3) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 0 + })) +} diff --git a/lua-question.go b/lua-question.go new file mode 100644 index 00000000..b4fc94f9 --- /dev/null +++ b/lua-question.go @@ -0,0 +1,77 @@ +package rdns + +import ( + "fmt" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Question functions + +const luaQuestionMetatableName = "Question" + +func (s *LuaScript) RegisterQuestionType() { + L := s.L + mt := L.NewTypeMetatable(luaQuestionMetatableName) + L.SetGlobal(luaQuestionMetatableName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + q := &dns.Question{Qclass: dns.ClassINET} + nArgs := L.GetTop() + if nArgs >= 1 { // Name provided + q.Name = L.CheckString(1) + } + if nArgs >= 2 { // Name and type + q.Qtype = uint16(L.CheckNumber(2)) + } + if nArgs >= 3 { // Name, type and class + q.Qclass = uint16(L.CheckNumber(3)) + } + L.Push(userDataWithMetatable(L, luaQuestionMetatableName, q)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + question, ok := getUserDataArg[*dns.Question](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "name": + L.Push(lua.LString(question.Name)) + case "qtype": + L.Push(lua.LNumber(question.Qtype)) + case "qclass": + L.Push(lua.LNumber(question.Qclass)) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + question, ok := getUserDataArg[*dns.Question](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "name": + question.Name = L.CheckString(3) + case "qtype": + question.Qtype = uint16(L.CheckNumber(3)) + case "qclass": + question.Qclass = uint16(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 0 + })) +} diff --git a/lua-resolver.go b/lua-resolver.go new file mode 100644 index 00000000..63f0e146 --- /dev/null +++ b/lua-resolver.go @@ -0,0 +1,61 @@ +package rdns + +import ( + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Resolver functions + +const luaResolverMetatableName = "Resolver" + +func (s *LuaScript) InjectResolvers(resolvers []Resolver) { + L := s.L + mt := L.NewTypeMetatable(luaResolverMetatableName) + L.SetGlobal("Resolver", mt) + + // Methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "resolve": resolverResolve, + })) + + table := L.CreateTable(len(resolvers), 0) + for _, r := range resolvers { + lv := userDataWithMetatable(L, luaResolverMetatableName, r) + table.Append(lv) + } + L.SetGlobal("Resolvers", table) +} + +func resolverResolve(L *lua.LState) int { + if L.GetTop() != 3 { + L.ArgError(1, "expected 2 arguments") + return 0 + } + r, ok := getUserDataArg[Resolver](L, 1) + if !ok { + return 0 + } + msg, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return 0 + } + ci, ok := getUserDataArg[ClientInfo](L, 3) + if !ok { + return 0 + } + + resp, err := r.Resolve(msg, ci) + + // Return the answer + L.Push(userDataWithMetatable(L, luaMessageMetatableName, resp)) + + // Return the error + if err != nil { + L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) + } else { + L.Push(lua.LNil) + } + + return 2 +} diff --git a/lua-rr.go b/lua-rr.go new file mode 100644 index 00000000..22f74a8d --- /dev/null +++ b/lua-rr.go @@ -0,0 +1,448 @@ +package rdns + +import ( + "fmt" + "net" + "reflect" + "strings" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// RR functions + +const luaRRHeaderMetatableName = "RR" + +func (s *LuaScript) RegisterRRTypes() { + L := s.L + + mt := L.NewTypeMetatable(luaRRHeaderMetatableName) + L.SetGlobal(luaRRHeaderMetatableName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + table := L.CheckTable(1) + lrtype, ok := table.RawGetString("rtype").(lua.LNumber) + if !ok { + L.ArgError(1, "rtype must be a number") + return 0 + } + rtype := uint16(lrtype) + + rrFunc, ok := dns.TypeToRR[rtype] + if !ok { + L.ArgError(1, "unknown rtype") + return 0 + } + rr := rrFunc() + + var err error + table.ForEach(func(k, v lua.LValue) { + if k.Type() != lua.LTString { + if err != nil { // Only record the first error + err = fmt.Errorf("expecte string keys, got %s", k.Type().String()) + } + return + } + if k.String() == "rtype" { + // We don't allow this to be set or updated + rr.Header().Rrtype = rtype + return + } + if setErr := rrDB.set(L, rr, k.String(), v); setErr != nil && err == nil { + err = setErr + } + }) + if err != nil { + L.ArgError(1, err.Error()) + return 0 + } + + L.Push(userDataWithMetatable(L, luaRRHeaderMetatableName, rr)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + rr, ok := getUserDataArg[dns.RR](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + + lv, err := rrDB.get(L, rr, fieldName) + if err != nil { + L.ArgError(1, err.Error()) // TODO: figure out arg position + return 0 + } + L.Push(lv) + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + rr, ok := getUserDataArg[dns.RR](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + if fieldName == "" { + return 0 + } + value := L.CheckAny(3) + + if err := rrDB.set(L, rr, fieldName, value); err != nil { + L.ArgError(1, err.Error()) // TODO: figure out arg position + return 0 + } + return 0 + })) +} + +type rrFieldDB map[reflect.Type]map[string]rrFieldAccessors + +type rrFieldAccessors struct { + index []int + get func(*lua.LState, reflect.Value) (lua.LValue, error) + set func(*lua.LState, reflect.Value, lua.LValue) error +} + +var rrDB = func() rrFieldDB { + db := make(map[reflect.Type]map[string]rrFieldAccessors) + + for _, rrFunc := range dns.TypeToRR { + rr := rrFunc() + typ := reflect.TypeOf(rr) + db[typ] = rrFieldsForType(typ.Elem(), nil) + } + return db +}() + +func rrFieldsForType(typ reflect.Type, index []int) map[string]rrFieldAccessors { + fields := make(map[string]rrFieldAccessors) + for _, field := range reflect.VisibleFields(typ) { + if !field.IsExported() { + continue + } + // All RR have a header and we handle that directly, without reflection + if field.Name == "Hdr" { + continue + } + a := rrFieldAccessors{ + index: append(index, field.Index...), + } + switch field.Type { + case reflect.TypeOf(net.IP{}): + a.get, a.set = getIPField, setIPField + case reflect.TypeOf(""): + a.get, a.set = getStringField, setStringField + case reflect.TypeOf(uint8(0)): + a.get, a.set = getUint8Field, setUint8Field + case reflect.TypeOf(uint16(0)): + a.get, a.set = getUint16Field, setUint16Field + case reflect.TypeOf(uint32(0)): + a.get, a.set = getUint32Field, setUint32Field + case reflect.TypeOf(uint64(0)): + a.get, a.set = getUint64Field, setUint64Field + case reflect.TypeOf([]uint16{}): + a.get, a.set = getUint16SliceField, setUint16SliceField + case reflect.TypeOf([]string{}): + a.get, a.set = getStringSliceField, setStringSliceField + case reflect.TypeOf(dns.DS{}): // Composed in DLV + return rrFieldsForType(reflect.TypeOf(dns.DS{}), field.Index) + case reflect.TypeOf(dns.SVCB{}): // Composed in HTTPS + return rrFieldsForType(reflect.TypeOf(dns.SVCB{}), field.Index) + case reflect.TypeOf(dns.NSEC{}): // Composed in NXT + return rrFieldsForType(reflect.TypeOf(dns.NSEC{}), field.Index) + case reflect.TypeOf([]dns.APLPrefix{}): // Used in APL + a.get, a.set = getUnsupported(field.Name), setUnsupported(field.Name) + case reflect.TypeOf(dns.RRSIG{}): // Composed in SIG + return rrFieldsForType(reflect.TypeOf(dns.RRSIG{}), field.Index) + case reflect.TypeOf(dns.DNSKEY{}): // Composed in KEY + return rrFieldsForType(reflect.TypeOf(dns.DNSKEY{}), field.Index) + case reflect.TypeOf([]dns.SVCBKeyValue{}): // interface + a.get, a.set = getUnsupported(field.Name), setUnsupported(field.Name) + case reflect.TypeOf([]dns.EDNS0{}): // in OPT + a.get, a.set = getEDNS0SliceField, setEDNS0SliceField + default: + panic(fmt.Errorf("unsupported RR field value type %v in %s", field.Type, typ)) + } + + fields[strings.ToLower(field.Name)] = a + } + return fields +} + +func (db rrFieldDB) get(L *lua.LState, rr dns.RR, name string) (lua.LValue, error) { + // If the field is in the header, we handle that directly + switch name { + case "name": + return lua.LString(rr.Header().Name), nil + case "rtype": + return lua.LNumber(rr.Header().Rrtype), nil + case "class": + return lua.LNumber(rr.Header().Class), nil + case "ttl": + return lua.LNumber(rr.Header().Ttl), nil + case "rdlength": + return lua.LNumber(rr.Header().Rdlength), nil + } + + // Lookup the fields for this type + typeFields, ok := db[reflect.TypeOf(rr)] + if !ok { + return nil, luaArgError{1, fmt.Errorf("unsupported resource record type %v", reflect.TypeOf(rr).String())} + } + a, ok := typeFields[name] + if !ok { + return nil, luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} + } + fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) + return a.get(L, fieldValue) +} + +func (db rrFieldDB) set(L *lua.LState, rr dns.RR, name string, value lua.LValue) error { + // If the field is in the header, we handle that directly + switch name { + case "name": + if value.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string value, got %v", value.Type().String())} + } + rr.Header().Name = value.String() + return nil + case "rtype": + return luaArgError{2, fmt.Errorf("cannot change rtype directly")} + case "class": + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number value, got %v", value.Type().String())} + } + rr.Header().Class = uint16(value.(lua.LNumber)) + return nil + case "ttl": + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number value, got %v", value.Type().String())} + } + rr.Header().Ttl = uint32(value.(lua.LNumber)) + return nil + case "rdlength": + return luaArgError{2, fmt.Errorf("cannot change rdlength")} + } + + // Lookup the fields for this type + typeFields, ok := db[reflect.TypeOf(rr)] + if !ok { + return luaArgError{1, fmt.Errorf("unsupported resource record type %v", reflect.TypeOf(rr).String())} + } + a, ok := typeFields[name] + if !ok { + return luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} + } + fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) + return a.set(L, fieldValue, value) +} + +type luaArgError struct { + position int + error +} + +func getStringField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(string) + return lua.LString(field), nil +} + +func setStringField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string value, got %v", value.Type().String())} + } + fieldValue.SetString(value.String()) + return nil +} + +func getStringSliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]string) + table := L.CreateTable(len(field), 0) + for _, v := range field { + table.Append(lua.LString(v)) + } + return table, nil +} + +func setStringSliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + stringValues := make([]string, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string, got %v", element.Type().String())} + } + s := element.String() + stringValues = append(stringValues, s) + } + newVal := reflect.ValueOf(stringValues) + fieldValue.Set(newVal) + return nil +} + +func getUint16SliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]uint16) + table := L.CreateTable(len(field), 0) + for _, v := range field { + table.Append(lua.LNumber(v)) + } + return table, nil +} + +func setUint16SliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + values := make([]uint16, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", element.Type().String())} + } + lv := element.(lua.LNumber) + values = append(values, uint16(lv)) + } + newVal := reflect.ValueOf(values) + fieldValue.Set(newVal) + return nil +} + +func getUint8Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint8) + return lua.LNumber(field), nil +} + +func setUint8Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint16Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint16) + return lua.LNumber(field), nil +} + +func setUint16Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint32Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint32) + return lua.LNumber(field), nil +} + +func setUint32Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint64Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint64) + return lua.LNumber(field), nil +} + +func setUint64Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getIPField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(net.IP) + if field == nil { + return lua.LNil, nil + } + return lua.LString(field.String()), nil +} + +func setIPField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + switch value.Type() { + case lua.LTString: + ip := net.ParseIP(value.String()) + if ip == nil { + return nil + } + fieldValue.SetBytes(ip) + case lua.LTNil: + fieldValue.SetZero() + default: + return luaArgError{3, fmt.Errorf("expected string or nil, got %v", value.Type().String())} + } + return nil +} + +func getEDNS0SliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]dns.EDNS0) + table := L.CreateTable(len(field), 0) + for _, v := range field { + // TODO: This is a hacky way to determine the name of the + // metatable for EDNS0 recods. Ideally we reference some name + // constant or function exposed by the code that registers the + // EDNS0 types. + mtName := reflect.TypeOf(v).String() + if i := strings.LastIndex(mtName, "."); i >= 0 { + mtName = mtName[i+1:] + } + lv := userDataWithMetatable(L, mtName, v) + table.Append(lv) + } + return table, nil +} + +func setEDNS0SliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + stringValues := make([]dns.EDNS0, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + ud, ok := element.(*lua.LUserData) + if !ok { + return luaArgError{3, fmt.Errorf("expected userdata, got %v", element.Type().String())} + } + value, ok := ud.Value.(dns.EDNS0) + if !ok { + return luaArgError{3, fmt.Errorf("expected EDNS0, got %T", ud)} + } + stringValues = append(stringValues, value) + } + newVal := reflect.ValueOf(stringValues) + fieldValue.Set(newVal) + return nil +} +func getUnsupported(name string) func(L *lua.LState, v reflect.Value) (lua.LValue, error) { + return func(L *lua.LState, v reflect.Value) (lua.LValue, error) { + return nil, luaArgError{2, fmt.Errorf("getting %q not supported", name)} + } +} + +func setUnsupported(name string) func(L *lua.LState, v reflect.Value, value lua.LValue) error { + return func(L *lua.LState, v reflect.Value, value lua.LValue) error { + return luaArgError{2, fmt.Errorf("setting %q not supported", name)} + } +} diff --git a/lua-script.go b/lua-script.go new file mode 100644 index 00000000..be20ff27 --- /dev/null +++ b/lua-script.go @@ -0,0 +1,91 @@ +package rdns + +import ( + "fmt" + "io" + "slices" + + lua "github.com/yuin/gopher-lua" + "github.com/yuin/gopher-lua/parse" +) + +type ByteCode struct { + *lua.FunctionProto +} + +type LuaScript struct { + L *lua.LState +} + +// LuaCompile compiles lua script into bytecode. The returned bytecode can be used +// to instantiate one or more scripts. +func LuaCompile(reader io.Reader, name string) (ByteCode, error) { + chunk, err := parse.Parse(reader, name) + if err != nil { + return ByteCode{}, err + } + proto, err := lua.Compile(chunk, name) + if err != nil { + return ByteCode{}, err + } + return ByteCode{proto}, nil +} + +// NewScriptFromByteCode creates a new lua script from bytecode. +func NewScriptFromByteCode(b ByteCode) (*LuaScript, error) { + L := lua.NewState() + lfunc := L.NewFunctionFromProto(b.FunctionProto) + L.Push(lfunc) + return &LuaScript{L: L}, L.PCall(0, lua.MultRet, nil) +} + +func (s *LuaScript) HasFunction(name string) bool { + return s.L.GetGlobal(name).Type() == lua.LTFunction +} + +func (s *LuaScript) Call(fnName string, nret int, params ...any) ([]any, error) { + // args := make([]lua.LValue, 0, len(params)) + // TODO: implement + // for _, p := range params { + // args = append(args, userDataWithType(s.L, "", p)) + // } + + args := []lua.LValue{ + userDataWithMetatable(s.L, luaMessageMetatableName, params[0]), + userDataWithMetatable(s.L, "", params[1]), + } + + // Call the resolve() function in the lua script + if err := s.L.CallByParam(lua.P{ + Fn: s.L.GetGlobal(fnName), + NRet: nret, + Protect: true, + }, args...); err != nil { + return nil, fmt.Errorf("failed to call lua: %w", err) + } + + // Grab return values from the stack and add them to the result slice + // in reverse order + ret := make([]any, nret) + for i := range slices.Backward(ret) { + lv := s.L.Get(-1) + s.L.Pop(1) + + var v any + + switch lv.Type() { + case lua.LTNil: + v = nil + case lua.LTUserData: + ud := lv.(*lua.LUserData) + v = ud.Value + case lua.LTString: + v = lv.String() + default: + return nil, fmt.Errorf("unsupported return type: %v", lv.Type()) + } + ret[i] = v + } + + return ret, nil +} diff --git a/lua-types.go b/lua-types.go new file mode 100644 index 00000000..8ab20daa --- /dev/null +++ b/lua-types.go @@ -0,0 +1,44 @@ +package rdns + +import ( + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +func (s *LuaScript) RegisterConstants() { + L := s.L + + // Register TypeA, TypeAAAA, etc + for value, name := range dns.TypeToString { + L.SetGlobal("Type"+name, lua.LNumber(value)) + } + + // Register ClassIN, etc + for value, name := range dns.ClassToString { + L.SetGlobal("Class"+name, lua.LNumber(value)) + } + + // Register Rcodes, RcodeNOERROR, RcodeNXDOMAIN, etc + for value, name := range dns.RcodeToString { + L.SetGlobal("Rcode"+name, lua.LNumber(value)) + } + + // Register EDNS0 option codes + for name, value := range map[string]uint16{ + "EDNS0LLQ": 0x1, + "EDNS0UL": 0x2, + "EDNS0NSID": 0x3, + "EDNS0ESU": 0x4, + "EDNS0DAU": 0x5, + "EDNS0DHU": 0x6, + "EDNS0N3U": 0x7, + "EDNS0SUBNET": 0x8, + "EDNS0EXPIRE": 0x9, + "EDNS0COOKIE": 0xa, + "EDNS0TCPKEEPALIVE": 0xb, + "EDNS0PADDING": 0xc, + "EDNS0EDE": 0xf, + } { + L.SetGlobal(name, lua.LNumber(value)) + } +} diff --git a/lua.go b/lua.go new file mode 100644 index 00000000..5fb1fe9c --- /dev/null +++ b/lua.go @@ -0,0 +1,115 @@ +package rdns + +import ( + "errors" + "fmt" + "strings" + + "github.com/miekg/dns" +) + +type Lua struct { + id string + resolvers []Resolver + scripts chan *LuaScript + bytecode ByteCode + + opt LuaOptions +} + +var _ Resolver = &Lua{} + +type LuaOptions struct { + Script string + Concurrency uint +} + +func NewLua(id string, opt LuaOptions, resolvers ...Resolver) (*Lua, error) { + if opt.Concurrency == 0 { + opt.Concurrency = 4 + } + + // Compile the script + bytecode, err := LuaCompile(strings.NewReader(opt.Script), id) + if err != nil { + return nil, err + } + + r := &Lua{ + id: id, + resolvers: resolvers, + opt: opt, + scripts: make(chan *LuaScript, opt.Concurrency), + bytecode: bytecode, + } + + // Initialize scripts + for range opt.Concurrency { + s, err := r.newScript() + if err != nil { + return nil, err + } + r.scripts <- s + } + return r, nil +} + +func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { + s := <-r.scripts + defer func() { r.scripts <- s }() + + log := logger(r.id, q, ci) + + // Call the "resolve" function in the script. It should return 2 values. + ret, err := s.Call("Resolve", 2, q, ci) + if err != nil { + log.Error("failed to run lua script", "error", err) + return nil, err + } + + // Extract the answer and error from the returned values + if len(ret) != 2 { + return nil, fmt.Errorf("invalid return value, expected 2, got %d", len(ret)) + } + + answer, ok := ret[0].(*dns.Msg) + if ret[0] != nil && !ok { + return nil, fmt.Errorf("invalid return value, expected Message, got %T", ret[0]) + } + + err, ok = ret[1].(error) + if ret[1] != nil && !ok { + return nil, fmt.Errorf("invalid return value, expected Error, got %T", ret[1]) + } + + return answer, err +} + +func (r *Lua) String() string { + return r.id +} + +func (r *Lua) newScript() (*LuaScript, error) { + s, err := NewScriptFromByteCode(r.bytecode) + if err != nil { + return nil, err + } + + // Register types and methods + s.RegisterConstants() + s.RegisterMessageType() + s.RegisterQuestionType() + s.RegisterRRTypes() + s.RegisterEDNS0Types() + s.RegisterErrorType() + + // Inject the resolvers into the state (so they can be used in the script) + s.InjectResolvers(r.resolvers) + + // The script must contain a Resolve() function which is the entry point + if !s.HasFunction("Resolve") { + return nil, errors.New("no Resolve() function found in lua script") + } + + return s, nil +} diff --git a/lua_test.go b/lua_test.go new file mode 100644 index 00000000..340f89eb --- /dev/null +++ b/lua_test.go @@ -0,0 +1,461 @@ +package rdns + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func TestLuaSimplePassthrough(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + local resolver = Resolvers[1] + local answer, err = resolver:resolve(msg, ci) + if err ~= nil then + return nil, err + end + return answer, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) + require.Equal(t, 1, resolver.HitCount()) +} + +func TestLuaMissingResolveFunc(t *testing.T) { + opt := LuaOptions{ + Script: `function test() return nil, nil end`, + } + + resolver := new(TestResolver) + + _, err := NewLua("test-lua", opt, resolver) + require.Error(t, err) +} + +func TestLuaResolveError(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + return nil, Error.new("no bueno") +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + + _, err = r.Resolve(q, ci) + require.Error(t, err) + require.Zero(t, resolver.HitCount()) +} + +func TestLuaStaticAnswer(t *testing.T) { + tests := map[string]LuaOptions{ + "set_questions": { + Script: ` +function Resolve(msg, ci) + local question = Question.new("example.com.", TypeA) + local answer = Message.new() + answer.id = msg.id + answer.questions = { question } + answer.response = true + answer.rcode = RcodeNXDOMAIN + return answer, nil +end`, + }, + "set_question": { + Script: ` +function Resolve(msg, ci) + local answer = Message.new() + answer:set_question("example.com.", TypeA) + answer.id = msg.id + answer.response = true + answer.rcode = RcodeNXDOMAIN + return answer, nil +end`, + }, + "set_reply": { + Script: ` +function Resolve(msg, ci) + local answer = Message.new() + answer:set_reply(msg) + answer.rcode = RcodeNXDOMAIN + return answer, nil +end`, + }, + } + + for name, opt := range tests { + t.Run(name, func(t *testing.T) { + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + q.Id = 1234 + + answer, err := r.Resolve(q, ci) + require.NoError(t, err) + require.Equal(t, 0, resolver.HitCount()) + require.Equal(t, "example.com.", answer.Question[0].Name) + require.Equal(t, dns.TypeA, answer.Question[0].Qtype) + require.Equal(t, uint16(1234), answer.Id) + require.Equal(t, dns.RcodeNameError, answer.Rcode) + require.True(t, answer.Response) + }) + } +} + +func TestLuaQuestionOperations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new Question record and test value set/get operations + local question = Question.new("example.com.", TypeA) + if question.name ~= "example.com." or question.qtype ~= TypeA then + return nil, Error.new("unexpected name value") + end + question.name = "testing." + if question.name ~= "testing." then + return nil, Error.new("unexpected name value") + end + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} + +func TestLuaRROperations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new TXT record and test value set/get operations + rr = RR.new({rtype=TypeTXT, name="example.com.", class=ClassIN, ttl=60, txt={"hello", "world"}}) + if rr.txt[1] ~= "hello" then + return nil, Error.new("unexpected value") + end + rr.txt = {"bla"} + if rr.txt[1] ~= "bla" then + return nil, Error.new("unexpected value in txt") + end + + -- Create a new A record and test value set/get operations + rr = RR.new({rtype=TypeA, name="example.com.", class=ClassIN, ttl=60, a="1.2.3.4"}) + if rr.rtype ~= TypeA or rr.name ~= "example.com." then + return nil, Error.new("unexpected name value") + end + rr.a = "1.1.1.1" + if rr.a ~= "1.1.1.1" then + return nil, Error.new("unexpected ip value") + end + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} + +func TestLuaEDNS0Operations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new EDNS0 COOKIE and test value set/get operations + edns0 = EDNS0_COOKIE.new("24a5ac1a012345ff") + if edns0.cookie ~= "24a5ac1a012345ff" then + return nil, Error.new("unexpected value") + end + edns0.cookie = "bla" + if edns0.cookie ~= "bla" then + return nil, Error.new("unexpected value in edns0 cookie") + end + + -- Create a new EDNS0 DAU and test value set/get operations + edns0 = EDNS0_DAU.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 dau") + end + + -- Create a new EDNS0 DHU and test value set/get operations + edns0 = EDNS0_DHU.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 dhu") + end + + -- Create a new EDNS0 EDE and test value set/get operations + edns0 = EDNS0_EDE.new(15, "domain blocked") + if edns0.infocode ~= 15 then + return nil, Error.new("unexpected value") + end + edns0.extratext = "testing" + if edns0.extratext ~= "testing" then + return nil, Error.new("unexpected value in edns0 ede") + end + + -- Create a new EDNS0 ESU and test value set/get operations + edns0 = EDNS0_ESU.new("http://example.com") + if edns0.uri ~= "http://example.com" then + return nil, Error.new("unexpected value") + end + edns0.uri = "http://example.org" + if edns0.uri ~= "http://example.org" then + return nil, Error.new("unexpected value in edns0 ede") + end + + -- Create a new EDNS0 EXPIRE and test value set/get operations + edns0 = EDNS0_EXPIRE.new(123) + if edns0.expire ~= 123 then + return nil, Error.new("unexpected value") + end + edns0.expire = 124 + if edns0.expire ~= 124 then + return nil, Error.new("unexpected value in edns0 expire") + end + + -- Create a new EDNS0 LLQ and test value set/get operations + edns0 = EDNS0_LLQ.new(1, 16, 0, 1234, 4321) + if edns0.version ~= 1 then + return nil, Error.new("unexpected value") + end + if edns0.opcode ~= 16 then + return nil, Error.new("unexpected value") + end + if edns0.error ~= 0 then + return nil, Error.new("unexpected value") + end + if edns0.id ~= 1234 then + return nil, Error.new("unexpected value") + end + if edns0.leaselife ~= 4321 then + return nil, Error.new("unexpected value") + end + edns0.error = 1 + if edns0.error ~= 1 then + return nil, Error.new("unexpected value in edns0 llq") + end + + -- Create a new EDNS0 LOCAL and test value set/get operations + edns0 = EDNS0_LOCAL.new(65001, "somedata") + if edns0.code ~= 65001 then + return nil, Error.new("unexpected value") + end + edns0.data = "otherdata" + if edns0.data ~= "otherdata" then + return nil, Error.new("unexpected value in edns0 local") + end + + -- Create a new EDNS0 N3U and test value set/get operations + edns0 = EDNS0_N3U.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 n3u") + end + + -- Create a new EDNS0 NSID and test value set/get operations + edns0 = EDNS0_NSID.new("someid") + if edns0.nsid ~= "someid" then + return nil, Error.new("unexpected value") + end + edns0.nsid = "otherid" + if edns0.nsid ~= "otherid" then + return nil, Error.new("unexpected value in edns0 nsid") + end + + -- Create a new EDNS0 PADDING and test value set/get operations + edns0 = EDNS0_PADDING.new("somepadding") + if edns0.padding ~= "somepadding" then + return nil, Error.new("unexpected value") + end + edns0.padding = "otherpadding" + if edns0.padding ~= "otherpadding" then + return nil, Error.new("unexpected value in edns0 padding") + end + + -- Create a new EDNS0 SUBNET and test value set/get operations + edns0 = EDNS0_SUBNET.new(1, 32, 0, "192.168.0.0") + if edns0.family ~= 1 then + return nil, Error.new("unexpected value") + end + if edns0.sourcenetmask ~= 32 then + return nil, Error.new("unexpected value") + end + if edns0.sourcescope ~= 0 then + return nil, Error.new("unexpected value") + end + if edns0.address ~= "192.168.0.0" then + return nil, Error.new("unexpected value") + end + edns0.address = "172.16.0.0" + if edns0.address ~= "172.16.0.0" then + return nil, Error.new("unexpected value in edns0 subnet") + end + + -- Create a new EDNS0 TCP_KEEPALIVE and test value set/get operations + edns0 = EDNS0_TCP_KEEPALIVE.new(1) + if edns0.timeout ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.timeout = 2 + if edns0.timeout ~= 2 then + return nil, Error.new("unexpected value in edns0 tcp keepalive") + end + + -- Create a new EDNS0 UL and test value set/get operations + edns0 = EDNS0_UL.new(1, 2) + if edns0.lease ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.keylease = 3 + if edns0.keylease ~= 3 then + return nil, Error.new("unexpected value in edns0 ul") + end + + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} + +func TestLuaRREDNS0(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Read the EDNS0 options + opt = msg:is_edns0() + if opt == nil then + return nil, Error.new("no edns0") + end + + -- Grab the options + options = opt.option + + -- The first one should be a cookie + if options[1].option ~= EDNS0COOKIE then + return nil, Error.new("unexpected subnet option value in cookie option") + end + if options[1].cookie ~= "testing" then + return nil, Error.new("unexpected value in edns0 cookie option") + end + + -- The second one should be a subnet option + if options[2].option ~= EDNS0SUBNET then + return nil, Error.new("unexpected subnet option value in subnet option") + end + if options[2].family ~= 1 then + return nil, Error.new("unexpected value in edns0 subnet option") + end + + -- Reply with an extended error message + local answer = Message.new() + answer:set_reply(msg) + answer.rcode = RcodeNXDOMAIN + answer:set_edns0(4096, false) + edns0 = answer:is_edns0() + edns0.option = { + EDNS0_EDE.new(15, "totally blocked"), + } + + return answer, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + q.SetEdns0(4096, false) + edns0 := q.IsEdns0() + + edns0.Option = append(edns0.Option, + &dns.EDNS0_COOKIE{ + Code: dns.EDNS0COOKIE, + Cookie: "testing", + }, + &dns.EDNS0_SUBNET{ + Code: dns.EDNS0SUBNET, + Family: 1, + SourceNetmask: 32, + SourceScope: 0, + Address: net.ParseIP("127.0.0.1"), + }, + ) + + a, err := r.Resolve(q, ci) + require.NoError(t, err) + + edns0 = a.IsEdns0() + require.NotNil(t, edns0) + require.Len(t, edns0.Option, 1) + require.Equal(t, uint16(dns.EDNS0EDE), edns0.Option[0].Option()) + ede := edns0.Option[0].(*dns.EDNS0_EDE) + require.Equal(t, uint16(15), ede.InfoCode) + require.Equal(t, "totally blocked", ede.ExtraText) +}