Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions cmd_scripting.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (

func commandsScripting(m *Miniredis) {
m.srv.Register("EVAL", m.cmdEval)
m.srv.Register("EVAL_RO", m.cmdEvalro, server.ReadOnlyOption())
m.srv.Register("EVALSHA", m.cmdEvalsha)
m.srv.Register("EVALSHA_RO", m.cmdEvalshaRo, server.ReadOnlyOption())
m.srv.Register("SCRIPT", m.cmdScript)
}

Expand All @@ -28,7 +30,7 @@ var (

// Execute lua. Needs to run m.Lock()ed, from within withTx().
// Returns true if the lua was OK (and hence should be cached).
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, readOnly bool, args []string) bool {
l := lua.NewState(lua.Options{SkipOpenLibs: true})
defer l.Close()

Expand Down Expand Up @@ -85,7 +87,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri
}
l.SetGlobal("ARGV", argvTable)

redisFuncs, redisConstants := mkLua(m.srv, c, sha)
redisFuncs, redisConstants := mkLua(m.srv, c, sha, readOnly)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
Expand Down Expand Up @@ -150,7 +152,8 @@ func compile(script string) (*lua.FunctionProto, error) {
return proto, nil
}

func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
// Shared implementation for EVAL and EVALRO
func (m *Miniredis) cmdEvalShared(c *server.Peer, cmd string, readOnly bool, args []string) {
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
return
}
Expand All @@ -165,14 +168,20 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sha := sha1Hex(script)
ok := m.runLuaScript(c, sha, script, args)
ok := m.runLuaScript(c, sha, script, readOnly, args)
if ok {
m.scripts[sha] = script
}
})
}

func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
// Wrapper function for EVAL command
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
m.cmdEvalShared(c, cmd, false, args)
}

// Shared implementation for EVALSHA and EVALSHA_RO
func (m *Miniredis) cmdEvalshaShared(c *server.Peer, cmd string, readOnly bool, args []string) {
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
return
}
Expand All @@ -192,10 +201,25 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
return
}

m.runLuaScript(c, sha, script, args)
m.runLuaScript(c, sha, script, readOnly, args)
})
}

// Wrapper function for EVALSHA command
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
m.cmdEvalshaShared(c, cmd, false, args)
}

// Wrapper function for EVALRO command
func (m *Miniredis) cmdEvalro(c *server.Peer, cmd string, args []string) {
m.cmdEvalShared(c, cmd, true, args)
}

// Wrapper function for EVALSHA_RO command
func (m *Miniredis) cmdEvalshaRo(c *server.Peer, cmd string, args []string) {
m.cmdEvalshaShared(c, cmd, true, args)
}

func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
if !m.isValidCMD(c, cmd, args, atLeast(1)) {
return
Expand Down
121 changes: 121 additions & 0 deletions cmd_scripting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,124 @@ func TestLuaTX(t *testing.T) {
)
})
}

func TestEvalRo(t *testing.T) {
_, c := runWithClient(t)

t.Run("read-only command", func(t *testing.T) {
mustOK(t, c,
"SET", "readonly", "foo",
)

// Test EVALRO with read-only command (should work)
mustDo(t, c,
"EVAL_RO", "return redis.call('GET', KEYS[1])", "1", "readonly",
proto.String("foo"),
)
})

t.Run("write command", func(t *testing.T) {
// Test EVALRO with write command (should fail)
mustContain(t, c,
"EVAL_RO", "return redis.call('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
"Write commands are not allowed in read-only scripts",
)
})
}

func TestEvalshaRo(t *testing.T) {
_, c := runWithClient(t)

// First load a read-only script
script := "return redis.call('GET', KEYS[1])"
t.Run("read-only script", func(t *testing.T) {
mustDo(t, c,
"SCRIPT", "LOAD", script,
proto.String("d3c21d0c2b9ca22f82737626a27bcaf5d288f99f"),
)

mustOK(t, c,
"SET", "readonly", "foo",
)

// Test EVALSHA_RO with read-only command (should work)
mustDo(t, c,
"EVALSHA_RO", "d3c21d0c2b9ca22f82737626a27bcaf5d288f99f", "1", "readonly",
proto.String("foo"),
)

})

t.Run("write script", func(t *testing.T) {
// Load a write script
writeScript := "return redis.call('SET', KEYS[1], ARGV[1])"
mustDo(t, c,
"SCRIPT", "LOAD", writeScript,
proto.String("d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37"),
)

// Test EVALSHA_RO with write command (should fail)
mustContain(t, c,
"EVALSHA_RO", "d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37", "1", "key1", "value1",
"Write commands are not allowed in read-only scripts",
)
})
}

