diff --git a/file/error.go b/file/error.go index 8ff85dfa5..b7838d4e1 100644 --- a/file/error.go +++ b/file/error.go @@ -3,7 +3,6 @@ package file import ( "fmt" "strings" - "unicode/utf8" ) type Error struct { @@ -19,43 +18,49 @@ func (e *Error) Error() string { return e.format() } +var tabReplacer = strings.NewReplacer("\t", " ") + func (e *Error) Bind(source Source) *Error { + src := source.String() + e.Line = 1 - for i, r := range source { - if i == e.From { + var runeCount, lineStart, lineOffset int + for i, r := range src { + if runeCount == e.From { break } if r == '\n' { + lineStart = i e.Line++ e.Column = 0 + lineOffset = 0 } else { e.Column++ } + runeCount++ + lineOffset++ + } + + lineEnd := lineStart + strings.IndexByte(src[lineStart:], '\n') + if lineEnd < lineStart { + lineEnd = len(src) + } + if lineStart == lineEnd { + return e } - if snippet, found := source.Snippet(e.Line); found { - snippet := strings.Replace(snippet, "\t", " ", -1) - srcLine := "\n | " + snippet - var bytes = []byte(snippet) - var indLine = "\n | " - for i := 0; i < e.Column && len(bytes) > 0; i++ { - _, sz := utf8.DecodeRune(bytes) - bytes = bytes[sz:] - if sz > 1 { - goto noind - } else { - indLine += "." - } - } - if _, sz := utf8.DecodeRune(bytes); sz > 1 { - goto noind - } else { - indLine += "^" - } - srcLine += indLine - noind: - e.Snippet = srcLine + const prefix = "\n | " + line := src[lineStart:lineEnd] + snippet := new(strings.Builder) + snippet.Grow(2*len(prefix) + len(line) + lineOffset + 1) + snippet.WriteString(prefix) + tabReplacer.WriteString(snippet, line) + snippet.WriteString(prefix) + for i := 0; i < lineOffset; i++ { + snippet.WriteByte('.') } + snippet.WriteByte('^') + e.Snippet = snippet.String() return e } diff --git a/file/source.go b/file/source.go index 8e2b2d154..b11bb5f9d 100644 --- a/file/source.go +++ b/file/source.go @@ -1,48 +1,36 @@ package file -import ( - "strings" - "unicode/utf8" -) +import "strings" -type Source []rune +type Source struct { + raw string +} func NewSource(contents string) Source { - return []rune(contents) + return Source{ + raw: contents, + } } func (s Source) String() string { - return string(s) + return s.raw } func (s Source) Snippet(line int) (string, bool) { - if s == nil { + if s.raw == "" { return "", false } - lines := strings.Split(string(s), "\n") - lineOffsets := make([]int, len(lines)) - var offset int - for i, line := range lines { - offset = offset + utf8.RuneCountInString(line) + 1 - lineOffsets[i] = offset - } - charStart, found := getLineOffset(lineOffsets, line) - if !found || len(s) == 0 { - return "", false + var start int + for i := 1; i < line; i++ { + pos := strings.IndexByte(s.raw[start:], '\n') + if pos < 0 { + return "", false + } + start += pos + 1 } - charEnd, found := getLineOffset(lineOffsets, line+1) - if found { - return string(s[charStart : charEnd-1]), true - } - return string(s[charStart:]), true -} - -func getLineOffset(lineOffsets []int, line int) (int, bool) { - if line == 1 { - return 0, true - } else if line > 1 && line <= len(lineOffsets) { - offset := lineOffsets[line-2] - return offset, true + end := start + strings.IndexByte(s.raw[start:], '\n') + if end < start { + end = len(s.raw) } - return -1, false + return s.raw[start:end], true } diff --git a/parser/lexer/bench_test.go b/parser/lexer/bench_test.go new file mode 100644 index 000000000..ccd6f6914 --- /dev/null +++ b/parser/lexer/bench_test.go @@ -0,0 +1,23 @@ +package lexer + +import ( + "testing" + + "github.com/expr-lang/expr/file" +) + +func BenchmarkParser(b *testing.B) { + const source = ` + /* + Showing worst case scenario + */ + let value = trim("contains escapes \n\"\\ \U0001F600 and non ASCII ñ"); // inline comment + len(value) == 0x2A + // let's introduce an error too + whatever + ` + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Lex(file.NewSource(source)) + } +} diff --git a/parser/lexer/lexer.go b/parser/lexer/lexer.go index e6b06c09d..8f067525f 100644 --- a/parser/lexer/lexer.go +++ b/parser/lexer/lexer.go @@ -3,18 +3,19 @@ package lexer import ( "fmt" "strings" + "unicode/utf8" "github.com/expr-lang/expr/file" ) +const minTokens = 10 + func Lex(source file.Source) ([]Token, error) { + raw := source.String() l := &lexer{ - source: source, - tokens: make([]Token, 0), - start: 0, - end: 0, + raw: raw, + tokens: make([]Token, 0, minTokens), } - l.commit() for state := root; state != nil; { state = state(l) @@ -28,10 +29,15 @@ func Lex(source file.Source) ([]Token, error) { } type lexer struct { - source file.Source + raw string tokens []Token - start, end int err *file.Error + start, end pos + eof bool +} + +type pos struct { + byte, rune int } const eof rune = -1 @@ -41,23 +47,39 @@ func (l *lexer) commit() { } func (l *lexer) next() rune { - if l.end >= len(l.source) { - l.end++ + if l.end.byte >= len(l.raw) { + l.eof = true return eof } - r := l.source[l.end] - l.end++ + r, sz := utf8.DecodeRuneInString(l.raw[l.end.byte:]) + l.end.rune++ + l.end.byte += sz return r } func (l *lexer) peek() rune { - r := l.next() - l.backup() - return r + if l.end.byte < len(l.raw) { + r, _ := utf8.DecodeRuneInString(l.raw[l.end.byte:]) + return r + } + return eof +} + +func (l *lexer) peekByte() (byte, bool) { + if l.end.byte >= 0 && l.end.byte < len(l.raw) { + return l.raw[l.end.byte], true + } + return 0, false } func (l *lexer) backup() { - l.end-- + if l.eof { + l.eof = false + } else if l.end.rune > 0 { + _, sz := utf8.DecodeLastRuneInString(l.raw[:l.end.byte]) + l.end.byte -= sz + l.end.rune-- + } } func (l *lexer) emit(t Kind) { @@ -66,7 +88,7 @@ func (l *lexer) emit(t Kind) { func (l *lexer) emitValue(t Kind, value string) { l.tokens = append(l.tokens, Token{ - Location: file.Location{From: l.start, To: l.end}, + Location: file.Location{From: l.start.rune, To: l.end.rune}, Kind: t, Value: value, }) @@ -74,11 +96,11 @@ func (l *lexer) emitValue(t Kind, value string) { } func (l *lexer) emitEOF() { - from := l.end - 2 + from := l.end.rune - 1 if from < 0 { from = 0 } - to := l.end - 1 + to := l.end.rune - 0 if to < 0 { to = 0 } @@ -94,60 +116,37 @@ func (l *lexer) skip() { } func (l *lexer) word() string { - // TODO: boundary check is NOT needed here, but for some reason CI fuzz tests are failing. - if l.start > len(l.source) || l.end > len(l.source) { - return "__invalid__" - } - return string(l.source[l.start:l.end]) + return l.raw[l.start.byte:l.end.byte] } func (l *lexer) accept(valid string) bool { - if strings.ContainsRune(valid, l.next()) { + if strings.ContainsRune(valid, l.peek()) { + l.next() return true } - l.backup() return false } func (l *lexer) acceptRun(valid string) { - for strings.ContainsRune(valid, l.next()) { + for l.accept(valid) { } - l.backup() } func (l *lexer) skipSpaces() { - r := l.peek() - for ; r == ' '; r = l.peek() { - l.next() - } + l.acceptRun(" ") l.skip() } -func (l *lexer) acceptWord(word string) bool { - pos := l.end - - l.skipSpaces() - - for _, ch := range word { - if l.next() != ch { - l.end = pos - return false - } - } - if r := l.peek(); r != ' ' && r != eof { - l.end = pos - return false - } - - return true -} - func (l *lexer) error(format string, args ...any) stateFn { if l.err == nil { // show first error + end := l.end.rune + if l.eof { + end++ + } l.err = &file.Error{ Location: file.Location{ - From: l.end - 1, - To: l.end, + From: end - 1, + To: end, }, Message: fmt.Sprintf(format, args...), } @@ -225,6 +224,6 @@ func (l *lexer) scanRawString(quote rune) (n int) { ch = l.next() n++ } - l.emitValue(String, string(l.source[l.start+1:l.end-1])) + l.emitValue(String, l.raw[l.start.byte+1:l.end.byte-1]) return } diff --git a/parser/lexer/lexer_test.go b/parser/lexer/lexer_test.go index db02d2acf..5171f4255 100644 --- a/parser/lexer/lexer_test.go +++ b/parser/lexer/lexer_test.go @@ -335,6 +335,7 @@ literal not terminated (1:10) früh ♥︎ unrecognized character: U+2665 '♥' (1:6) | früh ♥︎ + | .....^ ` func TestLex_error(t *testing.T) { diff --git a/parser/lexer/token.go b/parser/lexer/token.go index 459fa6905..c809c690e 100644 --- a/parser/lexer/token.go +++ b/parser/lexer/token.go @@ -31,17 +31,13 @@ func (t Token) String() string { } func (t Token) Is(kind Kind, values ...string) bool { - if len(values) == 0 { - return kind == t.Kind + if kind != t.Kind { + return false } - for _, v := range values { if v == t.Value { - goto found + return true } } - return false - -found: - return kind == t.Kind + return len(values) == 0 } diff --git a/parser/lexer/utils.go b/parser/lexer/utils.go index 5c9e6b59d..fdb8beaa1 100644 --- a/parser/lexer/utils.go +++ b/parser/lexer/utils.go @@ -36,7 +36,8 @@ func unescape(value string) (string, error) { if size >= math.MaxInt { return "", fmt.Errorf("too large string") } - buf := make([]byte, 0, size) + buf := new(strings.Builder) + buf.Grow(int(size)) for len(value) > 0 { c, multibyte, rest, err := unescapeChar(value) if err != nil { @@ -44,13 +45,13 @@ func unescape(value string) (string, error) { } value = rest if c < utf8.RuneSelf || !multibyte { - buf = append(buf, byte(c)) + buf.WriteByte(byte(c)) } else { n := utf8.EncodeRune(runeTmp[:], c) - buf = append(buf, runeTmp[:n]...) + buf.Write(runeTmp[:n]) } } - return string(buf), nil + return buf.String(), nil } // unescapeChar takes a string input and returns the following info: diff --git a/vm/vm_test.go b/vm/vm_test.go index 91752a419..817fc6cc2 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/internal/testify/require" "github.com/expr-lang/expr" @@ -609,10 +610,10 @@ func TestVM_DirectCallOpcodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args, @@ -735,10 +736,10 @@ func TestVM_IndexAndCountOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args, @@ -1176,10 +1177,10 @@ func TestVM_DirectBasicOpcodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args,