Skip to content

Commit 2e0f7e9

Browse files
authored
Implement the Go to Symbol in Workspace (Ctrl-T) command (#1327)
1 parent 8fbfa3e commit 2e0f7e9

File tree

5 files changed

+318
-4
lines changed

5 files changed

+318
-4
lines changed

internal/ast/ast.go

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10064,8 +10064,10 @@ type SourceFile struct {
1006410064

1006510065
// Fields set by language service
1006610066

10067-
tokenCacheMu sync.Mutex
10068-
tokenCache map[core.TextRange]*Node
10067+
tokenCacheMu sync.Mutex
10068+
tokenCache map[core.TextRange]*Node
10069+
declarationMapMu sync.Mutex
10070+
declarationMap map[string][]*Node
1006910071
}
1007010072

1007110073
func (f *NodeFactory) NewSourceFile(opts SourceFileParseOptions, text string, statements *NodeList) *Node {
@@ -10248,6 +10250,136 @@ func IsSourceFile(node *Node) bool {
1024810250
return node.Kind == KindSourceFile
1024910251
}
1025010252

10253+
func (node *SourceFile) GetDeclarationMap() map[string][]*Node {
10254+
node.declarationMapMu.Lock()
10255+
defer node.declarationMapMu.Unlock()
10256+
if node.declarationMap == nil {
10257+
node.declarationMap = node.computeDeclarationMap()
10258+
}
10259+
return node.declarationMap
10260+
}
10261+
10262+
func (node *SourceFile) computeDeclarationMap() map[string][]*Node {
10263+
result := make(map[string][]*Node)
10264+
10265+
addDeclaration := func(declaration *Node) {
10266+
name := getDeclarationName(declaration)
10267+
if name != "" {
10268+
result[name] = append(result[name], declaration)
10269+
}
10270+
}
10271+
10272+
var visit func(*Node) bool
10273+
visit = func(node *Node) bool {
10274+
switch node.Kind {
10275+
case KindFunctionDeclaration, KindFunctionExpression, KindMethodDeclaration, KindMethodSignature:
10276+
declarationName := getDeclarationName(node)
10277+
if declarationName != "" {
10278+
declarations := result[declarationName]
10279+
var lastDeclaration *Node
10280+
if len(declarations) != 0 {
10281+
lastDeclaration = declarations[len(declarations)-1]
10282+
}
10283+
// Check whether this declaration belongs to an "overload group".
10284+
if lastDeclaration != nil && node.Parent == lastDeclaration.Parent && node.Symbol() == lastDeclaration.Symbol() {
10285+
// Overwrite the last declaration if it was an overload and this one is an implementation.
10286+
if node.Body() != nil && lastDeclaration.Body() == nil {
10287+
declarations[len(declarations)-1] = node
10288+
}
10289+
} else {
10290+
result[declarationName] = append(result[declarationName], node)
10291+
}
10292+
}
10293+
node.ForEachChild(visit)
10294+
case KindClassDeclaration, KindClassExpression, KindInterfaceDeclaration, KindTypeAliasDeclaration, KindEnumDeclaration, KindModuleDeclaration,
10295+
KindImportEqualsDeclaration, KindImportClause, KindNamespaceImport, KindGetAccessor, KindSetAccessor, KindTypeLiteral:
10296+
addDeclaration(node)
10297+
node.ForEachChild(visit)
10298+
case KindImportSpecifier, KindExportSpecifier:
10299+
if node.PropertyName() != nil {
10300+
addDeclaration(node)
10301+
}
10302+
case KindParameter:
10303+
// Only consider parameter properties
10304+
if !HasSyntacticModifier(node, ModifierFlagsParameterPropertyModifier) {
10305+
break
10306+
}
10307+
fallthrough
10308+
case KindVariableDeclaration, KindBindingElement:
10309+
name := node.Name()
10310+
if name != nil {
10311+
if IsBindingPattern(name) {
10312+
node.Name().ForEachChild(visit)
10313+
} else {
10314+
if node.Initializer() != nil {
10315+
visit(node.Initializer())
10316+
}
10317+
addDeclaration(node)
10318+
}
10319+
}
10320+
case KindEnumMember, KindPropertyDeclaration, KindPropertySignature:
10321+
addDeclaration(node)
10322+
case KindExportDeclaration:
10323+
// Handle named exports case e.g.:
10324+
// export {a, b as B} from "mod";
10325+
exportClause := node.AsExportDeclaration().ExportClause
10326+
if exportClause != nil {
10327+
if IsNamedExports(exportClause) {
10328+
for _, element := range exportClause.AsNamedExports().Elements.Nodes {
10329+
visit(element)
10330+
}
10331+
} else {
10332+
visit(exportClause.AsNamespaceExport().Name())
10333+
}
10334+
}
10335+
case KindImportDeclaration:
10336+
importClause := node.AsImportDeclaration().ImportClause
10337+
if importClause != nil {
10338+
// Handle default import case e.g.:
10339+
// import d from "mod";
10340+
if importClause.Name() != nil {
10341+
addDeclaration(importClause.Name())
10342+
}
10343+
// Handle named bindings in imports e.g.:
10344+
// import * as NS from "mod";
10345+
// import {a, b as B} from "mod";
10346+
namedBindings := importClause.AsImportClause().NamedBindings
10347+
if namedBindings != nil {
10348+
if namedBindings.Kind == KindNamespaceImport {
10349+
addDeclaration(namedBindings)
10350+
} else {
10351+
for _, element := range namedBindings.AsNamedImports().Elements.Nodes {
10352+
visit(element)
10353+
}
10354+
}
10355+
}
10356+
}
10357+
default:
10358+
node.ForEachChild(visit)
10359+
}
10360+
return false
10361+
}
10362+
node.ForEachChild(visit)
10363+
return result
10364+
}
10365+
10366+
func getDeclarationName(declaration *Node) string {
10367+
name := GetNonAssignedNameOfDeclaration(declaration)
10368+
if name != nil {
10369+
if IsComputedPropertyName(name) {
10370+
if IsStringOrNumericLiteralLike(name.Expression()) {
10371+
return name.Expression().Text()
10372+
}
10373+
if IsPropertyAccessExpression(name.Expression()) {
10374+
return name.Expression().Name().Text()
10375+
}
10376+
} else if IsPropertyName(name) {
10377+
return name.Text()
10378+
}
10379+
}
10380+
return ""
10381+
}
10382+
1025110383
type SourceFileLike interface {
1025210384
Text() string
1025310385
LineMap() []core.TextPos

internal/ls/symbols.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package ls
2+
3+
import (
4+
"context"
5+
"slices"
6+
"strings"
7+
"unicode"
8+
"unicode/utf8"
9+
10+
"github.com/microsoft/typescript-go/internal/ast"
11+
"github.com/microsoft/typescript-go/internal/collections"
12+
"github.com/microsoft/typescript-go/internal/compiler"
13+
"github.com/microsoft/typescript-go/internal/core"
14+
"github.com/microsoft/typescript-go/internal/lsp/lsproto"
15+
"github.com/microsoft/typescript-go/internal/scanner"
16+
"github.com/microsoft/typescript-go/internal/stringutil"
17+
)
18+
19+
type DeclarationInfo struct {
20+
name string
21+
declaration *ast.Node
22+
matchScore int
23+
}
24+
25+
func ProvideWorkspaceSymbols(ctx context.Context, programs []*compiler.Program, converters *Converters, query string) ([]lsproto.SymbolInformation, error) {
26+
// Obtain set of non-declaration source files from all active programs.
27+
var sourceFiles collections.Set[*ast.SourceFile]
28+
for _, program := range programs {
29+
for _, sourceFile := range program.SourceFiles() {
30+
if !sourceFile.IsDeclarationFile {
31+
sourceFiles.Add(sourceFile)
32+
}
33+
}
34+
}
35+
// Create DeclarationInfos for all declarations in the source files.
36+
var infos []DeclarationInfo
37+
for sourceFile := range sourceFiles.Keys() {
38+
if ctx != nil && ctx.Err() != nil {
39+
return []lsproto.SymbolInformation{}, nil
40+
}
41+
declarationMap := sourceFile.GetDeclarationMap()
42+
for name, declarations := range declarationMap {
43+
score := getMatchScore(name, query)
44+
if score >= 0 {
45+
for _, declaration := range declarations {
46+
infos = append(infos, DeclarationInfo{name, declaration, score})
47+
}
48+
}
49+
}
50+
}
51+
// Sort the DeclarationInfos and return the top 256 matches.
52+
slices.SortFunc(infos, compareDeclarationInfos)
53+
count := min(len(infos), 256)
54+
symbols := make([]lsproto.SymbolInformation, count)
55+
for i, info := range infos[0:count] {
56+
node := core.OrElse(ast.GetNameOfDeclaration(info.declaration), info.declaration)
57+
sourceFile := ast.GetSourceFileOfNode(node)
58+
pos := scanner.SkipTrivia(sourceFile.Text(), node.Pos())
59+
var symbol lsproto.SymbolInformation
60+
symbol.Name = info.name
61+
symbol.Kind = getSymbolKindFromNode(info.declaration)
62+
symbol.Location = converters.ToLSPLocation(sourceFile, core.NewTextRange(pos, node.End()))
63+
symbols[i] = symbol
64+
}
65+
return symbols, nil
66+
}
67+
68+
// Return a score for matching `s` against `pattern`. In order to match, `s` must contain each of the characters in
69+
// `pattern` in the same order. Upper case characters in `pattern` must match exactly, whereas lower case characters
70+
// in `pattern` match either case in `s`. If `s` doesn't match, -1 is returned. Otherwise, the returned score is the
71+
// number of characters in `s` that weren't matched. Thus, zero represents an exact match, and higher values represent
72+
// increasingly less specific partial matches.
73+
func getMatchScore(s string, pattern string) int {
74+
score := 0
75+
for _, p := range pattern {
76+
exact := unicode.IsUpper(p)
77+
for {
78+
c, size := utf8.DecodeRuneInString(s)
79+
if size == 0 {
80+
return -1
81+
}
82+
s = s[size:]
83+
if exact && c == p || !exact && unicode.ToLower(c) == unicode.ToLower(p) {
84+
break
85+
}
86+
score++
87+
}
88+
}
89+
return score
90+
}
91+
92+
// Sort DeclarationInfos by ascending match score, then ascending case insensitive name, then
93+
// ascending case sensitive name, and finally by source file name and position.
94+
func compareDeclarationInfos(d1, d2 DeclarationInfo) int {
95+
if d1.matchScore != d2.matchScore {
96+
return d1.matchScore - d2.matchScore
97+
}
98+
if c := stringutil.CompareStringsCaseInsensitive(d1.name, d2.name); c != 0 {
99+
return c
100+
}
101+
if c := strings.Compare(d1.name, d2.name); c != 0 {
102+
return c
103+
}
104+
s1 := ast.GetSourceFileOfNode(d1.declaration)
105+
s2 := ast.GetSourceFileOfNode(d2.declaration)
106+
if s1 != s2 {
107+
return strings.Compare(string(s1.Path()), string(s2.Path()))
108+
}
109+
return d1.declaration.Pos() - d2.declaration.Pos()
110+
}
111+
112+
func getSymbolKindFromNode(node *ast.Node) lsproto.SymbolKind {
113+
switch node.Kind {
114+
case ast.KindModuleDeclaration:
115+
return lsproto.SymbolKindNamespace
116+
case ast.KindClassDeclaration, ast.KindClassExpression, ast.KindTypeAliasDeclaration:
117+
return lsproto.SymbolKindClass
118+
case ast.KindMethodDeclaration, ast.KindMethodSignature:
119+
return lsproto.SymbolKindMethod
120+
case ast.KindPropertyDeclaration, ast.KindPropertySignature, ast.KindGetAccessor, ast.KindSetAccessor:
121+
return lsproto.SymbolKindProperty
122+
case ast.KindConstructor, ast.KindConstructSignature:
123+
return lsproto.SymbolKindConstructor
124+
case ast.KindEnumDeclaration:
125+
return lsproto.SymbolKindEnum
126+
case ast.KindInterfaceDeclaration:
127+
return lsproto.SymbolKindInterface
128+
case ast.KindFunctionDeclaration, ast.KindFunctionExpression:
129+
return lsproto.SymbolKindFunction
130+
case ast.KindEnumMember:
131+
return lsproto.SymbolKindEnumMember
132+
case ast.KindTypeParameter:
133+
return lsproto.SymbolKindTypeParameter
134+
}
135+
return lsproto.SymbolKindVariable
136+
}

internal/lsp/server.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R
496496
return s.handleDocumentRangeFormat(ctx, req)
497497
case *lsproto.DocumentOnTypeFormattingParams:
498498
return s.handleDocumentOnTypeFormat(ctx, req)
499+
case *lsproto.WorkspaceSymbolParams:
500+
return s.handleWorkspaceSymbol(ctx, req)
499501
default:
500502
switch req.Method {
501503
case lsproto.MethodShutdown:
@@ -573,6 +575,9 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) {
573575
FirstTriggerCharacter: "{",
574576
MoreTriggerCharacter: &[]string{"}", ";", "\n"},
575577
},
578+
WorkspaceSymbolProvider: &lsproto.BooleanOrWorkspaceSymbolOptions{
579+
Boolean: ptrTo(true),
580+
},
576581
},
577582
})
578583
}
@@ -770,6 +775,17 @@ func (s *Server) handleDocumentOnTypeFormat(ctx context.Context, req *lsproto.Re
770775
return nil
771776
}
772777

