From c6824c4028a709b6847457d9cdc91f92f44a15b6 Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 20 Apr 2025 09:55:04 +0200 Subject: [PATCH 1/9] Support for Lua scripting --- go.mod | 1 + go.sum | 2 + lua.go | 352 ++++++++++++++++++++++++++++++++++++++++++++++++++++ lua_test.go | 95 ++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 lua.go create mode 100644 lua_test.go diff --git a/go.mod b/go.mod index ff3bc144..56ea0b20 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/quic-go/qpack v0.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/txthinking/runnergroup v0.0.0-20230325130830-408dc5853f86 // indirect + github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.35.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect diff --git a/go.sum b/go.sum index e131ec8f..9a9ab26d 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/txthinking/runnergroup v0.0.0-20230325130830-408dc5853f86/go.mod h1:c github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 h1:d/Wr/Vl/wiJHc3AHYbYs5I3PucJvRuw3SvbmlIRf+oM= github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301/go.mod h1:ntmMHL/xPq1WLeKiw8p/eRATaae6PiVRNipHFJxI8PM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2 h1:PF/PSu+RcSOzNpCqGo6KEO+4qw+LEiX/oqVxGm8rup8= +github.com/yuin/gopher-lua v1.1.2-0.20241109074121-ccacf662c9d2/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/lua.go b/lua.go new file mode 100644 index 00000000..f514c83f --- /dev/null +++ b/lua.go @@ -0,0 +1,352 @@ +package rdns + +import ( + "errors" + "fmt" + "reflect" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +type Lua struct { + id string + resolvers []Resolver + states chan *lua.LState + + opt LuaOptions +} + +var _ Resolver = &Lua{} + +type LuaOptions struct { + Script string + Concurrency uint +} + +func NewLua(id string, opt LuaOptions, resolvers ...Resolver) (*Lua, error) { + if opt.Concurrency == 0 { + opt.Concurrency = 4 + } + r := &Lua{ + id: id, + resolvers: resolvers, + opt: opt, + states: make(chan *lua.LState, opt.Concurrency), + } + + // Initialize lua states + for range opt.Concurrency { + L, err := r.newState() + if err != nil { + return nil, err + } + r.states <- L + } + return r, nil +} + +func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { + L := <-r.states + defer func() { r.states <- L }() + + log := logger(r.id, q, ci) + + lq := userDataWithType(L, luaMessageTypeName, q) + lci := L.NewUserData() + lci.Value = ci + + // Call the resolve() function in the lua script + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal("resolve"), + NRet: 2, + Protect: true, + }, lq, lci); err != nil { + log.Error("failed to run lua script", "error", err) + return nil, fmt.Errorf("failed to run lua script: %w", err) + } + + // Grab return values from the stack + lanswer := L.Get(-2) + lerr := L.Get(-1) + L.Pop(2) + + // Check for errors + switch lerr.Type() { + case lua.LTNil: // No error + case lua.LTUserData: + ud := lerr.(*lua.LUserData) + err, ok := ud.Value.(error) + if !ok { + err := fmt.Errorf("invalid respone type from lua script, expected error, got %T", ud.Value) + log.Error("failed to run lua script", "error", err) + return nil, err + } + return nil, err + + default: + err := fmt.Errorf("invalid respone type from lua script, expected userdata, got %T", lerr) + log.Error("failed to run lua script", "error", err) + return nil, err + } + + // Check the response + switch lanswer.Type() { + case lua.LTNil: + return nil, nil + + case lua.LTUserData: + ud := lanswer.(*lua.LUserData) + msg, ok := ud.Value.(*dns.Msg) + if !ok { + err := fmt.Errorf("invalid respone type from lua script, expected Message, got %T", ud.Value) + log.Error("failed to run lua script", "error", err) + return nil, err + } + return msg, nil + + default: + err := fmt.Errorf("invalid respone type from lua script, expected userdata, got %T", lerr) + log.Error("failed to run lua script", "error", err) + return nil, err + } +} + +func (r *Lua) String() string { + return r.id +} + +func (r *Lua) newState() (*lua.LState, error) { + L := lua.NewState() + + // Register types + registerMessageType(L) + registerQuestionType(L) + registerErrorType(L) + + // Inject the resolvers into the state (so they can be used in the script) + registerResolvers(L, r.resolvers) + + if err := L.DoString(r.opt.Script); err != nil { + return nil, err + } + + // The script must contain a resolve() function which is the entry point + if resolveFunc := L.GetGlobal("resolve"); resolveFunc.Type() != lua.LTFunction { + return nil, errors.New("no resolve() function found in lua script") + } + + return L, nil +} + +// Define Lua types +const ( + luaResolverTypeName = "Resolver" + luaMessageTypeName = "Message" + luaQuestionTypeName = "Question" + luaErrorTypeName = "Error" +) + +// Resolver functions + +func registerResolvers(L *lua.LState, resolvers []Resolver) { + mt := L.NewTypeMetatable(luaResolverTypeName) + L.SetGlobal("Resolver", mt) + + // Methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "resolve": resolverResolve, + })) + + table := L.CreateTable(len(resolvers), 0) + for _, r := range resolvers { + lv := userDataWithType(L, luaResolverTypeName, r) + table.Append(lv) + } + L.SetGlobal("Resolvers", table) +} + +func resolverResolve(L *lua.LState) int { + if L.GetTop() != 3 { + L.ArgError(1, "expected at 2 argument") + return 0 + } + r, ok := getUserDataArg[Resolver](L, 1) + if !ok { + return 0 + } + msg, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return 0 + } + ci, ok := getUserDataArg[ClientInfo](L, 3) + if !ok { + return 0 + } + + resp, err := r.Resolve(msg, ci) + + // Return the answer + L.Push(userDataWithType(L, luaMessageTypeName, resp)) + + // Return the error + if err != nil { + L.Push(userDataWithType(L, luaErrorTypeName, err)) + } else { + L.Push(lua.LNil) + } + + return 2 +} + +func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { + ud := L.CheckUserData(n) + v, ok := ud.Value.(T) + if !ok { + L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) + return v, false + } + return v, true +} + +// Message functions + +func registerMessageType(L *lua.LState) { + mt := L.NewTypeMetatable(luaMessageTypeName) + L.SetGlobal("Message", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newMessage)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "get_question": getter(messageGetQuestion), + "set_question": setter(messageSetQuestion), + })) +} + +func newMessage(L *lua.LState) int { + L.Push(userDataWithType(L, luaMessageTypeName, new(dns.Msg))) + return 1 +} + +func messageGetQuestion(L *lua.LState, msg *dns.Msg) { + table := L.CreateTable(len(msg.Question), 0) + for _, q := range msg.Question { + lv := userDataWithType(L, luaQuestionTypeName, &q) + table.Append(lv) + } + L.Push(table) +} + +func messageSetQuestion(L *lua.LState, msg *dns.Msg) { + table := L.CheckTable(2) + n := table.Len() + questions := make([]dns.Question, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTUserData { + L.ArgError(1, "invalid type, expected userdata") + return + } + lq := element.(*lua.LUserData) + q, ok := lq.Value.(*dns.Question) + if !ok { + L.ArgError(1, "invalid type, expected question") + return + } + questions = append(questions, *q) + } + msg.Question = questions +} + +// Question functions + +func registerQuestionType(L *lua.LState) { + mt := L.NewTypeMetatable(luaQuestionTypeName) + L.SetGlobal("Question", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newQuestion)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "get_name": getter(questionGetName), + "get_qtype": getter(questionGetQType), + "get_qclass": getter(questionGetQClass), + "set_name": setter(questionSetName), + "set_qtype": setter(questionSetQType), + "set_qclass": setter(questionSetQClass), + })) +} + +func newQuestion(L *lua.LState) int { + L.Push(userDataWithType(L, luaQuestionTypeName, new(dns.Question))) + return 1 +} + +func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) } +func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) } +func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) } + +func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) } +func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) } +func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) } + +// Error functions + +func registerErrorType(L *lua.LState) { + mt := L.NewTypeMetatable(luaErrorTypeName) + L.SetGlobal("Error", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newError)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "error": getter(errorGetError), + })) +} + +func newError(L *lua.LState) int { + err := errors.New(L.CheckString(1)) + L.Push(userDataWithType(L, luaErrorTypeName, err)) + return 1 +} +func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) } + +// Helper functions + +func userDataWithType(L *lua.LState, typ string, value any) *lua.LUserData { + ud := L.NewUserData() + ud.Value = value + L.SetMetatable(ud, L.GetTypeMetatable(typ)) + return ud +} + +func getter[T any](f func(*lua.LState, T)) func(*lua.LState) int { + return func(L *lua.LState) int { + if L.GetTop() > 1 { + L.ArgError(1, "no arguments expected") + return 0 + } + ud := L.CheckUserData(1) + r, ok := ud.Value.(T) + if !ok { + L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) + return 0 + } + f(L, r) + return 1 + } +} +func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int { + return func(L *lua.LState) int { + if L.GetTop() < 2 { + L.ArgError(1, "expected at least 1 argument") + return 0 + } + ud := L.CheckUserData(1) + r, ok := ud.Value.(T) + if !ok { + L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) + return 0 + } + f(L, r) + return 1 + } +} diff --git a/lua_test.go b/lua_test.go new file mode 100644 index 00000000..8ffe7842 --- /dev/null +++ b/lua_test.go @@ -0,0 +1,95 @@ +package rdns + +import ( + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func TestLuaSimplePassthrough(t *testing.T) { + opt := LuaOptions{ + Script: ` +function resolve(msg, ci) + resolver = Resolvers[1] + answer, err = resolver:resolve(msg, ci) + if err ~= nil then + return nil, err + end + return answer, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) + require.Equal(t, 1, resolver.HitCount()) +} + +func TestLuaMissingResolveFunc(t *testing.T) { + opt := LuaOptions{ + Script: `function test() return nil, nil end`, + } + + resolver := new(TestResolver) + + _, err := NewLua("test-lua", opt, resolver) + require.Error(t, err) +} + +func TestLuaResolveError(t *testing.T) { + opt := LuaOptions{ + Script: ` +function resolve(msg, ci) + return nil, Error.new("no bueno") +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + + _, err = r.Resolve(q, ci) + require.Error(t, err) + require.Zero(t, resolver.HitCount()) +} + +func TestLuaStaticAnswer(t *testing.T) { + opt := LuaOptions{ + Script: ` +function resolve(msg, ci) + answer = Message.new() + question = Question.new() + question:set_name("example.com.") + answer:set_question({question}) + return answer, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + + answer, err := r.Resolve(q, ci) + require.NoError(t, err) + require.Equal(t, 0, resolver.HitCount()) + require.Equal(t, "example.com.", answer.Question[0].Name) +} From 8062ba854ad3b6f9d485d4bb50adc9afa23e235a Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 20 Apr 2025 16:02:36 +0200 Subject: [PATCH 2/9] reorg --- lua-error.go | 28 ++++ lua-helpers.go | 50 ++++++++ lua-message.go | 56 ++++++++ lua-question.go | 38 ++++++ lua-resolver.go | 72 +++++++++++ lua-script.go | 91 +++++++++++++ lua-types.go | 9 ++ lua.go | 330 +++++++----------------------------------------- 8 files changed, 389 insertions(+), 285 deletions(-) create mode 100644 lua-error.go create mode 100644 lua-helpers.go create mode 100644 lua-message.go create mode 100644 lua-question.go create mode 100644 lua-resolver.go create mode 100644 lua-script.go create mode 100644 lua-types.go diff --git a/lua-error.go b/lua-error.go new file mode 100644 index 00000000..4162c721 --- /dev/null +++ b/lua-error.go @@ -0,0 +1,28 @@ +package rdns + +import ( + "errors" + + lua "github.com/yuin/gopher-lua" +) + +// Error functions + +func (s *LuaScript) RegisterErrorType() { + L := s.L + mt := L.NewTypeMetatable(luaErrorTypeName) + L.SetGlobal("Error", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newError)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "error": getter(errorGetError), + })) +} + +func newError(L *lua.LState) int { + err := errors.New(L.CheckString(1)) + L.Push(userDataWithType(L, luaErrorTypeName, err)) + return 1 +} +func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) } diff --git a/lua-helpers.go b/lua-helpers.go new file mode 100644 index 00000000..40243334 --- /dev/null +++ b/lua-helpers.go @@ -0,0 +1,50 @@ +package rdns + +import ( + "fmt" + "reflect" + + lua "github.com/yuin/gopher-lua" +) + +// Helper functions + +func userDataWithType(L *lua.LState, typ string, value any) *lua.LUserData { + ud := L.NewUserData() + ud.Value = value + L.SetMetatable(ud, L.GetTypeMetatable(typ)) + return ud +} + +func getter[T any](f func(*lua.LState, T)) func(*lua.LState) int { + return func(L *lua.LState) int { + if L.GetTop() > 1 { + L.ArgError(1, "no arguments expected") + return 0 + } + ud := L.CheckUserData(1) + r, ok := ud.Value.(T) + if !ok { + L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) + return 0 + } + f(L, r) + return 1 + } +} +func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int { + return func(L *lua.LState) int { + if L.GetTop() < 2 { + L.ArgError(1, "expected at least 1 argument") + return 0 + } + ud := L.CheckUserData(1) + r, ok := ud.Value.(T) + if !ok { + L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) + return 0 + } + f(L, r) + return 1 + } +} diff --git a/lua-message.go b/lua-message.go new file mode 100644 index 00000000..43000dd2 --- /dev/null +++ b/lua-message.go @@ -0,0 +1,56 @@ +package rdns + +import ( + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Message functions + +func (s *LuaScript) RegisterMessageType() { + L := s.L + mt := L.NewTypeMetatable(luaMessageTypeName) + L.SetGlobal("Message", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newMessage)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "get_question": getter(messageGetQuestion), + "set_question": setter(messageSetQuestion), + })) +} + +func newMessage(L *lua.LState) int { + L.Push(userDataWithType(L, luaMessageTypeName, new(dns.Msg))) + return 1 +} + +func messageGetQuestion(L *lua.LState, msg *dns.Msg) { + table := L.CreateTable(len(msg.Question), 0) + for _, q := range msg.Question { + lv := userDataWithType(L, luaQuestionTypeName, &q) + table.Append(lv) + } + L.Push(table) +} + +func messageSetQuestion(L *lua.LState, msg *dns.Msg) { + table := L.CheckTable(2) + n := table.Len() + questions := make([]dns.Question, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTUserData { + L.ArgError(1, "invalid type, expected userdata") + return + } + lq := element.(*lua.LUserData) + q, ok := lq.Value.(*dns.Question) + if !ok { + L.ArgError(1, "invalid type, expected question") + return + } + questions = append(questions, *q) + } + msg.Question = questions +} diff --git a/lua-question.go b/lua-question.go new file mode 100644 index 00000000..00de0ff4 --- /dev/null +++ b/lua-question.go @@ -0,0 +1,38 @@ +package rdns + +import ( + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Question functions + +func (s *LuaScript) RegisterQuestionType() { + L := s.L + mt := L.NewTypeMetatable(luaQuestionTypeName) + L.SetGlobal("Question", mt) + // static attributes + L.SetField(mt, "new", L.NewFunction(newQuestion)) + // methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "get_name": getter(questionGetName), + "get_qtype": getter(questionGetQType), + "get_qclass": getter(questionGetQClass), + "set_name": setter(questionSetName), + "set_qtype": setter(questionSetQType), + "set_qclass": setter(questionSetQClass), + })) +} + +func newQuestion(L *lua.LState) int { + L.Push(userDataWithType(L, luaQuestionTypeName, new(dns.Question))) + return 1 +} + +func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) } +func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) } +func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) } + +func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) } +func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) } +func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) } diff --git a/lua-resolver.go b/lua-resolver.go new file mode 100644 index 00000000..06d6e45b --- /dev/null +++ b/lua-resolver.go @@ -0,0 +1,72 @@ +package rdns + +import ( + "fmt" + "reflect" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// Resolver functions + +func (s *LuaScript) InjectResolvers(resolvers []Resolver) { + L := s.L + mt := L.NewTypeMetatable(luaResolverTypeName) + L.SetGlobal("Resolver", mt) + + // Methods + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ + "resolve": resolverResolve, + })) + + table := L.CreateTable(len(resolvers), 0) + for _, r := range resolvers { + lv := userDataWithType(L, luaResolverTypeName, r) + table.Append(lv) + } + L.SetGlobal("Resolvers", table) +} + +func resolverResolve(L *lua.LState) int { + if L.GetTop() != 3 { + L.ArgError(1, "expected 2 arguments") + return 0 + } + r, ok := getUserDataArg[Resolver](L, 1) + if !ok { + return 0 + } + msg, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return 0 + } + ci, ok := getUserDataArg[ClientInfo](L, 3) + if !ok { + return 0 + } + + resp, err := r.Resolve(msg, ci) + + // Return the answer + L.Push(userDataWithType(L, luaMessageTypeName, resp)) + + // Return the error + if err != nil { + L.Push(userDataWithType(L, luaErrorTypeName, err)) + } else { + L.Push(lua.LNil) + } + + return 2 +} + +func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { + ud := L.CheckUserData(n) + v, ok := ud.Value.(T) + if !ok { + L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) + return v, false + } + return v, true +} diff --git a/lua-script.go b/lua-script.go new file mode 100644 index 00000000..8f897c5e --- /dev/null +++ b/lua-script.go @@ -0,0 +1,91 @@ +package rdns + +import ( + "fmt" + "io" + "slices" + + lua "github.com/yuin/gopher-lua" + "github.com/yuin/gopher-lua/parse" +) + +type ByteCode struct { + *lua.FunctionProto +} + +type LuaScript struct { + L *lua.LState +} + +// LuaCompile compiles lua script into bytecode. The returned bytecode can be used +// to instantiate one or more scripts. +func LuaCompile(reader io.Reader, name string) (ByteCode, error) { + chunk, err := parse.Parse(reader, name) + if err != nil { + return ByteCode{}, err + } + proto, err := lua.Compile(chunk, name) + if err != nil { + return ByteCode{}, err + } + return ByteCode{proto}, nil +} + +// NewScriptFromByteCode creates a new lua script from bytecode. +func NewScriptFromByteCode(b ByteCode) (*LuaScript, error) { + L := lua.NewState() + lfunc := L.NewFunctionFromProto(b.FunctionProto) + L.Push(lfunc) + return &LuaScript{L: L}, L.PCall(0, lua.MultRet, nil) +} + +func (s *LuaScript) HasFunction(name string) bool { + return s.L.GetGlobal(name).Type() == lua.LTFunction +} + +func (s *LuaScript) Call(fnName string, nret int, params ...any) ([]any, error) { + // args := make([]lua.LValue, 0, len(params)) + // TODO: implement + // for _, p := range params { + // args = append(args, userDataWithType(s.L, "", p)) + // } + + args := []lua.LValue{ + userDataWithType(s.L, luaMessageTypeName, params[0]), + userDataWithType(s.L, "", params[1]), + } + + // Call the resolve() function in the lua script + if err := s.L.CallByParam(lua.P{ + Fn: s.L.GetGlobal("resolve"), + NRet: nret, + Protect: true, + }, args...); err != nil { + return nil, fmt.Errorf("failed to call lua: %w", err) + } + + // Grab return values from the stack and add them to the result slice + // in reverse order + ret := make([]any, nret) + for i := range slices.Backward(ret) { + lv := s.L.Get(-1) + s.L.Pop(1) + + var v any + + switch lv.Type() { + case lua.LTNil: + v = nil + case lua.LTUserData: + ud := lv.(*lua.LUserData) + v = ud.Value + case lua.LTString: + v = lv.String() + default: + return nil, fmt.Errorf("unsupported return type: %v", lv.Type()) + } + ret[i] = v + } + + return ret, nil +} diff --git a/lua-types.go b/lua-types.go new file mode 100644 index 00000000..196044b8 --- /dev/null +++ b/lua-types.go @@ -0,0 +1,9 @@ +package rdns + +// Define Lua types +const ( + luaResolverTypeName = "Resolver" + luaMessageTypeName = "Message" + luaQuestionTypeName = "Question" + luaErrorTypeName = "Error" +) diff --git a/lua.go b/lua.go index f514c83f..a43aba5f 100644 --- a/lua.go +++ b/lua.go @@ -3,16 +3,16 @@ package rdns import ( "errors" "fmt" - "reflect" + "strings" "github.com/miekg/dns" - lua "github.com/yuin/gopher-lua" ) type Lua struct { id string resolvers []Resolver - states chan *lua.LState + scripts chan *LuaScript + bytecode ByteCode opt LuaOptions } @@ -28,325 +28,85 @@ func NewLua(id string, opt LuaOptions, resolvers ...Resolver) (*Lua, error) { if opt.Concurrency == 0 { opt.Concurrency = 4 } + + // Compile the script + bytecode, err := LuaCompile(strings.NewReader(opt.Script), id) + if err != nil { + return nil, err + } + r := &Lua{ id: id, resolvers: resolvers, opt: opt, - states: make(chan *lua.LState, opt.Concurrency), + scripts: make(chan *LuaScript, opt.Concurrency), + bytecode: bytecode, } - // Initialize lua states + // Initialize scripts for range opt.Concurrency { - L, err := r.newState() + s, err := r.newScript() if err != nil { return nil, err } - r.states <- L + r.scripts <- s } return r, nil } func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { - L := <-r.states - defer func() { r.states <- L }() + s := <-r.scripts + defer func() { r.scripts <- s }() log := logger(r.id, q, ci) - lq := userDataWithType(L, luaMessageTypeName, q) - lci := L.NewUserData() - lci.Value = ci - - // Call the resolve() function in the lua script - if err := L.CallByParam(lua.P{ - Fn: L.GetGlobal("resolve"), - NRet: 2, - Protect: true, - }, lq, lci); err != nil { - log.Error("failed to run lua script", "error", err) - return nil, fmt.Errorf("failed to run lua script: %w", err) - } - - // Grab return values from the stack - lanswer := L.Get(-2) - lerr := L.Get(-1) - L.Pop(2) - - // Check for errors - switch lerr.Type() { - case lua.LTNil: // No error - case lua.LTUserData: - ud := lerr.(*lua.LUserData) - err, ok := ud.Value.(error) - if !ok { - err := fmt.Errorf("invalid respone type from lua script, expected error, got %T", ud.Value) - log.Error("failed to run lua script", "error", err) - return nil, err - } - return nil, err - - default: - err := fmt.Errorf("invalid respone type from lua script, expected userdata, got %T", lerr) + // Call the "resolve" function in the script. It should return 2 values. + ret, err := s.Call("resolve", 2, q, ci) + if err != nil { log.Error("failed to run lua script", "error", err) return nil, err } - // Check the response - switch lanswer.Type() { - case lua.LTNil: - return nil, nil - - case lua.LTUserData: - ud := lanswer.(*lua.LUserData) - msg, ok := ud.Value.(*dns.Msg) - if !ok { - err := fmt.Errorf("invalid respone type from lua script, expected Message, got %T", ud.Value) - log.Error("failed to run lua script", "error", err) - return nil, err - } - return msg, nil - - default: - err := fmt.Errorf("invalid respone type from lua script, expected userdata, got %T", lerr) - log.Error("failed to run lua script", "error", err) - return nil, err + // Extract the answer and error from the returned values + if len(ret) != 2 { + return nil, fmt.Errorf("invalid return value, expected 2, got %d", len(ret)) } -} -func (r *Lua) String() string { - return r.id -} - -func (r *Lua) newState() (*lua.LState, error) { - L := lua.NewState() - - // Register types - registerMessageType(L) - registerQuestionType(L) - registerErrorType(L) - - // Inject the resolvers into the state (so they can be used in the script) - registerResolvers(L, r.resolvers) - - if err := L.DoString(r.opt.Script); err != nil { - return nil, err + answer, ok := ret[0].(*dns.Msg) + if ret[0] != nil && !ok { + return nil, fmt.Errorf("invalid return value, expected Message, got %T", ret[0]) } - // The script must contain a resolve() function which is the entry point - if resolveFunc := L.GetGlobal("resolve"); resolveFunc.Type() != lua.LTFunction { - return nil, errors.New("no resolve() function found in lua script") + err, ok = ret[1].(error) + if ret[1] != nil && !ok { + return nil, fmt.Errorf("invalid return value, expected Error, got %T", ret[1]) } - return L, nil + return answer, err } -// Define Lua types -const ( - luaResolverTypeName = "Resolver" - luaMessageTypeName = "Message" - luaQuestionTypeName = "Question" - luaErrorTypeName = "Error" -) - -// Resolver functions - -func registerResolvers(L *lua.LState, resolvers []Resolver) { - mt := L.NewTypeMetatable(luaResolverTypeName) - L.SetGlobal("Resolver", mt) - - // Methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "resolve": resolverResolve, - })) - - table := L.CreateTable(len(resolvers), 0) - for _, r := range resolvers { - lv := userDataWithType(L, luaResolverTypeName, r) - table.Append(lv) - } - L.SetGlobal("Resolvers", table) +func (r *Lua) String() string { + return r.id } -func resolverResolve(L *lua.LState) int { - if L.GetTop() != 3 { - L.ArgError(1, "expected at 2 argument") - return 0 - } - r, ok := getUserDataArg[Resolver](L, 1) - if !ok { - return 0 - } - msg, ok := getUserDataArg[*dns.Msg](L, 2) - if !ok { - return 0 - } - ci, ok := getUserDataArg[ClientInfo](L, 3) - if !ok { - return 0 - } - - resp, err := r.Resolve(msg, ci) - - // Return the answer - L.Push(userDataWithType(L, luaMessageTypeName, resp)) - - // Return the error +func (r *Lua) newScript() (*LuaScript, error) { + s, err := NewScriptFromByteCode(r.bytecode) if err != nil { - L.Push(userDataWithType(L, luaErrorTypeName, err)) - } else { - L.Push(lua.LNil) - } - - return 2 -} - -func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { - ud := L.CheckUserData(n) - v, ok := ud.Value.(T) - if !ok { - L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) - return v, false + return nil, err } - return v, true -} - -// Message functions -func registerMessageType(L *lua.LState) { - mt := L.NewTypeMetatable(luaMessageTypeName) - L.SetGlobal("Message", mt) - // static attributes - L.SetField(mt, "new", L.NewFunction(newMessage)) - // methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_question": getter(messageGetQuestion), - "set_question": setter(messageSetQuestion), - })) -} - -func newMessage(L *lua.LState) int { - L.Push(userDataWithType(L, luaMessageTypeName, new(dns.Msg))) - return 1 -} + // Register types and methods + s.RegisterMessageType() + s.RegisterQuestionType() + s.RegisterErrorType() -func messageGetQuestion(L *lua.LState, msg *dns.Msg) { - table := L.CreateTable(len(msg.Question), 0) - for _, q := range msg.Question { - lv := userDataWithType(L, luaQuestionTypeName, &q) - table.Append(lv) - } - L.Push(table) -} + // Inject the resolvers into the state (so they can be used in the script) + s.InjectResolvers(r.resolvers) -func messageSetQuestion(L *lua.LState, msg *dns.Msg) { - table := L.CheckTable(2) - n := table.Len() - questions := make([]dns.Question, 0, n) - for i := range n { - element := table.RawGetInt(i + 1) - if element.Type() != lua.LTUserData { - L.ArgError(1, "invalid type, expected userdata") - return - } - lq := element.(*lua.LUserData) - q, ok := lq.Value.(*dns.Question) - if !ok { - L.ArgError(1, "invalid type, expected question") - return - } - questions = append(questions, *q) + // The script must contain a resolve() function which is the entry point + if !s.HasFunction("resolve") { + return nil, errors.New("no resolve() function found in lua script") } - msg.Question = questions -} -// Question functions - -func registerQuestionType(L *lua.LState) { - mt := L.NewTypeMetatable(luaQuestionTypeName) - L.SetGlobal("Question", mt) - // static attributes - L.SetField(mt, "new", L.NewFunction(newQuestion)) - // methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_name": getter(questionGetName), - "get_qtype": getter(questionGetQType), - "get_qclass": getter(questionGetQClass), - "set_name": setter(questionSetName), - "set_qtype": setter(questionSetQType), - "set_qclass": setter(questionSetQClass), - })) -} - -func newQuestion(L *lua.LState) int { - L.Push(userDataWithType(L, luaQuestionTypeName, new(dns.Question))) - return 1 -} - -func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) } -func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) } -func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) } - -func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) } -func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) } -func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) } - -// Error functions - -func registerErrorType(L *lua.LState) { - mt := L.NewTypeMetatable(luaErrorTypeName) - L.SetGlobal("Error", mt) - // static attributes - L.SetField(mt, "new", L.NewFunction(newError)) - // methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "error": getter(errorGetError), - })) -} - -func newError(L *lua.LState) int { - err := errors.New(L.CheckString(1)) - L.Push(userDataWithType(L, luaErrorTypeName, err)) - return 1 -} -func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) } - -// Helper functions - -func userDataWithType(L *lua.LState, typ string, value any) *lua.LUserData { - ud := L.NewUserData() - ud.Value = value - L.SetMetatable(ud, L.GetTypeMetatable(typ)) - return ud -} - -func getter[T any](f func(*lua.LState, T)) func(*lua.LState) int { - return func(L *lua.LState) int { - if L.GetTop() > 1 { - L.ArgError(1, "no arguments expected") - return 0 - } - ud := L.CheckUserData(1) - r, ok := ud.Value.(T) - if !ok { - L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) - return 0 - } - f(L, r) - return 1 - } -} -func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int { - return func(L *lua.LState) int { - if L.GetTop() < 2 { - L.ArgError(1, "expected at least 1 argument") - return 0 - } - ud := L.CheckUserData(1) - r, ok := ud.Value.(T) - if !ok { - L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) - return 0 - } - f(L, r) - return 1 - } + return s, nil } From b337e0206da33cd2dc7a55f2213ab92ac06b912f Mon Sep 17 00:00:00 2001 From: folbrich Date: Sat, 26 Apr 2025 20:08:10 +0200 Subject: [PATCH 3/9] update --- lua-error.go | 7 +++++-- lua-helpers.go | 7 ------- lua-message.go | 10 ++++++---- lua-question.go | 6 ++++-- lua-resolver.go | 10 ++++++---- lua-script.go | 6 +++--- lua-types.go | 15 ++++++++------- lua.go | 8 ++++---- lua_test.go | 14 ++++++++------ 9 files changed, 44 insertions(+), 39 deletions(-) diff --git a/lua-error.go b/lua-error.go index 4162c721..aecb8efd 100644 --- a/lua-error.go +++ b/lua-error.go @@ -8,9 +8,11 @@ import ( // Error functions +const luaErrorMetatableName = "Error" + func (s *LuaScript) RegisterErrorType() { L := s.L - mt := L.NewTypeMetatable(luaErrorTypeName) + mt := L.NewTypeMetatable(luaErrorMetatableName) L.SetGlobal("Error", mt) // static attributes L.SetField(mt, "new", L.NewFunction(newError)) @@ -22,7 +24,8 @@ func (s *LuaScript) RegisterErrorType() { func newError(L *lua.LState) int { err := errors.New(L.CheckString(1)) - L.Push(userDataWithType(L, luaErrorTypeName, err)) + L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) return 1 } + func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) } diff --git a/lua-helpers.go b/lua-helpers.go index 40243334..54d513f8 100644 --- a/lua-helpers.go +++ b/lua-helpers.go @@ -9,13 +9,6 @@ import ( // Helper functions -func userDataWithType(L *lua.LState, typ string, value any) *lua.LUserData { - ud := L.NewUserData() - ud.Value = value - L.SetMetatable(ud, L.GetTypeMetatable(typ)) - return ud -} - func getter[T any](f func(*lua.LState, T)) func(*lua.LState) int { return func(L *lua.LState) int { if L.GetTop() > 1 { diff --git a/lua-message.go b/lua-message.go index 43000dd2..3953fbe7 100644 --- a/lua-message.go +++ b/lua-message.go @@ -7,10 +7,12 @@ import ( // Message functions +const luaMessageMetatableName = "Message" + func (s *LuaScript) RegisterMessageType() { L := s.L - mt := L.NewTypeMetatable(luaMessageTypeName) - L.SetGlobal("Message", mt) + mt := L.NewTypeMetatable(luaMessageMetatableName) + L.SetGlobal(luaMessageMetatableName, mt) // static attributes L.SetField(mt, "new", L.NewFunction(newMessage)) // methods @@ -21,14 +23,14 @@ func (s *LuaScript) RegisterMessageType() { } func newMessage(L *lua.LState) int { - L.Push(userDataWithType(L, luaMessageTypeName, new(dns.Msg))) + L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg))) return 1 } func messageGetQuestion(L *lua.LState, msg *dns.Msg) { table := L.CreateTable(len(msg.Question), 0) for _, q := range msg.Question { - lv := userDataWithType(L, luaQuestionTypeName, &q) + lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) table.Append(lv) } L.Push(table) diff --git a/lua-question.go b/lua-question.go index 00de0ff4..d42b75f0 100644 --- a/lua-question.go +++ b/lua-question.go @@ -7,9 +7,11 @@ import ( // Question functions +const luaQuestionMetatableName = "Question" + func (s *LuaScript) RegisterQuestionType() { L := s.L - mt := L.NewTypeMetatable(luaQuestionTypeName) + mt := L.NewTypeMetatable(luaQuestionMetatableName) L.SetGlobal("Question", mt) // static attributes L.SetField(mt, "new", L.NewFunction(newQuestion)) @@ -25,7 +27,7 @@ func (s *LuaScript) RegisterQuestionType() { } func newQuestion(L *lua.LState) int { - L.Push(userDataWithType(L, luaQuestionTypeName, new(dns.Question))) + L.Push(userDataWithMetatable(L, luaQuestionMetatableName, new(dns.Question))) return 1 } diff --git a/lua-resolver.go b/lua-resolver.go index 06d6e45b..54ff25c4 100644 --- a/lua-resolver.go +++ b/lua-resolver.go @@ -10,9 +10,11 @@ import ( // Resolver functions +const luaResolverMetatableName = "Resolver" + func (s *LuaScript) InjectResolvers(resolvers []Resolver) { L := s.L - mt := L.NewTypeMetatable(luaResolverTypeName) + mt := L.NewTypeMetatable(luaResolverMetatableName) L.SetGlobal("Resolver", mt) // Methods @@ -22,7 +24,7 @@ func (s *LuaScript) InjectResolvers(resolvers []Resolver) { table := L.CreateTable(len(resolvers), 0) for _, r := range resolvers { - lv := userDataWithType(L, luaResolverTypeName, r) + lv := userDataWithMetatable(L, luaResolverMetatableName, r) table.Append(lv) } L.SetGlobal("Resolvers", table) @@ -49,11 +51,11 @@ func resolverResolve(L *lua.LState) int { resp, err := r.Resolve(msg, ci) // Return the answer - L.Push(userDataWithType(L, luaMessageTypeName, resp)) + L.Push(userDataWithMetatable(L, luaMessageMetatableName, resp)) // Return the error if err != nil { - L.Push(userDataWithType(L, luaErrorTypeName, err)) + L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) } else { L.Push(lua.LNil) } diff --git a/lua-script.go b/lua-script.go index 8f897c5e..be20ff27 100644 --- a/lua-script.go +++ b/lua-script.go @@ -51,13 +51,13 @@ func (s *LuaScript) Call(fnName string, nret int, params ...any) ([]any, error) // } args := []lua.LValue{ - userDataWithType(s.L, luaMessageTypeName, params[0]), - userDataWithType(s.L, "", params[1]), + userDataWithMetatable(s.L, luaMessageMetatableName, params[0]), + userDataWithMetatable(s.L, "", params[1]), } // Call the resolve() function in the lua script if err := s.L.CallByParam(lua.P{ - Fn: s.L.GetGlobal("resolve"), + Fn: s.L.GetGlobal(fnName), NRet: nret, Protect: true, }, args...); err != nil { diff --git a/lua-types.go b/lua-types.go index 196044b8..98bbb2c7 100644 --- a/lua-types.go +++ b/lua-types.go @@ -1,9 +1,10 @@ package rdns -// Define Lua types -const ( - luaResolverTypeName = "Resolver" - luaMessageTypeName = "Message" - luaQuestionTypeName = "Question" - luaErrorTypeName = "Error" -) +import lua "github.com/yuin/gopher-lua" + +func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { + ud := L.NewUserData() + ud.Value = value + L.SetMetatable(ud, L.GetTypeMetatable(mtName)) + return ud +} diff --git a/lua.go b/lua.go index a43aba5f..c61e24ed 100644 --- a/lua.go +++ b/lua.go @@ -61,7 +61,7 @@ func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) { log := logger(r.id, q, ci) // Call the "resolve" function in the script. It should return 2 values. - ret, err := s.Call("resolve", 2, q, ci) + ret, err := s.Call("Resolve", 2, q, ci) if err != nil { log.Error("failed to run lua script", "error", err) return nil, err @@ -103,9 +103,9 @@ func (r *Lua) newScript() (*LuaScript, error) { // Inject the resolvers into the state (so they can be used in the script) s.InjectResolvers(r.resolvers) - // The script must contain a resolve() function which is the entry point - if !s.HasFunction("resolve") { - return nil, errors.New("no resolve() function found in lua script") + // The script must contain a Resolve() function which is the entry point + if !s.HasFunction("Resolve") { + return nil, errors.New("no Resolve() function found in lua script") } return s, nil diff --git a/lua_test.go b/lua_test.go index 8ffe7842..366b4fa5 100644 --- a/lua_test.go +++ b/lua_test.go @@ -11,8 +11,10 @@ func TestLuaSimplePassthrough(t *testing.T) { opt := LuaOptions{ Script: ` function resolve(msg, ci) - resolver = Resolvers[1] - answer, err = resolver:resolve(msg, ci) + end +function Resolve(msg, ci) + local resolver = Resolvers[1] + local answer, err = resolver:resolve(msg, ci) if err ~= nil then return nil, err end @@ -48,7 +50,7 @@ func TestLuaMissingResolveFunc(t *testing.T) { func TestLuaResolveError(t *testing.T) { opt := LuaOptions{ Script: ` -function resolve(msg, ci) +function Resolve(msg, ci) return nil, Error.new("no bueno") end`, } @@ -70,9 +72,9 @@ end`, func TestLuaStaticAnswer(t *testing.T) { opt := LuaOptions{ Script: ` -function resolve(msg, ci) - answer = Message.new() - question = Question.new() +function Resolve(msg, ci) + local answer = Message.new() + local question = Question.new() question:set_name("example.com.") answer:set_question({question}) return answer, nil From 3e73248325ee25101420323fdf10a4675bb4e985 Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 27 Apr 2025 16:35:50 +0200 Subject: [PATCH 4/9] update --- lua-error.go | 21 +++++---- lua-helpers.go | 10 +++++ lua-message.go | 116 ++++++++++++++++++++++++++++++++---------------- lua-question.go | 45 ++++++++++--------- lua-resolver.go | 13 ------ lua-types.go | 24 +++++++++- lua.go | 1 + lua_test.go | 71 ++++++++++++++++++++--------- 8 files changed, 197 insertions(+), 104 deletions(-) diff --git a/lua-error.go b/lua-error.go index aecb8efd..ce562c40 100644 --- a/lua-error.go +++ b/lua-error.go @@ -13,19 +13,18 @@ const luaErrorMetatableName = "Error" func (s *LuaScript) RegisterErrorType() { L := s.L mt := L.NewTypeMetatable(luaErrorMetatableName) - L.SetGlobal("Error", mt) + L.SetGlobal(luaErrorMetatableName, mt) + // static attributes - L.SetField(mt, "new", L.NewFunction(newError)) + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + err := errors.New(L.CheckString(1)) + L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) + return 1 + })) + // methods L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "error": getter(errorGetError), + "error": getter(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }), })) } - -func newError(L *lua.LState) int { - err := errors.New(L.CheckString(1)) - L.Push(userDataWithMetatable(L, luaErrorMetatableName, err)) - return 1 -} - -func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) } diff --git a/lua-helpers.go b/lua-helpers.go index 54d513f8..62d39961 100644 --- a/lua-helpers.go +++ b/lua-helpers.go @@ -41,3 +41,13 @@ func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int { return 1 } } + +func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { + ud := L.CheckUserData(n) + v, ok := ud.Value.(T) + if !ok { + L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) + return v, false + } + return v, true +} diff --git a/lua-message.go b/lua-message.go index 3953fbe7..f29071ef 100644 --- a/lua-message.go +++ b/lua-message.go @@ -14,45 +14,85 @@ func (s *LuaScript) RegisterMessageType() { mt := L.NewTypeMetatable(luaMessageMetatableName) L.SetGlobal(luaMessageMetatableName, mt) // static attributes - L.SetField(mt, "new", L.NewFunction(newMessage)) + L.SetField(mt, "new", L.NewFunction(func(L *lua.LState) int { + L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg))) + return 1 + })) // methods L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_question": getter(messageGetQuestion), - "set_question": setter(messageSetQuestion), + "get_questions": getter(func(L *lua.LState, msg *dns.Msg) { + table := L.CreateTable(len(msg.Question), 0) + for _, q := range msg.Question { + lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) + table.Append(lv) + } + L.Push(table) + }), + "set_questions": setter(func(L *lua.LState, msg *dns.Msg) { + table := L.CheckTable(2) + n := table.Len() + questions := make([]dns.Question, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTUserData { + L.ArgError(1, "invalid type, expected userdata") + return + } + lq := element.(*lua.LUserData) + q, ok := lq.Value.(*dns.Question) + if !ok { + L.ArgError(1, "invalid type, expected question") + return + } + questions = append(questions, *q) + } + msg.Question = questions + }), + "set_question": setter(func(L *lua.LState, msg *dns.Msg) { + msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3))) + }), + "set_id": setter(func(L *lua.LState, msg *dns.Msg) { + msg.Id = uint16(L.CheckInt(2)) + }), + "get_id": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LNumber(msg.Id)) + }), + "set_response": setter(func(L *lua.LState, msg *dns.Msg) { + msg.Response = L.CheckBool(2) + }), + "get_response": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LBool(msg.Response)) + }), + "set_reply": setter(func(L *lua.LState, msg *dns.Msg) { + request, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return + } + msg.SetReply(request) + }), + "set_rcode": setter(func(L *lua.LState, msg *dns.Msg) { + msg.Rcode = L.CheckInt(2) + }), + "get_rcode": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LNumber(msg.Rcode)) + }), + "set_rd": setter(func(L *lua.LState, msg *dns.Msg) { + msg.RecursionDesired = L.CheckBool(2) + }), + "get_rd": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LBool(msg.RecursionDesired)) + }), + "set_ra": setter(func(L *lua.LState, msg *dns.Msg) { + msg.RecursionAvailable = L.CheckBool(2) + }), + "get_ra": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LBool(msg.RecursionAvailable)) + }), + "set_ad": setter(func(L *lua.LState, msg *dns.Msg) { + msg.AuthenticatedData = L.CheckBool(2) + }), + "get_ad": getter(func(L *lua.LState, msg *dns.Msg) { + L.Push(lua.LBool(msg.AuthenticatedData)) + }), })) } - -func newMessage(L *lua.LState) int { - L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg))) - return 1 -} - -func messageGetQuestion(L *lua.LState, msg *dns.Msg) { - table := L.CreateTable(len(msg.Question), 0) - for _, q := range msg.Question { - lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) - table.Append(lv) - } - L.Push(table) -} - -func messageSetQuestion(L *lua.LState, msg *dns.Msg) { - table := L.CheckTable(2) - n := table.Len() - questions := make([]dns.Question, 0, n) - for i := range n { - element := table.RawGetInt(i + 1) - if element.Type() != lua.LTUserData { - L.ArgError(1, "invalid type, expected userdata") - return - } - lq := element.(*lua.LUserData) - q, ok := lq.Value.(*dns.Question) - if !ok { - L.ArgError(1, "invalid type, expected question") - return - } - questions = append(questions, *q) - } - msg.Question = questions -} diff --git a/lua-question.go b/lua-question.go index d42b75f0..4170d0a5 100644 --- a/lua-question.go +++ b/lua-question.go @@ -12,29 +12,32 @@ const luaQuestionMetatableName = "Question" func (s *LuaScript) RegisterQuestionType() { L := s.L mt := L.NewTypeMetatable(luaQuestionMetatableName) - L.SetGlobal("Question", mt) + L.SetGlobal(luaQuestionMetatableName, mt) // static attributes - L.SetField(mt, "new", L.NewFunction(newQuestion)) + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + q := &dns.Question{Qclass: dns.ClassINET} + nArgs := L.GetTop() + if nArgs >= 1 { // Name provided + q.Name = L.CheckString(1) + } + if nArgs >= 2 { // Name and type + q.Qtype = uint16(L.CheckNumber(2)) + } + if nArgs >= 3 { // Name, type and class + q.Qclass = uint16(L.CheckNumber(3)) + } + L.Push(userDataWithMetatable(L, luaQuestionMetatableName, q)) + return 1 + })) + // methods L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_name": getter(questionGetName), - "get_qtype": getter(questionGetQType), - "get_qclass": getter(questionGetQClass), - "set_name": setter(questionSetName), - "set_qtype": setter(questionSetQType), - "set_qclass": setter(questionSetQClass), + "get_name": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }), + "get_qtype": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }), + "get_qclass": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }), + "set_name": setter(func(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }), + "set_qtype": setter(func(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }), + "set_qclass": setter(func(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }), })) } - -func newQuestion(L *lua.LState) int { - L.Push(userDataWithMetatable(L, luaQuestionMetatableName, new(dns.Question))) - return 1 -} - -func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) } -func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) } -func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) } - -func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) } -func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) } -func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) } diff --git a/lua-resolver.go b/lua-resolver.go index 54ff25c4..63f0e146 100644 --- a/lua-resolver.go +++ b/lua-resolver.go @@ -1,9 +1,6 @@ package rdns import ( - "fmt" - "reflect" - "github.com/miekg/dns" lua "github.com/yuin/gopher-lua" ) @@ -62,13 +59,3 @@ func resolverResolve(L *lua.LState) int { return 2 } - -func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { - ud := L.CheckUserData(n) - v, ok := ud.Value.(T) - if !ok { - L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value)) - return v, false - } - return v, true -} diff --git a/lua-types.go b/lua-types.go index 98bbb2c7..c6ba79fb 100644 --- a/lua-types.go +++ b/lua-types.go @@ -1,6 +1,28 @@ package rdns -import lua "github.com/yuin/gopher-lua" +import ( + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +func (s *LuaScript) RegisterConstants() { + L := s.L + + // Register TypeA, TypeAAAA, etc + for value, name := range dns.TypeToString { + L.SetGlobal("Type"+name, lua.LNumber(value)) + } + + // Register ClassINET, etc + for value, name := range dns.ClassToString { + L.SetGlobal("Class"+name, lua.LNumber(value)) + } + + // Register Rcodes, RcodeNOERROR, RcodeNXDOMAIN, etc + for value, name := range dns.RcodeToString { + L.SetGlobal("Rcode"+name, lua.LNumber(value)) + } +} func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { ud := L.NewUserData() diff --git a/lua.go b/lua.go index c61e24ed..5c1780f7 100644 --- a/lua.go +++ b/lua.go @@ -96,6 +96,7 @@ func (r *Lua) newScript() (*LuaScript, error) { } // Register types and methods + s.RegisterConstants() s.RegisterMessageType() s.RegisterQuestionType() s.RegisterErrorType() diff --git a/lua_test.go b/lua_test.go index 366b4fa5..94c14949 100644 --- a/lua_test.go +++ b/lua_test.go @@ -10,8 +10,6 @@ import ( func TestLuaSimplePassthrough(t *testing.T) { opt := LuaOptions{ Script: ` -function resolve(msg, ci) - end function Resolve(msg, ci) local resolver = Resolvers[1] local answer, err = resolver:resolve(msg, ci) @@ -70,28 +68,61 @@ end`, } func TestLuaStaticAnswer(t *testing.T) { - opt := LuaOptions{ - Script: ` + tests := map[string]LuaOptions{ + "set_questions": { + Script: ` +function Resolve(msg, ci) + local question = Question.new("example.com.", TypeA) + local answer = Message.new() + answer:set_id(msg:get_id()) + answer:set_questions({question}) + answer:set_response(true) + answer:set_rcode(RcodeNXDOMAIN) + return answer, nil +end`, + }, + "set_question": { + Script: ` function Resolve(msg, ci) local answer = Message.new() - local question = Question.new() - question:set_name("example.com.") - answer:set_question({question}) + answer:set_question("example.com.", TypeA) + answer:set_id(msg:get_id()) + answer:set_response(true) + answer:set_rcode(RcodeNXDOMAIN) return answer, nil end`, + }, + "set_reply": { + Script: ` +function Resolve(msg, ci) + local answer = Message.new() + answer:set_reply(msg) + answer:set_rcode(RcodeNXDOMAIN) + return answer, nil +end`, + }, } - var ci ClientInfo - resolver := new(TestResolver) - - r, err := NewLua("test-lua", opt, resolver) - require.NoError(t, err) - - q := new(dns.Msg) - q.SetQuestion("example.com.", dns.TypeA) - - answer, err := r.Resolve(q, ci) - require.NoError(t, err) - require.Equal(t, 0, resolver.HitCount()) - require.Equal(t, "example.com.", answer.Question[0].Name) + for name, opt := range tests { + t.Run(name, func(t *testing.T) { + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + q.Id = 1234 + + answer, err := r.Resolve(q, ci) + require.NoError(t, err) + require.Equal(t, 0, resolver.HitCount()) + require.Equal(t, "example.com.", answer.Question[0].Name) + require.Equal(t, dns.TypeA, answer.Question[0].Qtype) + require.Equal(t, uint16(1234), answer.Id) + require.Equal(t, dns.RcodeNameError, answer.Rcode) + require.True(t, answer.Response) + }) + } } From 9126b84966feafc4f9e27a2b3b0c7c8ffe7360cf Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 4 May 2025 15:01:27 +0200 Subject: [PATCH 5/9] RR types --- lua-rr.go | 409 +++++++++++++++++++++++++++++++++++++++++++++++++++ lua-types.go | 2 +- lua.go | 1 + lua_test.go | 40 +++++ 4 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 lua-rr.go diff --git a/lua-rr.go b/lua-rr.go new file mode 100644 index 00000000..cbe3b33b --- /dev/null +++ b/lua-rr.go @@ -0,0 +1,409 @@ +package rdns + +import ( + "fmt" + "net" + "reflect" + "strings" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// RR functions + +const luaRRHeaderMetatableName = "RR" + +func (s *LuaScript) RegisterRRTypes() { + L := s.L + + mt := L.NewTypeMetatable(luaRRHeaderMetatableName) + L.SetGlobal(luaRRHeaderMetatableName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + table := L.CheckTable(1) + lrtype, ok := table.RawGetString("rtype").(lua.LNumber) + if !ok { + L.ArgError(1, "rtype must be a number") + return 0 + } + rtype := uint16(lrtype) + + rrFunc, ok := dns.TypeToRR[rtype] + if !ok { + L.ArgError(1, "unknown rtype") + return 0 + } + rr := rrFunc() + + var err error + table.ForEach(func(k, v lua.LValue) { + if k.Type() != lua.LTString { + if err != nil { // Only record the first error + err = fmt.Errorf("expecte string keys, got %s", k.Type().String()) + } + return + } + if k.String() == "rtype" { + // We don't allow this to be set or updated + rr.Header().Rrtype = rtype + return + } + if setErr := rrDB.set(rr, k.String(), v); setErr != nil && err == nil { + err = setErr + } + }) + if err != nil { + L.ArgError(1, err.Error()) + return 0 + } + + L.Push(userDataWithMetatable(L, luaRRHeaderMetatableName, rr)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + rr, ok := getUserDataArg[dns.RR](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + + lv, err := rrDB.get(rr, fieldName) + if err != nil { + L.ArgError(1, err.Error()) // TODO: figure out arg position + return 0 + } + L.Push(lv) + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + rr, ok := getUserDataArg[dns.RR](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + if fieldName == "" { + return 0 + } + value := L.CheckAny(3) + + // TODO: handle header fields directly + + if err := rrDB.set(rr, fieldName, value); err != nil { + L.ArgError(1, err.Error()) // TODO: figure out arg position + return 0 + } + return 0 + })) +} + +type rrFieldDB map[reflect.Type]map[string]rrFieldAccessors + +type rrFieldAccessors struct { + index []int + get func(reflect.Value) (lua.LValue, error) + set func(reflect.Value, lua.LValue) error +} + +var rrDB = func() rrFieldDB { + db := make(map[reflect.Type]map[string]rrFieldAccessors) + + for _, rrFunc := range dns.TypeToRR { + rr := rrFunc() + typ := reflect.TypeOf(rr) + db[typ] = rrFieldsForType(typ.Elem(), nil) + } + return db +}() + +func rrFieldsForType(typ reflect.Type, index []int) map[string]rrFieldAccessors { + fields := make(map[string]rrFieldAccessors) + for _, field := range reflect.VisibleFields(typ) { + if !field.IsExported() { + continue + } + // All RR have a header and we handle that directly, without reflection + if field.Name == "Hdr" { + continue + } + a := rrFieldAccessors{ + index: append(index, field.Index...), + } + switch field.Type { + case reflect.TypeOf(net.IP{}): + a.get, a.set = getIPField, setIPField + case reflect.TypeOf(""): + a.get, a.set = getStringField, setStringField + case reflect.TypeOf(uint8(0)): + a.get, a.set = getUint8Field, setUint8Field + case reflect.TypeOf(uint16(0)): + a.get, a.set = getUint16Field, setUint16Field + case reflect.TypeOf(uint32(0)): + a.get, a.set = getUint32Field, setUint32Field + case reflect.TypeOf(uint64(0)): + a.get, a.set = getUint64Field, setUint64Field + case reflect.TypeOf([]uint16{}): + a.get, a.set = getUint16SliceField, setUint16SliceField + case reflect.TypeOf([]string{}): + a.get, a.set = getStringSliceField, setStringSliceField + case reflect.TypeOf(dns.DS{}): // Composed in DLV + return rrFieldsForType(reflect.TypeOf(dns.DS{}), field.Index) + case reflect.TypeOf(dns.SVCB{}): // Composed in HTTPS + return rrFieldsForType(reflect.TypeOf(dns.SVCB{}), field.Index) + case reflect.TypeOf(dns.NSEC{}): // Composed in NXT + return rrFieldsForType(reflect.TypeOf(dns.NSEC{}), field.Index) + case reflect.TypeOf([]dns.APLPrefix{}): // Used in APL + a.get, a.set = getUnsupported(field.Name), setUnsupported(field.Name) + case reflect.TypeOf(dns.RRSIG{}): // Composed in SIG + return rrFieldsForType(reflect.TypeOf(dns.RRSIG{}), field.Index) + case reflect.TypeOf(dns.DNSKEY{}): // Composed in KEY + return rrFieldsForType(reflect.TypeOf(dns.DNSKEY{}), field.Index) + case reflect.TypeOf([]dns.SVCBKeyValue{}): // interface + a.get, a.set = getUnsupported(field.Name), setUnsupported(field.Name) + case reflect.TypeOf([]dns.EDNS0{}): // in OPT + // TODO:implement + default: + panic(fmt.Errorf("unsupported RR field value type %v in %s", field.Type, typ)) + } + + fields[strings.ToLower(field.Name)] = a + } + return fields +} + +func (db rrFieldDB) get(rr dns.RR, name string) (lua.LValue, error) { + // If the field is in the header, we handle that directly + switch name { + case "name": + return lua.LString(rr.Header().Name), nil + case "rtype": + return lua.LNumber(rr.Header().Rrtype), nil + case "class": + return lua.LNumber(rr.Header().Class), nil + case "ttl": + return lua.LNumber(rr.Header().Ttl), nil + case "rdlength": + return lua.LNumber(rr.Header().Rdlength), nil + } + + // Lookup the fields for this type + typeFields, ok := db[reflect.TypeOf(rr)] + if !ok { + return nil, luaArgError{1, fmt.Errorf("unsupported resource record type %v", reflect.TypeOf(rr).String())} + } + a, ok := typeFields[name] + if !ok { + return nil, luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} + } + fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) + return a.get(fieldValue) +} + +func (db rrFieldDB) set(rr dns.RR, name string, value lua.LValue) error { + // If the field is in the header, we handle that directly + switch name { + case "name": + if value.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string value, got %v", value.Type().String())} + } + rr.Header().Name = value.String() + return nil + case "rtype": + return luaArgError{2, fmt.Errorf("cannot change rtype directly")} + case "class": + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number value, got %v", value.Type().String())} + } + rr.Header().Class = uint16(value.(lua.LNumber)) + return nil + case "ttl": + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number value, got %v", value.Type().String())} + } + rr.Header().Ttl = uint32(value.(lua.LNumber)) + return nil + case "rdlength": + return luaArgError{2, fmt.Errorf("cannot change rdlength")} + } + + // Lookup the fields for this type + typeFields, ok := db[reflect.TypeOf(rr)] + if !ok { + return luaArgError{1, fmt.Errorf("unsupported resource record type %v", reflect.TypeOf(rr).String())} + } + a, ok := typeFields[name] + if !ok { + return luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} + } + fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) + return a.set(fieldValue, value) +} + +type luaArgError struct { + position int + error +} + +func getStringField(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(string) + return lua.LString(field), nil +} + +func setStringField(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string value, got %v", value.Type().String())} + } + fieldValue.SetString(value.String()) + return nil +} + +func getStringSliceField(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]string) + table := (*lua.LState)(nil).CreateTable(len(field), 0) + for _, v := range field { + table.Append(lua.LString(v)) + } + return table, nil +} + +func setStringSliceField(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + stringValues := make([]string, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTString { + return luaArgError{3, fmt.Errorf("expected string, got %v", element.Type().String())} + } + s := element.String() + stringValues = append(stringValues, s) + } + newVal := reflect.ValueOf(stringValues) + fieldValue.Set(newVal) + return nil +} + +func getUint16SliceField(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]uint16) + table := (*lua.LState)(nil).CreateTable(len(field), 0) + for _, v := range field { + table.Append(lua.LNumber(v)) + } + return table, nil +} + +func setUint16SliceField(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + values := make([]uint16, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", element.Type().String())} + } + lv := element.(lua.LNumber) + values = append(values, uint16(lv)) + } + newVal := reflect.ValueOf(values) + fieldValue.Set(newVal) + return nil +} + +func getUint8Field(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint8) + return lua.LNumber(field), nil +} + +func setUint8Field(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint16Field(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint16) + return lua.LNumber(field), nil +} + +func setUint16Field(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint32Field(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint32) + return lua.LNumber(field), nil +} + +func setUint32Field(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getUint64Field(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(uint64) + return lua.LNumber(field), nil +} + +func setUint64Field(fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTNumber { + return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} + } + fieldValue.SetUint(uint64(value.(lua.LNumber))) + return nil +} + +func getIPField(fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().(net.IP) + if field == nil { + return lua.LNil, nil + } + return lua.LString(field.String()), nil +} + +func setIPField(fieldValue reflect.Value, value lua.LValue) error { + switch value.Type() { + case lua.LTString: + ip := net.ParseIP(value.String()) + if ip == nil { + return nil + } + fieldValue.SetBytes(ip) + case lua.LTNil: + fieldValue.SetZero() + default: + return luaArgError{3, fmt.Errorf("expected string or nil, got %v", value.Type().String())} + } + return nil +} + +func getUnsupported(name string) func(v reflect.Value) (lua.LValue, error) { + return func(v reflect.Value) (lua.LValue, error) { + return nil, luaArgError{2, fmt.Errorf("getting %q not supported", name)} + } +} + +func setUnsupported(name string) func(v reflect.Value, value lua.LValue) error { + return func(v reflect.Value, value lua.LValue) error { + return luaArgError{2, fmt.Errorf("setting %q not supported", name)} + } +} diff --git a/lua-types.go b/lua-types.go index c6ba79fb..a9e95d1b 100644 --- a/lua-types.go +++ b/lua-types.go @@ -13,7 +13,7 @@ func (s *LuaScript) RegisterConstants() { L.SetGlobal("Type"+name, lua.LNumber(value)) } - // Register ClassINET, etc + // Register ClassIN, etc for value, name := range dns.ClassToString { L.SetGlobal("Class"+name, lua.LNumber(value)) } diff --git a/lua.go b/lua.go index 5c1780f7..ee59a37f 100644 --- a/lua.go +++ b/lua.go @@ -99,6 +99,7 @@ func (r *Lua) newScript() (*LuaScript, error) { s.RegisterConstants() s.RegisterMessageType() s.RegisterQuestionType() + s.RegisterRRTypes() s.RegisterErrorType() // Inject the resolvers into the state (so they can be used in the script) diff --git a/lua_test.go b/lua_test.go index 94c14949..d219556c 100644 --- a/lua_test.go +++ b/lua_test.go @@ -126,3 +126,43 @@ end`, }) } } + +func TestLuaRROperations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new TXT record and test value set/get operations + rr = RR.new({rtype=TypeTXT, name="example.com.", class=ClassIN, ttl=60, txt={"hello", "world"}}) + if rr.txt[1] ~= "hello" then + return nil, Error.new("unexpected value") + end + rr.txt = {"bla"} + if rr.txt[1] ~= "bla" then + return nil, Error.new("unexpected value in txt") + end + + -- Create a new A record and test value set/get operations + rr = RR.new({rtype=TypeA, name="example.com.", class=ClassIN, ttl=60, a="1.2.3.4"}) + if rr.rtype ~= TypeA or rr.name ~= "example.com." then + return nil, Error.new("unexpected name value") + end + rr.a = "1.1.1.1" + if rr.a ~= "1.1.1.1" then + return nil, Error.new("unexpected ip value") + end + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} From 7dc00f4562a849c254a1e1a905b9ab08577cdf17 Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 4 May 2025 18:35:57 +0200 Subject: [PATCH 6/9] rework message fields and methods --- lua-error.go | 2 +- lua-helpers.go | 20 +----- lua-message.go | 160 +++++++++++++++++++++++++++--------------------- lua-question.go | 50 ++++++++++++--- lua-rr.go | 2 - lua_test.go | 46 +++++++++++--- 6 files changed, 172 insertions(+), 108 deletions(-) diff --git a/lua-error.go b/lua-error.go index ce562c40..50ac5b29 100644 --- a/lua-error.go +++ b/lua-error.go @@ -25,6 +25,6 @@ func (s *LuaScript) RegisterErrorType() { // methods L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "error": getter(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }), + "error": method(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }), })) } diff --git a/lua-helpers.go b/lua-helpers.go index 62d39961..bdf99b98 100644 --- a/lua-helpers.go +++ b/lua-helpers.go @@ -9,25 +9,9 @@ import ( // Helper functions -func getter[T any](f func(*lua.LState, T)) func(*lua.LState) int { +func method[T any](f func(*lua.LState, T)) func(*lua.LState) int { return func(L *lua.LState) int { - if L.GetTop() > 1 { - L.ArgError(1, "no arguments expected") - return 0 - } - ud := L.CheckUserData(1) - r, ok := ud.Value.(T) - if !ok { - L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) - return 0 - } - f(L, r) - return 1 - } -} -func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int { - return func(L *lua.LState) int { - if L.GetTop() < 2 { + if L.GetTop() < 1 { L.ArgError(1, "expected at least 1 argument") return 0 } diff --git a/lua-message.go b/lua-message.go index f29071ef..09c65bd5 100644 --- a/lua-message.go +++ b/lua-message.go @@ -1,6 +1,8 @@ package rdns import ( + "fmt" + "github.com/miekg/dns" lua "github.com/yuin/gopher-lua" ) @@ -19,80 +21,96 @@ func (s *LuaScript) RegisterMessageType() { return 1 })) // methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_questions": getter(func(L *lua.LState, msg *dns.Msg) { - table := L.CreateTable(len(msg.Question), 0) - for _, q := range msg.Question { - lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) - table.Append(lv) + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + msg, ok := getUserDataArg[*dns.Msg](L, 1) + if !ok { + return 0 } - L.Push(table) - }), - "set_questions": setter(func(L *lua.LState, msg *dns.Msg) { - table := L.CheckTable(2) - n := table.Len() - questions := make([]dns.Question, 0, n) - for i := range n { - element := table.RawGetInt(i + 1) - if element.Type() != lua.LTUserData { - L.ArgError(1, "invalid type, expected userdata") - return - } - lq := element.(*lua.LUserData) - q, ok := lq.Value.(*dns.Question) - if !ok { - L.ArgError(1, "invalid type, expected question") - return + fieldName := L.CheckString(2) + switch fieldName { + case "questions": + table := L.CreateTable(len(msg.Question), 0) + for _, q := range msg.Question { + lv := userDataWithMetatable(L, luaQuestionMetatableName, &q) + table.Append(lv) } - questions = append(questions, *q) + L.Push(table) + case "id": + L.Push(lua.LNumber(msg.Id)) + case "response": + L.Push(lua.LBool(msg.Response)) + case "rcode": + L.Push(lua.LNumber(msg.Rcode)) + case "recursion_desired": + L.Push(lua.LBool(msg.RecursionDesired)) + case "recursion_available": + L.Push(lua.LBool(msg.RecursionAvailable)) + case "authenticated_data": + L.Push(lua.LBool(msg.AuthenticatedData)) + case "set_reply": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) { + request, ok := getUserDataArg[*dns.Msg](L, 2) + if !ok { + return + } + msg.SetReply(request) + }))) + case "set_question": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) { + msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3))) + }))) + default: + L.ArgError(2, fmt.Sprintf("message does not have field %q", fieldName)) + return 0 } - msg.Question = questions - }), - "set_question": setter(func(L *lua.LState, msg *dns.Msg) { - msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3))) - }), - "set_id": setter(func(L *lua.LState, msg *dns.Msg) { - msg.Id = uint16(L.CheckInt(2)) - }), - "get_id": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LNumber(msg.Id)) - }), - "set_response": setter(func(L *lua.LState, msg *dns.Msg) { - msg.Response = L.CheckBool(2) - }), - "get_response": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LBool(msg.Response)) - }), - "set_reply": setter(func(L *lua.LState, msg *dns.Msg) { - request, ok := getUserDataArg[*dns.Msg](L, 2) + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + msg, ok := getUserDataArg[*dns.Msg](L, 1) if !ok { - return + return 0 } - msg.SetReply(request) - }), - "set_rcode": setter(func(L *lua.LState, msg *dns.Msg) { - msg.Rcode = L.CheckInt(2) - }), - "get_rcode": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LNumber(msg.Rcode)) - }), - "set_rd": setter(func(L *lua.LState, msg *dns.Msg) { - msg.RecursionDesired = L.CheckBool(2) - }), - "get_rd": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LBool(msg.RecursionDesired)) - }), - "set_ra": setter(func(L *lua.LState, msg *dns.Msg) { - msg.RecursionAvailable = L.CheckBool(2) - }), - "get_ra": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LBool(msg.RecursionAvailable)) - }), - "set_ad": setter(func(L *lua.LState, msg *dns.Msg) { - msg.AuthenticatedData = L.CheckBool(2) - }), - "get_ad": getter(func(L *lua.LState, msg *dns.Msg) { - L.Push(lua.LBool(msg.AuthenticatedData)) - }), - })) + fieldName := L.CheckString(2) + switch fieldName { + case "questions": + table := L.CheckTable(3) + n := table.Len() + questions := make([]dns.Question, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTUserData { + L.ArgError(3, "invalid type, expected userdata") + return 0 + } + lq := element.(*lua.LUserData) + q, ok := lq.Value.(*dns.Question) + if !ok { + L.ArgError(3, "invalid type, expected question") + return 0 + } + questions = append(questions, *q) + } + msg.Question = questions + case "id": + msg.Id = uint16(L.CheckInt(3)) + case "response": + msg.Response = L.CheckBool(3) + case "rcode": + msg.Rcode = L.CheckInt(3) + case "recursion_desired": + msg.RecursionDesired = L.CheckBool(3) + case "recursion_available": + msg.RecursionAvailable = L.CheckBool(3) + case "authenticated_data": + msg.AuthenticatedData = L.CheckBool(3) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 0 + })) } diff --git a/lua-question.go b/lua-question.go index 4170d0a5..b4fc94f9 100644 --- a/lua-question.go +++ b/lua-question.go @@ -1,6 +1,8 @@ package rdns import ( + "fmt" + "github.com/miekg/dns" lua "github.com/yuin/gopher-lua" ) @@ -32,12 +34,44 @@ func (s *LuaScript) RegisterQuestionType() { })) // methods - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "get_name": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }), - "get_qtype": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }), - "get_qclass": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }), - "set_name": setter(func(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }), - "set_qtype": setter(func(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }), - "set_qclass": setter(func(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }), - })) + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + question, ok := getUserDataArg[*dns.Question](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "name": + L.Push(lua.LString(question.Name)) + case "qtype": + L.Push(lua.LNumber(question.Qtype)) + case "qclass": + L.Push(lua.LNumber(question.Qclass)) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + question, ok := getUserDataArg[*dns.Question](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "name": + question.Name = L.CheckString(3) + case "qtype": + question.Qtype = uint16(L.CheckNumber(3)) + case "qclass": + question.Qclass = uint16(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("question does not have field %q", fieldName)) + return 0 + } + return 0 + })) } diff --git a/lua-rr.go b/lua-rr.go index cbe3b33b..a0715809 100644 --- a/lua-rr.go +++ b/lua-rr.go @@ -92,8 +92,6 @@ func (s *LuaScript) RegisterRRTypes() { } value := L.CheckAny(3) - // TODO: handle header fields directly - if err := rrDB.set(rr, fieldName, value); err != nil { L.ArgError(1, err.Error()) // TODO: figure out arg position return 0 diff --git a/lua_test.go b/lua_test.go index d219556c..98303332 100644 --- a/lua_test.go +++ b/lua_test.go @@ -74,10 +74,10 @@ func TestLuaStaticAnswer(t *testing.T) { function Resolve(msg, ci) local question = Question.new("example.com.", TypeA) local answer = Message.new() - answer:set_id(msg:get_id()) - answer:set_questions({question}) - answer:set_response(true) - answer:set_rcode(RcodeNXDOMAIN) + answer.id = msg.id + answer.questions = { question } + answer.response = true + answer.rcode = RcodeNXDOMAIN return answer, nil end`, }, @@ -86,9 +86,9 @@ end`, function Resolve(msg, ci) local answer = Message.new() answer:set_question("example.com.", TypeA) - answer:set_id(msg:get_id()) - answer:set_response(true) - answer:set_rcode(RcodeNXDOMAIN) + answer.id = msg.id + answer.response = true + answer.rcode = RcodeNXDOMAIN return answer, nil end`, }, @@ -97,7 +97,7 @@ end`, function Resolve(msg, ci) local answer = Message.new() answer:set_reply(msg) - answer:set_rcode(RcodeNXDOMAIN) + answer.rcode = RcodeNXDOMAIN return answer, nil end`, }, @@ -127,6 +127,36 @@ end`, } } +func TestLuaQuestionOperations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new Question record and test value set/get operations + local question = Question.new("example.com.", TypeA) + if question.name ~= "example.com." or question.qtype ~= TypeA then + return nil, Error.new("unexpected name value") + end + question.name = "testing." + if question.name ~= "testing." then + return nil, Error.new("unexpected name value") + end + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} + func TestLuaRROperations(t *testing.T) { opt := LuaOptions{ Script: ` From f9cc8ac89273dc0d1a10dc590a1861f8c7cf6616 Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 1 Jun 2025 17:13:33 +0200 Subject: [PATCH 7/9] dev --- lua-error.go | 5 +- lua-helpers.go | 34 ++++++- lua-message.go | 35 ++++++- lua-rr.go | 105 ++++++++++++++------ lua-types.go | 19 ++++ lua.go | 1 + lua_test.go | 263 +++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 422 insertions(+), 40 deletions(-) diff --git a/lua-error.go b/lua-error.go index 50ac5b29..8224d97b 100644 --- a/lua-error.go +++ b/lua-error.go @@ -25,6 +25,9 @@ func (s *LuaScript) RegisterErrorType() { // methods L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "error": method(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }), + "error": method(func(L *lua.LState, r error) int { + L.Push(lua.LString(r.Error())) + return 1 + }), })) } diff --git a/lua-helpers.go b/lua-helpers.go index bdf99b98..9efb94ba 100644 --- a/lua-helpers.go +++ b/lua-helpers.go @@ -9,7 +9,7 @@ import ( // Helper functions -func method[T any](f func(*lua.LState, T)) func(*lua.LState) int { +func method[T any](f func(*lua.LState, T) int) func(*lua.LState) int { return func(L *lua.LState) int { if L.GetTop() < 1 { L.ArgError(1, "expected at least 1 argument") @@ -21,8 +21,7 @@ func method[T any](f func(*lua.LState, T)) func(*lua.LState) int { L.ArgError(1, fmt.Sprintf("%v expected", reflect.TypeFor[T]())) return 0 } - f(L, r) - return 1 + return f(L, r) } } @@ -35,3 +34,32 @@ func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { } return v, true } + +type numbers interface { + int | int8 | int16 | int32 | int64 | float32 | float64 | uint | uint8 | uint16 | uint32 | uint64 +} + +func getNumberSlice[T numbers](L *lua.LState, n int) ([]T, bool) { + table := L.CheckTable(n) + size := table.Len() + values := make([]T, 0, size) + for i := range size { + element := table.RawGetInt(i + 1) + if element.Type() != lua.LTNumber { + L.ArgError(n, "invalid type, expected number") + return nil, false + } + value := T(element.(lua.LNumber)) + values = append(values, value) + } + return values, true +} + +func numberSliceToTable[T numbers](L *lua.LState, values []T) *lua.LTable { + table := L.CreateTable(len(values), 0) + for _, value := range values { + table.Append(lua.LNumber(value)) + } + L.Push(table) + return table +} diff --git a/lua-message.go b/lua-message.go index 09c65bd5..c3490305 100644 --- a/lua-message.go +++ b/lua-message.go @@ -20,7 +20,7 @@ func (s *LuaScript) RegisterMessageType() { L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg))) return 1 })) - // methods + // methods and fields L.SetField(mt, "__index", L.NewFunction( func(L *lua.LState) int { msg, ok := getUserDataArg[*dns.Msg](L, 1) @@ -50,17 +50,44 @@ func (s *LuaScript) RegisterMessageType() { L.Push(lua.LBool(msg.AuthenticatedData)) case "set_reply": L.Push(L.NewFunction( - method(func(L *lua.LState, msg *dns.Msg) { + method(func(L *lua.LState, msg *dns.Msg) int { request, ok := getUserDataArg[*dns.Msg](L, 2) if !ok { - return + return 0 } msg.SetReply(request) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 }))) case "set_question": L.Push(L.NewFunction( - method(func(L *lua.LState, msg *dns.Msg) { + method(func(L *lua.LState, msg *dns.Msg) int { msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3))) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 + }))) + case "is_edns0": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + opt := msg.IsEdns0() + if opt == nil { + L.Push(lua.LNil) + return 1 + } + // TODO: Return an OPT metatable name here, not generic RR + lv := userDataWithMetatable(L, luaRRHeaderMetatableName, opt) + L.Push(lv) + return 1 + }))) + case "set_edns0": + L.Push(L.NewFunction( + method(func(L *lua.LState, msg *dns.Msg) int { + msg.SetEdns0(uint16(L.CheckNumber(2)), L.CheckBool(3)) + lv := userDataWithMetatable(L, luaMessageMetatableName, msg) + L.Push(lv) + return 1 }))) default: L.ArgError(2, fmt.Sprintf("message does not have field %q", fieldName)) diff --git a/lua-rr.go b/lua-rr.go index a0715809..22f74a8d 100644 --- a/lua-rr.go +++ b/lua-rr.go @@ -50,7 +50,7 @@ func (s *LuaScript) RegisterRRTypes() { rr.Header().Rrtype = rtype return } - if setErr := rrDB.set(rr, k.String(), v); setErr != nil && err == nil { + if setErr := rrDB.set(L, rr, k.String(), v); setErr != nil && err == nil { err = setErr } }) @@ -72,7 +72,7 @@ func (s *LuaScript) RegisterRRTypes() { } fieldName := L.CheckString(2) - lv, err := rrDB.get(rr, fieldName) + lv, err := rrDB.get(L, rr, fieldName) if err != nil { L.ArgError(1, err.Error()) // TODO: figure out arg position return 0 @@ -92,7 +92,7 @@ func (s *LuaScript) RegisterRRTypes() { } value := L.CheckAny(3) - if err := rrDB.set(rr, fieldName, value); err != nil { + if err := rrDB.set(L, rr, fieldName, value); err != nil { L.ArgError(1, err.Error()) // TODO: figure out arg position return 0 } @@ -104,8 +104,8 @@ type rrFieldDB map[reflect.Type]map[string]rrFieldAccessors type rrFieldAccessors struct { index []int - get func(reflect.Value) (lua.LValue, error) - set func(reflect.Value, lua.LValue) error + get func(*lua.LState, reflect.Value) (lua.LValue, error) + set func(*lua.LState, reflect.Value, lua.LValue) error } var rrDB = func() rrFieldDB { @@ -164,7 +164,7 @@ func rrFieldsForType(typ reflect.Type, index []int) map[string]rrFieldAccessors case reflect.TypeOf([]dns.SVCBKeyValue{}): // interface a.get, a.set = getUnsupported(field.Name), setUnsupported(field.Name) case reflect.TypeOf([]dns.EDNS0{}): // in OPT - // TODO:implement + a.get, a.set = getEDNS0SliceField, setEDNS0SliceField default: panic(fmt.Errorf("unsupported RR field value type %v in %s", field.Type, typ)) } @@ -174,7 +174,7 @@ func rrFieldsForType(typ reflect.Type, index []int) map[string]rrFieldAccessors return fields } -func (db rrFieldDB) get(rr dns.RR, name string) (lua.LValue, error) { +func (db rrFieldDB) get(L *lua.LState, rr dns.RR, name string) (lua.LValue, error) { // If the field is in the header, we handle that directly switch name { case "name": @@ -199,10 +199,10 @@ func (db rrFieldDB) get(rr dns.RR, name string) (lua.LValue, error) { return nil, luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} } fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) - return a.get(fieldValue) + return a.get(L, fieldValue) } -func (db rrFieldDB) set(rr dns.RR, name string, value lua.LValue) error { +func (db rrFieldDB) set(L *lua.LState, rr dns.RR, name string, value lua.LValue) error { // If the field is in the header, we handle that directly switch name { case "name": @@ -239,7 +239,7 @@ func (db rrFieldDB) set(rr dns.RR, name string, value lua.LValue) error { return luaArgError{2, fmt.Errorf("unknown field name %q for type %v", name, reflect.TypeOf(rr).String())} } fieldValue := reflect.ValueOf(rr).Elem().FieldByIndex(a.index) - return a.set(fieldValue, value) + return a.set(L, fieldValue, value) } type luaArgError struct { @@ -247,12 +247,12 @@ type luaArgError struct { error } -func getStringField(fieldValue reflect.Value) (lua.LValue, error) { +func getStringField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(string) return lua.LString(field), nil } -func setStringField(fieldValue reflect.Value, value lua.LValue) error { +func setStringField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTString { return luaArgError{3, fmt.Errorf("expected string value, got %v", value.Type().String())} } @@ -260,16 +260,16 @@ func setStringField(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getStringSliceField(fieldValue reflect.Value) (lua.LValue, error) { +func getStringSliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().([]string) - table := (*lua.LState)(nil).CreateTable(len(field), 0) + table := L.CreateTable(len(field), 0) for _, v := range field { table.Append(lua.LString(v)) } return table, nil } -func setStringSliceField(fieldValue reflect.Value, value lua.LValue) error { +func setStringSliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTTable { return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} } @@ -289,16 +289,16 @@ func setStringSliceField(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUint16SliceField(fieldValue reflect.Value) (lua.LValue, error) { +func getUint16SliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().([]uint16) - table := (*lua.LState)(nil).CreateTable(len(field), 0) + table := L.CreateTable(len(field), 0) for _, v := range field { table.Append(lua.LNumber(v)) } return table, nil } -func setUint16SliceField(fieldValue reflect.Value, value lua.LValue) error { +func setUint16SliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTTable { return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} } @@ -318,12 +318,12 @@ func setUint16SliceField(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUint8Field(fieldValue reflect.Value) (lua.LValue, error) { +func getUint8Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(uint8) return lua.LNumber(field), nil } -func setUint8Field(fieldValue reflect.Value, value lua.LValue) error { +func setUint8Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTNumber { return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} } @@ -331,12 +331,12 @@ func setUint8Field(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUint16Field(fieldValue reflect.Value) (lua.LValue, error) { +func getUint16Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(uint16) return lua.LNumber(field), nil } -func setUint16Field(fieldValue reflect.Value, value lua.LValue) error { +func setUint16Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTNumber { return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} } @@ -344,12 +344,12 @@ func setUint16Field(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUint32Field(fieldValue reflect.Value) (lua.LValue, error) { +func getUint32Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(uint32) return lua.LNumber(field), nil } -func setUint32Field(fieldValue reflect.Value, value lua.LValue) error { +func setUint32Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTNumber { return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} } @@ -357,12 +357,12 @@ func setUint32Field(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUint64Field(fieldValue reflect.Value) (lua.LValue, error) { +func getUint64Field(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(uint64) return lua.LNumber(field), nil } -func setUint64Field(fieldValue reflect.Value, value lua.LValue) error { +func setUint64Field(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { if value.Type() != lua.LTNumber { return luaArgError{3, fmt.Errorf("expected number, got %v", value.Type().String())} } @@ -370,7 +370,7 @@ func setUint64Field(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getIPField(fieldValue reflect.Value) (lua.LValue, error) { +func getIPField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { field := fieldValue.Interface().(net.IP) if field == nil { return lua.LNil, nil @@ -378,7 +378,7 @@ func getIPField(fieldValue reflect.Value) (lua.LValue, error) { return lua.LString(field.String()), nil } -func setIPField(fieldValue reflect.Value, value lua.LValue) error { +func setIPField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { switch value.Type() { case lua.LTString: ip := net.ParseIP(value.String()) @@ -394,14 +394,55 @@ func setIPField(fieldValue reflect.Value, value lua.LValue) error { return nil } -func getUnsupported(name string) func(v reflect.Value) (lua.LValue, error) { - return func(v reflect.Value) (lua.LValue, error) { +func getEDNS0SliceField(L *lua.LState, fieldValue reflect.Value) (lua.LValue, error) { + field := fieldValue.Interface().([]dns.EDNS0) + table := L.CreateTable(len(field), 0) + for _, v := range field { + // TODO: This is a hacky way to determine the name of the + // metatable for EDNS0 recods. Ideally we reference some name + // constant or function exposed by the code that registers the + // EDNS0 types. + mtName := reflect.TypeOf(v).String() + if i := strings.LastIndex(mtName, "."); i >= 0 { + mtName = mtName[i+1:] + } + lv := userDataWithMetatable(L, mtName, v) + table.Append(lv) + } + return table, nil +} + +func setEDNS0SliceField(L *lua.LState, fieldValue reflect.Value, value lua.LValue) error { + if value.Type() != lua.LTTable { + return luaArgError{3, fmt.Errorf("expected array, got %v", value.Type().String())} + } + table := value.(*lua.LTable) + n := table.Len() + stringValues := make([]dns.EDNS0, 0, n) + for i := range n { + element := table.RawGetInt(i + 1) + ud, ok := element.(*lua.LUserData) + if !ok { + return luaArgError{3, fmt.Errorf("expected userdata, got %v", element.Type().String())} + } + value, ok := ud.Value.(dns.EDNS0) + if !ok { + return luaArgError{3, fmt.Errorf("expected EDNS0, got %T", ud)} + } + stringValues = append(stringValues, value) + } + newVal := reflect.ValueOf(stringValues) + fieldValue.Set(newVal) + return nil +} +func getUnsupported(name string) func(L *lua.LState, v reflect.Value) (lua.LValue, error) { + return func(L *lua.LState, v reflect.Value) (lua.LValue, error) { return nil, luaArgError{2, fmt.Errorf("getting %q not supported", name)} } } -func setUnsupported(name string) func(v reflect.Value, value lua.LValue) error { - return func(v reflect.Value, value lua.LValue) error { +func setUnsupported(name string) func(L *lua.LState, v reflect.Value, value lua.LValue) error { + return func(L *lua.LState, v reflect.Value, value lua.LValue) error { return luaArgError{2, fmt.Errorf("setting %q not supported", name)} } } diff --git a/lua-types.go b/lua-types.go index a9e95d1b..614fd062 100644 --- a/lua-types.go +++ b/lua-types.go @@ -22,6 +22,25 @@ func (s *LuaScript) RegisterConstants() { for value, name := range dns.RcodeToString { L.SetGlobal("Rcode"+name, lua.LNumber(value)) } + + // Register EDNS0 option codes + for name, value := range map[string]uint16{ + "EDNS0LLQ": 0x1, + "EDNS0UL": 0x2, + "EDNS0NSID": 0x3, + "EDNS0ESU": 0x4, + "EDNS0DAU": 0x5, + "EDNS0DHU": 0x6, + "EDNS0N3U": 0x7, + "EDNS0SUBNET": 0x8, + "EDNS0EXPIRE": 0x9, + "EDNS0COOKIE": 0xa, + "EDNS0TCPKEEPALIVE": 0xb, + "EDNS0PADDING": 0xc, + "EDNS0EDE": 0xf, + } { + L.SetGlobal(name, lua.LNumber(value)) + } } func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { diff --git a/lua.go b/lua.go index ee59a37f..5fb1fe9c 100644 --- a/lua.go +++ b/lua.go @@ -100,6 +100,7 @@ func (r *Lua) newScript() (*LuaScript, error) { s.RegisterMessageType() s.RegisterQuestionType() s.RegisterRRTypes() + s.RegisterEDNS0Types() s.RegisterErrorType() // Inject the resolvers into the state (so they can be used in the script) diff --git a/lua_test.go b/lua_test.go index 98303332..340f89eb 100644 --- a/lua_test.go +++ b/lua_test.go @@ -1,6 +1,7 @@ package rdns import ( + "net" "testing" "github.com/miekg/dns" @@ -196,3 +197,265 @@ end`, _, err = r.Resolve(q, ci) require.NoError(t, err) } + +func TestLuaEDNS0Operations(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Create a new EDNS0 COOKIE and test value set/get operations + edns0 = EDNS0_COOKIE.new("24a5ac1a012345ff") + if edns0.cookie ~= "24a5ac1a012345ff" then + return nil, Error.new("unexpected value") + end + edns0.cookie = "bla" + if edns0.cookie ~= "bla" then + return nil, Error.new("unexpected value in edns0 cookie") + end + + -- Create a new EDNS0 DAU and test value set/get operations + edns0 = EDNS0_DAU.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 dau") + end + + -- Create a new EDNS0 DHU and test value set/get operations + edns0 = EDNS0_DHU.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 dhu") + end + + -- Create a new EDNS0 EDE and test value set/get operations + edns0 = EDNS0_EDE.new(15, "domain blocked") + if edns0.infocode ~= 15 then + return nil, Error.new("unexpected value") + end + edns0.extratext = "testing" + if edns0.extratext ~= "testing" then + return nil, Error.new("unexpected value in edns0 ede") + end + + -- Create a new EDNS0 ESU and test value set/get operations + edns0 = EDNS0_ESU.new("http://example.com") + if edns0.uri ~= "http://example.com" then + return nil, Error.new("unexpected value") + end + edns0.uri = "http://example.org" + if edns0.uri ~= "http://example.org" then + return nil, Error.new("unexpected value in edns0 ede") + end + + -- Create a new EDNS0 EXPIRE and test value set/get operations + edns0 = EDNS0_EXPIRE.new(123) + if edns0.expire ~= 123 then + return nil, Error.new("unexpected value") + end + edns0.expire = 124 + if edns0.expire ~= 124 then + return nil, Error.new("unexpected value in edns0 expire") + end + + -- Create a new EDNS0 LLQ and test value set/get operations + edns0 = EDNS0_LLQ.new(1, 16, 0, 1234, 4321) + if edns0.version ~= 1 then + return nil, Error.new("unexpected value") + end + if edns0.opcode ~= 16 then + return nil, Error.new("unexpected value") + end + if edns0.error ~= 0 then + return nil, Error.new("unexpected value") + end + if edns0.id ~= 1234 then + return nil, Error.new("unexpected value") + end + if edns0.leaselife ~= 4321 then + return nil, Error.new("unexpected value") + end + edns0.error = 1 + if edns0.error ~= 1 then + return nil, Error.new("unexpected value in edns0 llq") + end + + -- Create a new EDNS0 LOCAL and test value set/get operations + edns0 = EDNS0_LOCAL.new(65001, "somedata") + if edns0.code ~= 65001 then + return nil, Error.new("unexpected value") + end + edns0.data = "otherdata" + if edns0.data ~= "otherdata" then + return nil, Error.new("unexpected value in edns0 local") + end + + -- Create a new EDNS0 N3U and test value set/get operations + edns0 = EDNS0_N3U.new({ 1, 2, 3 }) + if edns0.algcode[1] ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.algcode = { 0 } + if edns0.algcode[1] ~= 0 then + return nil, Error.new("unexpected value in edns0 n3u") + end + + -- Create a new EDNS0 NSID and test value set/get operations + edns0 = EDNS0_NSID.new("someid") + if edns0.nsid ~= "someid" then + return nil, Error.new("unexpected value") + end + edns0.nsid = "otherid" + if edns0.nsid ~= "otherid" then + return nil, Error.new("unexpected value in edns0 nsid") + end + + -- Create a new EDNS0 PADDING and test value set/get operations + edns0 = EDNS0_PADDING.new("somepadding") + if edns0.padding ~= "somepadding" then + return nil, Error.new("unexpected value") + end + edns0.padding = "otherpadding" + if edns0.padding ~= "otherpadding" then + return nil, Error.new("unexpected value in edns0 padding") + end + + -- Create a new EDNS0 SUBNET and test value set/get operations + edns0 = EDNS0_SUBNET.new(1, 32, 0, "192.168.0.0") + if edns0.family ~= 1 then + return nil, Error.new("unexpected value") + end + if edns0.sourcenetmask ~= 32 then + return nil, Error.new("unexpected value") + end + if edns0.sourcescope ~= 0 then + return nil, Error.new("unexpected value") + end + if edns0.address ~= "192.168.0.0" then + return nil, Error.new("unexpected value") + end + edns0.address = "172.16.0.0" + if edns0.address ~= "172.16.0.0" then + return nil, Error.new("unexpected value in edns0 subnet") + end + + -- Create a new EDNS0 TCP_KEEPALIVE and test value set/get operations + edns0 = EDNS0_TCP_KEEPALIVE.new(1) + if edns0.timeout ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.timeout = 2 + if edns0.timeout ~= 2 then + return nil, Error.new("unexpected value in edns0 tcp keepalive") + end + + -- Create a new EDNS0 UL and test value set/get operations + edns0 = EDNS0_UL.new(1, 2) + if edns0.lease ~= 1 then + return nil, Error.new("unexpected value") + end + edns0.keylease = 3 + if edns0.keylease ~= 3 then + return nil, Error.new("unexpected value in edns0 ul") + end + + return nil, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + + _, err = r.Resolve(q, ci) + require.NoError(t, err) +} + +func TestLuaRREDNS0(t *testing.T) { + opt := LuaOptions{ + Script: ` +function Resolve(msg, ci) + -- Read the EDNS0 options + opt = msg:is_edns0() + if opt == nil then + return nil, Error.new("no edns0") + end + + -- Grab the options + options = opt.option + + -- The first one should be a cookie + if options[1].option ~= EDNS0COOKIE then + return nil, Error.new("unexpected subnet option value in cookie option") + end + if options[1].cookie ~= "testing" then + return nil, Error.new("unexpected value in edns0 cookie option") + end + + -- The second one should be a subnet option + if options[2].option ~= EDNS0SUBNET then + return nil, Error.new("unexpected subnet option value in subnet option") + end + if options[2].family ~= 1 then + return nil, Error.new("unexpected value in edns0 subnet option") + end + + -- Reply with an extended error message + local answer = Message.new() + answer:set_reply(msg) + answer.rcode = RcodeNXDOMAIN + answer:set_edns0(4096, false) + edns0 = answer:is_edns0() + edns0.option = { + EDNS0_EDE.new(15, "totally blocked"), + } + + return answer, nil +end`, + } + + var ci ClientInfo + resolver := new(TestResolver) + + r, err := NewLua("test-lua", opt, resolver) + require.NoError(t, err) + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeMX) + q.SetEdns0(4096, false) + edns0 := q.IsEdns0() + + edns0.Option = append(edns0.Option, + &dns.EDNS0_COOKIE{ + Code: dns.EDNS0COOKIE, + Cookie: "testing", + }, + &dns.EDNS0_SUBNET{ + Code: dns.EDNS0SUBNET, + Family: 1, + SourceNetmask: 32, + SourceScope: 0, + Address: net.ParseIP("127.0.0.1"), + }, + ) + + a, err := r.Resolve(q, ci) + require.NoError(t, err) + + edns0 = a.IsEdns0() + require.NotNil(t, edns0) + require.Len(t, edns0.Option, 1) + require.Equal(t, uint16(dns.EDNS0EDE), edns0.Option[0].Option()) + ede := edns0.Option[0].(*dns.EDNS0_EDE) + require.Equal(t, uint16(15), ede.InfoCode) + require.Equal(t, "totally blocked", ede.ExtraText) +} From bb4c141bf518d8599174a4fffb94b14f2679fcd6 Mon Sep 17 00:00:00 2001 From: folbrich Date: Sun, 1 Jun 2025 18:55:42 +0200 Subject: [PATCH 8/9] dev --- lua-edns0.go | 900 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 900 insertions(+) create mode 100644 lua-edns0.go diff --git a/lua-edns0.go b/lua-edns0.go new file mode 100644 index 00000000..2740fbd4 --- /dev/null +++ b/lua-edns0.go @@ -0,0 +1,900 @@ +package rdns + +import ( + "fmt" + "net" + + "github.com/miekg/dns" + lua "github.com/yuin/gopher-lua" +) + +// EDNS0 functions + +func (s *LuaScript) RegisterEDNS0Types() { + s.registerEDNS0COOKIEType() + s.registerEDNS0DAUType() + s.registerEDNS0DHUType() + s.registerEDNS0EDEType() + s.registerEDNS0ESUType() + s.registerEDNS0EXPIREType() + s.registerEDNS0LLQType() + s.registerEDNS0LOCALType() + s.registerEDNS0N3UType() + s.registerEDNS0NSIDType() + s.registerEDNS0PADDINGType() + s.registerEDNS0SUBNETType() + s.registerEDNS0TCPKEEPALIVEType() + s.registerEDNS0ULType() +} + +func (s *LuaScript) registerEDNS0COOKIEType() { + L := s.L + mtName := "EDNS0_COOKIE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_COOKIE) + e.Code = dns.EDNS0COOKIE + nArgs := L.GetTop() + if nArgs >= 1 { // Cookie + e.Cookie = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_COOKIE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "cookie": + L.Push(lua.LString(e.Cookie)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_COOKIE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "cookie": + e.Cookie = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0DAUType() { + L := s.L + mtName := "EDNS0_DAU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_DAU) + e.Code = dns.EDNS0DAU + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DAU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DAU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0DHUType() { + L := s.L + mtName := "EDNS0_DHU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_DHU) + e.Code = dns.EDNS0DHU + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DHU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_DHU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0EDEType() { + L := s.L + mtName := "EDNS0_EDE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_EDE) + nArgs := L.GetTop() + if nArgs >= 1 { // Code + e.InfoCode = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Extra Text + e.ExtraText = L.CheckString(2) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EDE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "infocode": + L.Push(lua.LNumber(e.InfoCode)) + case "extratext": + L.Push(lua.LString(e.ExtraText)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EDE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "infocode": + e.InfoCode = uint16(L.CheckNumber(3)) + case "extratext": + e.ExtraText = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0ESUType() { + L := s.L + mtName := "EDNS0_ESU" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_ESU) + e.Code = dns.EDNS0ESU + nArgs := L.GetTop() + if nArgs >= 1 { // URI + e.Uri = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_ESU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "uri": + L.Push(lua.LString(e.Uri)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_ESU](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "uri": + e.Uri = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0EXPIREType() { + L := s.L + mtName := "EDNS0_EXPIRE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_EXPIRE) + e.Code = dns.EDNS0EXPIRE + nArgs := L.GetTop() + if nArgs >= 1 { // Expire + e.Expire = uint32(L.CheckNumber(1)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EXPIRE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "expire": + L.Push(lua.LNumber(e.Expire)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_EXPIRE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "expire": + e.Expire = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0LLQType() { + L := s.L + mtName := "EDNS0_LLQ" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_LLQ) + e.Code = dns.EDNS0LLQ + nArgs := L.GetTop() + if nArgs >= 1 { // Version + e.Version = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Opcode + e.Opcode = uint16(L.CheckNumber(2)) + } + if nArgs >= 3 { // Error + e.Error = uint16(L.CheckNumber(3)) + } + if nArgs >= 4 { // Id + e.Id = uint64(L.CheckNumber(4)) + } + if nArgs >= 5 { // LeaseLife + e.LeaseLife = uint32(L.CheckNumber(5)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LLQ](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "version": + L.Push(lua.LNumber(e.Version)) + case "opcode": + L.Push(lua.LNumber(e.Opcode)) + case "error": + L.Push(lua.LNumber(e.Error)) + case "id": + L.Push(lua.LNumber(e.Id)) + case "leaselife": + L.Push(lua.LNumber(e.LeaseLife)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LLQ](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "version": + e.Version = uint16(L.CheckNumber(3)) + case "opcode": + e.Opcode = uint16(L.CheckNumber(3)) + case "error": + e.Error = uint16(L.CheckNumber(3)) + case "id": + e.Id = uint64(L.CheckNumber(3)) + case "leaselife": + e.LeaseLife = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0LOCALType() { + L := s.L + mtName := "EDNS0_LOCAL" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_LOCAL) + nArgs := L.GetTop() + if nArgs >= 1 { // Code + e.Code = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Data + e.Data = []byte(L.CheckString(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LOCAL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "code": + L.Push(lua.LNumber(e.Code)) + case "data": + L.Push(lua.LString(e.Data)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_LOCAL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "code": + e.Code = uint16(L.CheckNumber(3)) + case "data": + e.Data = []byte(L.CheckString(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0N3UType() { + L := s.L + mtName := "EDNS0_N3U" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_N3U) + e.Code = dns.EDNS0N3U + nArgs := L.GetTop() + if nArgs >= 1 { // Alg Codes + values, _ := getNumberSlice[uint8](L, 1) + e.AlgCode = values + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_N3U](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "algcode": + L.Push(numberSliceToTable(L, e.AlgCode)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_N3U](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "algcode": + values, _ := getNumberSlice[uint8](L, 3) + e.AlgCode = values + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0NSIDType() { + L := s.L + mtName := "EDNS0_NSID" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_NSID) + e.Code = dns.EDNS0NSID + nArgs := L.GetTop() + if nArgs >= 1 { // NSID + e.Nsid = L.CheckString(1) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_NSID](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "nsid": + L.Push(lua.LString(e.Nsid)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_NSID](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "nsid": + e.Nsid = L.CheckString(3) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0PADDINGType() { + L := s.L + mtName := "EDNS0_PADDING" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_PADDING) + nArgs := L.GetTop() + if nArgs >= 1 { // NSID + e.Padding = []byte(L.CheckString(1)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_PADDING](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "padding": + L.Push(lua.LString(e.Padding)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_PADDING](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "padding": + e.Padding = []byte(L.CheckString(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0SUBNETType() { + L := s.L + mtName := "EDNS0_SUBNET" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + nArgs := L.GetTop() + if nArgs >= 1 { // Family + e.Family = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // SourceNetmask + e.SourceNetmask = uint8(L.CheckNumber(2)) + } + if nArgs >= 3 { // SourceScope + e.SourceScope = uint8(L.CheckNumber(3)) + } + if nArgs >= 4 { // Address + value := L.CheckString(4) + ip := net.ParseIP(value) + if ip == nil { + L.ArgError(4, fmt.Sprintf("expected IP address, got %q", value)) + return 0 + } + e.Address = ip + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_SUBNET](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "family": + L.Push(lua.LNumber(e.Family)) + case "sourcenetmask": + L.Push(lua.LNumber(e.SourceNetmask)) + case "sourcescope": + L.Push(lua.LNumber(e.SourceScope)) + case "address": + L.Push(lua.LString(e.Address.String())) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_SUBNET](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "family": + e.Family = uint16(L.CheckNumber(3)) + case "sourcenetmask": + e.SourceNetmask = uint8(L.CheckNumber(3)) + case "sourcescope": + e.SourceScope = uint8(L.CheckNumber(3)) + case "address": + value := L.CheckString(3) + ip := net.ParseIP(value) + if ip == nil { + L.ArgError(4, fmt.Sprintf("expected IP address, got %q", value)) + return 0 + } + e.Address = ip + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0TCPKEEPALIVEType() { + L := s.L + mtName := "EDNS0_TCP_KEEPALIVE" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_TCP_KEEPALIVE) + e.Code = dns.EDNS0TCPKEEPALIVE + nArgs := L.GetTop() + if nArgs >= 1 { // Timeout + e.Timeout = uint16(L.CheckNumber(1)) + } + if nArgs >= 2 { // Length + e.Length = uint16(L.CheckNumber(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_TCP_KEEPALIVE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "timeout": + L.Push(lua.LNumber(e.Timeout)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_TCP_KEEPALIVE](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "timeout": + e.Timeout = uint16(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} + +func (s *LuaScript) registerEDNS0ULType() { + L := s.L + mtName := "EDNS0_UL" + mt := L.NewTypeMetatable(mtName) + L.SetGlobal(mtName, mt) + // static attributes + L.SetField(mt, "new", L.NewFunction( + func(L *lua.LState) int { + e := new(dns.EDNS0_UL) + e.Code = dns.EDNS0UL + nArgs := L.GetTop() + if nArgs >= 1 { // Lease + e.Lease = uint32(L.CheckNumber(1)) + } + if nArgs >= 2 { // KeyLease + e.KeyLease = uint32(L.CheckNumber(2)) + } + L.Push(userDataWithMetatable(L, mtName, e)) + return 1 + })) + + // methods + L.SetField(mt, "__index", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_UL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "option": + L.Push(lua.LNumber(e.Option())) + return 1 + case "lease": + L.Push(lua.LNumber(e.Lease)) + case "keylease": + L.Push(lua.LNumber(e.KeyLease)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 1 + })) + L.SetField(mt, "__newindex", L.NewFunction( + func(L *lua.LState) int { + e, ok := getUserDataArg[*dns.EDNS0_UL](L, 1) + if !ok { + return 0 + } + fieldName := L.CheckString(2) + switch fieldName { + case "lease": + e.Lease = uint32(L.CheckNumber(3)) + case "keylease": + e.KeyLease = uint32(L.CheckNumber(3)) + default: + L.ArgError(2, fmt.Sprintf("%s does not have field %q", mtName, fieldName)) + return 0 + } + return 0 + })) +} From 2a4ec93c10ecf77340cecc343e514fd3672b1919 Mon Sep 17 00:00:00 2001 From: folbrich Date: Mon, 7 Jul 2025 09:40:02 +0200 Subject: [PATCH 9/9] dev --- lua-helpers.go | 7 +++++++ lua-types.go | 7 ------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lua-helpers.go b/lua-helpers.go index 9efb94ba..0acf2ba1 100644 --- a/lua-helpers.go +++ b/lua-helpers.go @@ -35,6 +35,13 @@ func getUserDataArg[T any](L *lua.LState, n int) (T, bool) { return v, true } +func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { + ud := L.NewUserData() + ud.Value = value + L.SetMetatable(ud, L.GetTypeMetatable(mtName)) + return ud +} + type numbers interface { int | int8 | int16 | int32 | int64 | float32 | float64 | uint | uint8 | uint16 | uint32 | uint64 } diff --git a/lua-types.go b/lua-types.go index 614fd062..8ab20daa 100644 --- a/lua-types.go +++ b/lua-types.go @@ -42,10 +42,3 @@ func (s *LuaScript) RegisterConstants() { L.SetGlobal(name, lua.LNumber(value)) } } - -func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData { - ud := L.NewUserData() - ud.Value = value - L.SetMetatable(ud, L.GetTypeMetatable(mtName)) - return ud -}