func TestEvalRoWriteCommandWithPcall(t *testing.T) {
_, c := runWithClient(t)

t.Run("return error", func(t *testing.T) {
// Test EVAL with pcall and write command (should fail)
mustContain(t, c,
"EVAL_RO", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1",
"Unknown Redis command called from script",
)
})

t.Run("extra work after error", func(t *testing.T) {
script := `
local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]);
local res = "pcall:" .. err['err'];
return res;
`
// Test EVAL with pcall and write command (should fail)
mustContain(t, c,
"EVAL_RO", script, "1", "key1", "value1",
"pcall:ERR Unknown Redis command called from script",
)
})

t.Run("write command in read-only script", func(t *testing.T) {
// Test EVALRO with pcall and write command (should fail)
mustContain(t, c,
"EVAL_RO", "return redis.pcall('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
"Write commands are not allowed in read-only scripts",
)
})
}

func TestEvalWithPcall(t *testing.T) {
_, c := runWithClient(t)

t.Run("return error", func(t *testing.T) {
// Test EVAL with pcall and write command (should fail)
mustContain(t, c,
"EVAL", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1",
"Unknown Redis command called from script",
)
})

t.Run("continue after error", func(t *testing.T) {
script := `
local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]);
local res = "pcall:" .. err['err'];
return res;
`
// Test EVAL with pcall and write command (should fail)
mustContain(t, c,
"EVAL", script, "1", "foo", "value1",
"pcall:ERR Unknown Redis command called from script",
)
})
}
24 changes: 22 additions & 2 deletions lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var luaRedisConstants = map[string]lua.LValue{
"LOG_WARNING": lua.LNumber(3),
}

func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) {
func mkLua(srv *server.Server, c *server.Peer, sha string, readOnly bool) (map[string]lua.LGFunction, map[string]lua.LValue) {
mkCall := func(failFast bool) func(l *lua.LState) int {
// one server.Ctx for a single Lua run
pCtx := &connCtx{}
Expand Down Expand Up @@ -52,6 +52,20 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun
return 0
}

if readOnly && len(args) > 0 {
if srv.IsRegisteredCommand(args[0]) && !srv.IsReadOnlyCommand(args[0]) {
if failFast {
l.Error(lua.LString("Write commands are not allowed in read-only scripts"), 1)
return 0
}
// pcall() mode - return error table
res := &lua.LTable{}
res.RawSetString("err", lua.LString("Write commands are not allowed in read-only scripts"))
l.Push(res)
return 1
}
}

buf := &bytes.Buffer{}
wr := bufio.NewWriter(buf)
peer := server.NewPeer(wr)
Expand All @@ -71,7 +85,13 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun
return 0
}
// pcall() mode
l.Push(lua.LNil)
res := &lua.LTable{}
if strings.Contains(err.Error(), "ERR unknown command") {
res.RawSetString("err", lua.LString("ERR Unknown Redis command called from script"))
} else {
res.RawSetString("err", lua.LString(err.Error()))
}
l.Push(res)
return 1
}

Expand Down
9 changes: 9 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ func (s *Server) TotalCommands() int {
return s.infoCmds
}

// IsRegisteredCommand checks if a command is registered
func (s *Server) IsRegisteredCommand(cmd string) bool {
s.mu.Lock()
defer s.mu.Unlock()
cmdUp := strings.ToUpper(cmd)
_, ok := s.cmds[cmdUp]
return ok
}

// IsReadOnlyCommand checks if a command is marked as read-only
func (s *Server) IsReadOnlyCommand(cmd string) bool {
s.mu.Lock()
Expand Down
20 changes: 20 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,23 @@ func TestReadOnlyOption(t *testing.T) {
t.Error("Non-existent command should return false")
}
}

func TestIsRegisteredCommand(t *testing.T) {
srv, err := NewServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
defer srv.Close()

srv.Register("TESTGET", func(c *Peer, cmd string, args []string) {
c.WriteOK()
})

if !srv.IsRegisteredCommand("TESTGET") {
t.Error("TESTGET should be registered")
}

if srv.IsRegisteredCommand("NONEXISTENT") {
t.Error("NONEXISTENT should not be registered")
}
}