778+
func (s *Server) handleWorkspaceSymbol(ctx context.Context, req *lsproto.RequestMessage) error {
779+
programs := core.Map(s.projectService.Projects(), (*project.Project).GetProgram)
780+
params := req.Params.(*lsproto.WorkspaceSymbolParams)
781+
symbols, err := ls.ProvideWorkspaceSymbols(ctx, programs, s.projectService.Converters(), params.Query)
782+
if err != nil {
783+
return err
784+
}
785+
s.sendResult(req.ID, symbols)
786+
return nil
787+
}
788+
773789
func (s *Server) Log(msg ...any) {
774790
fmt.Fprintln(s.stderr, msg...)
775791
}

internal/project/service.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ func (s *Service) DefaultLibraryPath() string {
129129
return s.host.DefaultLibraryPath()
130130
}
131131

132+
func (s *Service) Converters() *ls.Converters {
133+
return s.converters
134+
}
135+
132136
// TypingsInstaller implements ProjectHost.
133137
func (s *Service) TypingsInstaller() *TypingsInstaller {
134138
if s.typingsInstaller != nil {

internal/stringutil/compare.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package stringutil
22

3-
import "strings"
3+
import (
4+
"strings"
5+
"unicode"
6+
"unicode/utf8"
7+
)
48

59
func EquateStringCaseInsensitive(a, b string) bool {
610
// !!!
@@ -31,7 +35,29 @@ func CompareStringsCaseInsensitive(a string, b string) Comparison {
3135
if a == b {
3236
return ComparisonEqual
3337
}
34-
return strings.Compare(strings.ToUpper(a), strings.ToUpper(b))
38+
for {
39+
ca, sa := utf8.DecodeRuneInString(a)
40+
cb, sb := utf8.DecodeRuneInString(b)
41+
if sa == 0 {
42+
if sb == 0 {
43+
return ComparisonEqual
44+
}
45+
return ComparisonLessThan
46+
}
47+
if sb == 0 {
48+
return ComparisonGreaterThan
49+
}
50+
lca := unicode.ToLower(ca)
51+
lcb := unicode.ToLower(cb)
52+
if lca != lcb {
53+
if lca < lcb {
54+
return ComparisonLessThan
55+
}
56+
return ComparisonGreaterThan
57+
}
58+
a = a[sa:]
59+
b = b[sb:]
60+
}
3561
}
3662

3763
func CompareStringsCaseSensitive(a string, b string) Comparison {

0 commit comments

Comments
 (0)