diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 4b558549ba..bf187a1820 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -2785,6 +2785,10 @@ func (node *SwitchStatement) computeSubtreeFacts() SubtreeFacts { propagateSubtreeFacts(node.CaseBlock) } +func IsSwitchStatement(node *Node) bool { + return node.Kind == KindSwitchStatement +} + // CaseBlock type CaseBlock struct { diff --git a/internal/binder/binder.go b/internal/binder/binder.go index 05f8bc6382..8e15ad94f3 100644 --- a/internal/binder/binder.go +++ b/internal/binder/binder.go @@ -2758,7 +2758,7 @@ func (b *Binder) errorOrSuggestionOnRange(isError bool, startNode *ast.Node, end // If so, the node _must_ be in the current file (as that's the only way anything could have traversed to it to yield it as the error node) // This version of `createDiagnosticForNode` uses the binder's context to account for this, and always yields correct diagnostics even in these situations. func (b *Binder) createDiagnosticForNode(node *ast.Node, message *diagnostics.Message, args ...any) *ast.Diagnostic { - return ast.NewDiagnostic(b.file, GetErrorRangeForNode(b.file, node), message, args...) + return ast.NewDiagnostic(b.file, scanner.GetErrorRangeForNode(b.file, node), message, args...) } func (b *Binder) addDiagnostic(diagnostic *ast.Diagnostic) { @@ -2855,83 +2855,3 @@ func isAssignmentDeclaration(decl *ast.Node) bool { func isEffectiveModuleDeclaration(node *ast.Node) bool { return ast.IsModuleDeclaration(node) || ast.IsIdentifier(node) } - -func getErrorRangeForArrowFunction(sourceFile *ast.SourceFile, node *ast.Node) core.TextRange { - pos := scanner.SkipTrivia(sourceFile.Text(), node.Pos()) - body := node.AsArrowFunction().Body - if body != nil && body.Kind == ast.KindBlock { - startLine, _ := scanner.GetLineAndCharacterOfPosition(sourceFile, body.Pos()) - endLine, _ := scanner.GetLineAndCharacterOfPosition(sourceFile, body.End()) - if startLine < endLine { - // The arrow function spans multiple lines, - // make the error span be the first line, inclusive. - return core.NewTextRange(pos, scanner.GetEndLinePosition(sourceFile, startLine)) - } - } - return core.NewTextRange(pos, node.End()) -} - -func GetErrorRangeForNode(sourceFile *ast.SourceFile, node *ast.Node) core.TextRange { - errorNode := node - switch node.Kind { - case ast.KindSourceFile: - pos := scanner.SkipTrivia(sourceFile.Text(), 0) - if pos == len(sourceFile.Text()) { - return core.NewTextRange(0, 0) - } - return scanner.GetRangeOfTokenAtPosition(sourceFile, pos) - // This list is a work in progress. Add missing node kinds to improve their error spans - case ast.KindFunctionDeclaration, ast.KindMethodDeclaration: - if node.Flags&ast.NodeFlagsReparsed != 0 { - errorNode = node - break - } - fallthrough - case ast.KindVariableDeclaration, ast.KindBindingElement, ast.KindClassDeclaration, ast.KindClassExpression, ast.KindInterfaceDeclaration, - ast.KindModuleDeclaration, ast.KindEnumDeclaration, ast.KindEnumMember, ast.KindFunctionExpression, - ast.KindGetAccessor, ast.KindSetAccessor, ast.KindTypeAliasDeclaration, ast.KindJSTypeAliasDeclaration, ast.KindPropertyDeclaration, - ast.KindPropertySignature, ast.KindNamespaceImport: - errorNode = ast.GetNameOfDeclaration(node) - case ast.KindArrowFunction: - return getErrorRangeForArrowFunction(sourceFile, node) - case ast.KindCaseClause, ast.KindDefaultClause: - start := scanner.SkipTrivia(sourceFile.Text(), node.Pos()) - end := node.End() - statements := node.AsCaseOrDefaultClause().Statements.Nodes - if len(statements) != 0 { - end = statements[0].Pos() - } - return core.NewTextRange(start, end) - case ast.KindReturnStatement, ast.KindYieldExpression: - pos := scanner.SkipTrivia(sourceFile.Text(), node.Pos()) - return scanner.GetRangeOfTokenAtPosition(sourceFile, pos) - case ast.KindSatisfiesExpression: - pos := scanner.SkipTrivia(sourceFile.Text(), node.AsSatisfiesExpression().Expression.End()) - return scanner.GetRangeOfTokenAtPosition(sourceFile, pos) - case ast.KindConstructor: - if node.Flags&ast.NodeFlagsReparsed != 0 { - errorNode = node - break - } - scanner := scanner.GetScannerForSourceFile(sourceFile, node.Pos()) - start := scanner.TokenStart() - for scanner.Token() != ast.KindConstructorKeyword && scanner.Token() != ast.KindStringLiteral && scanner.Token() != ast.KindEndOfFile { - scanner.Scan() - } - return core.NewTextRange(start, scanner.TokenEnd()) - // !!! - // case KindJSDocSatisfiesTag: - // pos := scanner.SkipTrivia(sourceFile.Text(), node.tagName.pos) - // return scanner.GetRangeOfTokenAtPosition(sourceFile, pos) - } - if errorNode == nil { - // If we don't have a better node, then just set the error on the first token of - // construct. - return scanner.GetRangeOfTokenAtPosition(sourceFile, node.Pos()) - } - pos := errorNode.Pos() - if !ast.NodeIsMissing(errorNode) && !ast.IsJsxText(errorNode) { - pos = scanner.SkipTrivia(sourceFile.Text(), pos) - } - return core.NewTextRange(pos, errorNode.End()) -} diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 61bde64b91..4d31ae1ef0 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -30263,6 +30263,22 @@ func (c *Checker) getSymbolAtLocation(node *ast.Node, ignoreErrors bool) *ast.Sy } } +func (c *Checker) getIndexSignaturesAtLocation(node *ast.Node) []*ast.Node { + var signatures []*ast.Node + if ast.IsIdentifier(node) && ast.IsPropertyAccessExpression(node.Parent) && node.Parent.Name() == node { + keyType := c.getLiteralTypeFromPropertyName(node) + objectType := c.getTypeOfExpression(node.Parent.Expression()) + for _, t := range objectType.Distributed() { + for _, info := range c.getApplicableIndexInfos(t, keyType) { + if info.declaration != nil { + signatures = core.AppendIfUnique(signatures, info.declaration) + } + } + } + } + return signatures +} + func (c *Checker) getSymbolOfNameOrPropertyAccessExpression(name *ast.Node) *ast.Symbol { if ast.IsDeclarationName(name) { return c.getSymbolOfNode(name.Parent) diff --git a/internal/checker/exports.go b/internal/checker/exports.go index f3e20ed211..5481b1f31c 100644 --- a/internal/checker/exports.go +++ b/internal/checker/exports.go @@ -142,3 +142,11 @@ func (c *Checker) GetTypeOfPropertyOfType(t *Type, name string) *Type { func (c *Checker) GetContextualTypeForArgumentAtIndex(node *ast.Node, argIndex int) *Type { return c.getContextualTypeForArgumentAtIndex(node, argIndex) } + +func (c *Checker) GetIndexSignaturesAtLocation(node *ast.Node) []*ast.Node { + return c.getIndexSignaturesAtLocation(node) +} + +func (c *Checker) GetResolvedSymbol(node *ast.Node) *ast.Symbol { + return c.getResolvedSymbol(node) +} diff --git a/internal/checker/services.go b/internal/checker/services.go index ad5e39b857..ed7286203c 100644 --- a/internal/checker/services.go +++ b/internal/checker/services.go @@ -618,3 +618,68 @@ func (c *Checker) GetTypeParameterAtPosition(s *Signature, pos int) *Type { } return t } + +func (c *Checker) GetContextualDeclarationsForObjectLiteralElement(objectLiteral *ast.Node, name string) []*ast.Node { + var result []*ast.Node + if t := c.getApparentTypeOfContextualType(objectLiteral, ContextFlagsNone); t != nil { + for _, t := range t.Distributed() { + prop := c.getPropertyOfType(t, name) + if prop != nil { + for _, declaration := range prop.Declarations { + result = core.AppendIfUnique(result, declaration) + } + } else { + for _, info := range c.getApplicableIndexInfos(t, c.getStringLiteralType(name)) { + if info.declaration != nil { + result = core.AppendIfUnique(result, info.declaration) + } + } + } + } + } + return result +} + +var knownGenericTypeNames = map[string]struct{}{ + "Array": {}, + "ArrayLike": {}, + "ReadonlyArray": {}, + "Promise": {}, + "PromiseLike": {}, + "Iterable": {}, + "IterableIterator": {}, + "AsyncIterable": {}, + "Set": {}, + "WeakSet": {}, + "ReadonlySet": {}, + "Map": {}, + "WeakMap": {}, + "ReadonlyMap": {}, + "Partial": {}, + "Required": {}, + "Readonly": {}, + "Pick": {}, + "Omit": {}, + "NonNullable": {}, +} + +func isKnownGenericTypeName(name string) bool { + _, exists := knownGenericTypeNames[name] + return exists +} + +func (c *Checker) GetFirstTypeArgumentFromKnownType(t *Type) *Type { + if t.objectFlags&ObjectFlagsReference != 0 && isKnownGenericTypeName(t.symbol.Name) { + symbol := c.getGlobalSymbol(t.symbol.Name, ast.SymbolFlagsType, nil) + if symbol != nil && symbol == t.Target().symbol { + return core.FirstOrNil(c.getTypeArguments(t)) + } + } + if t.alias != nil && isKnownGenericTypeName(t.alias.symbol.Name) { + symbol := c.getGlobalSymbol(t.alias.symbol.Name, ast.SymbolFlagsType, nil) + if symbol != nil && symbol == t.alias.symbol { + return core.FirstOrNil(t.alias.typeArguments) + } + } + return nil +} diff --git a/internal/checker/utilities.go b/internal/checker/utilities.go index 321d069819..fab2c52440 100644 --- a/internal/checker/utilities.go +++ b/internal/checker/utilities.go @@ -23,7 +23,7 @@ func NewDiagnosticForNode(node *ast.Node, message *diagnostics.Message, args ... var loc core.TextRange if node != nil { file = ast.GetSourceFileOfNode(node) - loc = binder.GetErrorRangeForNode(file, node) + loc = scanner.GetErrorRangeForNode(file, node) } return ast.NewDiagnostic(file, loc, message, args...) } diff --git a/internal/ls/definition.go b/internal/ls/definition.go index d90c8c0b66..7cc4be72b3 100644 --- a/internal/ls/definition.go +++ b/internal/ls/definition.go @@ -18,41 +18,137 @@ func (l *LanguageService) ProvideDefinition(ctx context.Context, documentURI lsp return nil, nil } - checker, done := program.GetTypeCheckerForFile(ctx, file) + c, done := program.GetTypeCheckerForFile(ctx, file) defer done() - calledDeclaration := tryGetSignatureDeclaration(checker, node) - if calledDeclaration != nil { - name := ast.GetNameOfDeclaration(calledDeclaration) - if name != nil { - return l.createLocationsFromDeclarations([]*ast.Node{name}) + if node.Kind == ast.KindOverrideKeyword { + if sym := getSymbolForOverriddenMember(c, node); sym != nil { + return l.createLocationsFromDeclarations(sym.Declarations), nil } } - if symbol := checker.GetSymbolAtLocation(node); symbol != nil { + if ast.IsJumpStatementTarget(node) { + if label := getTargetLabel(node.Parent, node.Text()); label != nil { + return l.createLocationsFromDeclarations([]*ast.Node{label}), nil + } + } + + if node.Kind == ast.KindCaseKeyword || node.Kind == ast.KindDefaultKeyword && ast.IsDefaultClause(node.Parent) { + if stmt := ast.FindAncestor(node.Parent, ast.IsSwitchStatement); stmt != nil { + file := ast.GetSourceFileOfNode(stmt) + return l.createLocationFromFileAndRange(file, scanner.GetRangeOfTokenAtPosition(file, stmt.Pos())), nil + } + } + + if node.Kind == ast.KindReturnKeyword || node.Kind == ast.KindYieldKeyword || node.Kind == ast.KindAwaitKeyword { + if fn := ast.FindAncestor(node, ast.IsFunctionLikeDeclaration); fn != nil { + return l.createLocationsFromDeclarations([]*ast.Node{fn}), nil + } + } + + if calledDeclaration := tryGetSignatureDeclaration(c, node); calledDeclaration != nil { + return l.createLocationsFromDeclarations([]*ast.Node{calledDeclaration}), nil + } + + if ast.IsIdentifier(node) && ast.IsShorthandPropertyAssignment(node.Parent) { + return l.createLocationsFromDeclarations(c.GetResolvedSymbol(node).Declarations), nil + } + + node = getDeclarationNameForKeyword(node) + + if symbol := c.GetSymbolAtLocation(node); symbol != nil { + if symbol.Flags&ast.SymbolFlagsClass != 0 && symbol.Flags&(ast.SymbolFlagsFunction|ast.SymbolFlagsVariable) == 0 && node.Kind == ast.KindConstructorKeyword { + if constructor := symbol.Members[ast.InternalSymbolNameConstructor]; constructor != nil { + symbol = constructor + } + } if symbol.Flags&ast.SymbolFlagsAlias != 0 { - if resolved, ok := checker.ResolveAlias(symbol); ok { + if resolved, ok := c.ResolveAlias(symbol); ok { symbol = resolved } } + if symbol.Flags&(ast.SymbolFlagsProperty|ast.SymbolFlagsMethod|ast.SymbolFlagsAccessor) != 0 && symbol.Parent != nil && symbol.Parent.Flags&ast.SymbolFlagsObjectLiteral != 0 { + if objectLiteral := core.FirstOrNil(symbol.Parent.Declarations); objectLiteral != nil { + if declarations := c.GetContextualDeclarationsForObjectLiteralElement(objectLiteral, symbol.Name); len(declarations) != 0 { + return l.createLocationsFromDeclarations(declarations), nil + } + } + } + return l.createLocationsFromDeclarations(symbol.Declarations), nil + } - return l.createLocationsFromDeclarations(symbol.Declarations) + if indexInfos := c.GetIndexSignaturesAtLocation(node); len(indexInfos) != 0 { + return l.createLocationsFromDeclarations(indexInfos), nil } + return nil, nil } -func (l *LanguageService) createLocationsFromDeclarations(declarations []*ast.Node) (*lsproto.Definition, error) { +func (l *LanguageService) ProvideTypeDefinition(ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Definition, error) { + program, file := l.getProgramAndFile(documentURI) + node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) + if node.Kind == ast.KindSourceFile { + return nil, nil + } + + c, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + + node = getDeclarationNameForKeyword(node) + + if symbol := c.GetSymbolAtLocation(node); symbol != nil { + symbolType := getTypeOfSymbolAtLocation(c, symbol, node) + declarations := getDeclarationsFromType(symbolType) + if typeArgument := c.GetFirstTypeArgumentFromKnownType(symbolType); typeArgument != nil { + declarations = core.Concatenate(getDeclarationsFromType(typeArgument), declarations) + } + if len(declarations) != 0 { + return l.createLocationsFromDeclarations(declarations), nil + } + if symbol.Flags&ast.SymbolFlagsValue == 0 && symbol.Flags&ast.SymbolFlagsType != 0 { + return l.createLocationsFromDeclarations(symbol.Declarations), nil + } + } + + return nil, nil +} + +func getDeclarationNameForKeyword(node *ast.Node) *ast.Node { + if node.Kind >= ast.KindFirstKeyword && node.Kind <= ast.KindLastKeyword { + if ast.IsVariableDeclarationList(node.Parent) { + if decl := core.FirstOrNil(node.Parent.AsVariableDeclarationList().Declarations.Nodes); decl != nil && decl.Name() != nil { + return decl.Name() + } + } else if node.Parent.Name() != nil && node.Pos() < node.Parent.Name().Pos() { + return node.Parent.Name() + } + } + return node +} + +func (l *LanguageService) createLocationsFromDeclarations(declarations []*ast.Node) *lsproto.Definition { + someHaveBody := core.Some(declarations, func(node *ast.Node) bool { return node.Body() != nil }) locations := make([]lsproto.Location, 0, len(declarations)) for _, decl := range declarations { - file := ast.GetSourceFileOfNode(decl) - loc := decl.Loc - pos := scanner.GetTokenPosOfNode(decl, file, false /*includeJSDoc*/) - locations = append(locations, lsproto.Location{ + if !someHaveBody || decl.Body() != nil { + file := ast.GetSourceFileOfNode(decl) + name := core.OrElse(ast.GetNameOfDeclaration(decl), decl) + locations = append(locations, lsproto.Location{ + Uri: FileNameToDocumentURI(file.FileName()), + Range: *l.createLspRangeFromNode(name, file), + }) + } + } + return &lsproto.Definition{Locations: &locations} +} + +func (l *LanguageService) createLocationFromFileAndRange(file *ast.SourceFile, textRange core.TextRange) *lsproto.Definition { + return &lsproto.Definition{ + Location: &lsproto.Location{ Uri: FileNameToDocumentURI(file.FileName()), - Range: l.converters.ToLSPRange(file, core.NewTextRange(pos, loc.End())), - }) + Range: *l.createLspRangeFromBounds(textRange.Pos(), textRange.End(), file), + }, } - return &lsproto.Definition{Locations: &locations}, nil } /** Returns a CallLikeExpression where `node` is the target being invoked. */ @@ -60,12 +156,10 @@ func getAncestorCallLikeExpression(node *ast.Node) *ast.Node { target := ast.FindAncestor(node, func(n *ast.Node) bool { return !isRightSideOfPropertyAccess(n) }) - callLike := target.Parent if callLike != nil && ast.IsCallLikeExpression(callLike) && ast.GetInvokedExpression(callLike) == target { return callLike } - return nil } @@ -75,7 +169,6 @@ func tryGetSignatureDeclaration(typeChecker *checker.Checker, node *ast.Node) *a if callLike != nil { signature = typeChecker.GetResolvedSignature(callLike) } - // Don't go to a function type, go to the value having that type. var declaration *ast.Node if signature != nil && signature.Declaration() != nil { @@ -84,6 +177,60 @@ func tryGetSignatureDeclaration(typeChecker *checker.Checker, node *ast.Node) *a return declaration } } - return nil } + +func getSymbolForOverriddenMember(typeChecker *checker.Checker, node *ast.Node) *ast.Symbol { + classElement := ast.FindAncestor(node, ast.IsClassElement) + if classElement == nil || classElement.Name() == nil { + return nil + } + baseDeclaration := ast.FindAncestor(classElement, ast.IsClassLike) + if baseDeclaration == nil { + return nil + } + baseTypeNode := ast.GetClassExtendsHeritageElement(baseDeclaration) + if baseTypeNode == nil { + return nil + } + expression := ast.SkipParentheses(baseTypeNode.Expression()) + var base *ast.Symbol + if ast.IsClassExpression(expression) { + base = expression.Symbol() + } else { + base = typeChecker.GetSymbolAtLocation(expression) + } + if base == nil { + return nil + } + name := ast.GetTextOfPropertyName(classElement.Name()) + if ast.HasStaticModifier(classElement) { + return typeChecker.GetPropertyOfType(typeChecker.GetTypeOfSymbol(base), name) + } + return typeChecker.GetPropertyOfType(typeChecker.GetDeclaredTypeOfSymbol(base), name) +} + +func getTypeOfSymbolAtLocation(c *checker.Checker, symbol *ast.Symbol, node *ast.Node) *checker.Type { + t := c.GetTypeOfSymbolAtLocation(symbol, node) + // If the type is just a function's inferred type, go-to-type should go to the return type instead since + // go-to-definition takes you to the function anyway. + if t.Symbol() == symbol || t.Symbol() != nil && symbol.ValueDeclaration != nil && ast.IsVariableDeclaration(symbol.ValueDeclaration) && symbol.ValueDeclaration.Initializer() == t.Symbol().ValueDeclaration { + sigs := c.GetCallSignatures(t) + if len(sigs) == 1 { + return c.GetReturnTypeOfSignature(sigs[0]) + } + } + return t +} + +func getDeclarationsFromType(t *checker.Type) []*ast.Node { + var result []*ast.Node + for _, t := range t.Distributed() { + if t.Symbol() != nil { + for _, decl := range t.Symbol().Declarations { + result = core.AppendIfUnique(result, decl) + } + } + } + return result +} diff --git a/internal/ls/findallreferences.go b/internal/ls/findallreferences.go index 4b21242a80..fead25582a 100644 --- a/internal/ls/findallreferences.go +++ b/internal/ls/findallreferences.go @@ -640,7 +640,7 @@ func getReferencedSymbolsSpecial(node *ast.Node, sourceFiles []*ast.SourceFile) if isLabelOfLabeledStatement(node) { // it is a label definition and not a target, search within the parent labeledStatement - return getLabelReferencesInNode(node.Parent, node.AsIdentifier()) + return getLabelReferencesInNode(node.Parent, node) } if isThis(node) { @@ -654,9 +654,9 @@ func getReferencedSymbolsSpecial(node *ast.Node, sourceFiles []*ast.SourceFile) return nil } -func getLabelReferencesInNode(container *ast.Node, targetLabel *ast.Identifier) []*SymbolAndEntries { +func getLabelReferencesInNode(container *ast.Node, targetLabel *ast.Node) []*SymbolAndEntries { sourceFile := ast.GetSourceFileOfNode(container) - labelName := targetLabel.Text + labelName := targetLabel.Text() references := core.MapNonNil(getPossibleSymbolReferenceNodes(sourceFile, labelName, container), func(node *ast.Node) *referenceEntry { // Only pick labels that are either the target label, or have a target that is the target label if node == targetLabel.AsNode() || (isJumpStatementTarget(node) && getTargetLabel(node, labelName) == targetLabel) { @@ -664,7 +664,7 @@ func getLabelReferencesInNode(container *ast.Node, targetLabel *ast.Identifier) } return nil }) - return []*SymbolAndEntries{NewSymbolAndEntries(definitionKindLabel, targetLabel.AsNode(), nil, references)} + return []*SymbolAndEntries{NewSymbolAndEntries(definitionKindLabel, targetLabel, nil, references)} } func getReferencesForThisKeyword(thisOrSuperKeyword *ast.Node, sourceFiles []*ast.SourceFile) []*SymbolAndEntries { diff --git a/internal/ls/utilities.go b/internal/ls/utilities.go index 9f82ce42f9..e98c2a3db8 100644 --- a/internal/ls/utilities.go +++ b/internal/ls/utilities.go @@ -393,7 +393,7 @@ func isInRightSideOfInternalImportEqualsDeclaration(node *ast.Node) bool { } func (l *LanguageService) createLspRangeFromNode(node *ast.Node, file *ast.SourceFile) *lsproto.Range { - return l.createLspRangeFromBounds(node.Pos(), node.End(), file) + return l.createLspRangeFromBounds(scanner.GetTokenPosOfNode(node, file, false /*includeJSDoc*/), node.End(), file) } func (l *LanguageService) createLspRangeFromBounds(start, end int, file *ast.SourceFile) *lsproto.Range { @@ -1418,11 +1418,11 @@ func getPropertySymbolOfObjectBindingPatternWithoutPropertyName(symbol *ast.Symb return nil } -func getTargetLabel(referenceNode *ast.Node, labelName string) *ast.Identifier { +func getTargetLabel(referenceNode *ast.Node, labelName string) *ast.Node { // todo: rewrite as `ast.FindAncestor` for referenceNode != nil { if referenceNode.Kind == ast.KindLabeledStatement && referenceNode.AsLabeledStatement().Label.Text() == labelName { - return referenceNode.AsLabeledStatement().Label.AsIdentifier() + return referenceNode.AsLabeledStatement().Label } referenceNode = referenceNode.Parent } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index a39d84aeea..ba2dcf28f7 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -482,6 +482,8 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R return s.handleHover(ctx, req) case *lsproto.DefinitionParams: return s.handleDefinition(ctx, req) + case *lsproto.TypeDefinitionParams: + return s.handleTypeDefinition(ctx, req) case *lsproto.CompletionParams: return s.handleCompletion(ctx, req) case *lsproto.ReferenceParams: @@ -552,6 +554,9 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) { DefinitionProvider: &lsproto.BooleanOrDefinitionOptions{ Boolean: ptrTo(true), }, + TypeDefinitionProvider: &lsproto.BooleanOrTypeDefinitionOptionsOrTypeDefinitionRegistrationOptions{ + Boolean: ptrTo(true), + }, ReferencesProvider: &lsproto.BooleanOrReferenceOptions{ Boolean: ptrTo(true), }, @@ -696,6 +701,19 @@ func (s *Server) handleDefinition(ctx context.Context, req *lsproto.RequestMessa return nil } +func (s *Server) handleTypeDefinition(ctx context.Context, req *lsproto.RequestMessage) error { + params := req.Params.(*lsproto.TypeDefinitionParams) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + definition, err := languageService.ProvideTypeDefinition(ctx, params.TextDocument.Uri, params.Position) + if err != nil { + return err + } + s.sendResult(req.ID, definition) + return nil +} + func (s *Server) handleReferences(ctx context.Context, req *lsproto.RequestMessage) error { // findAllReferences params := req.Params.(*lsproto.ReferenceParams) diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 41450e1b5a..d8edd0fa02 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -2318,21 +2318,94 @@ func GetTokenPosOfNode(node *ast.Node, sourceFile *ast.SourceFile, includeJSDoc if ast.NodeIsMissing(node) { return node.Pos() } - if ast.IsJSDocNode(node) || node.Kind == ast.KindJsxText { // JsxText cannot actually contain comments, even though the scanner will think it sees comments return SkipTriviaEx(sourceFile.Text(), node.Pos(), &SkipTriviaOptions{StopAtComments: true}) } - if includeJSDoc && len(node.JSDoc(sourceFile)) > 0 { return GetTokenPosOfNode(node.JSDoc(sourceFile)[0], sourceFile, false /*includeJSDoc*/) } + return SkipTriviaEx(sourceFile.Text(), node.Pos(), &SkipTriviaOptions{InJSDoc: node.Flags&ast.NodeFlagsJSDoc != 0}) +} - return SkipTriviaEx( - sourceFile.Text(), - node.Pos(), - &SkipTriviaOptions{InJSDoc: node.Flags&ast.NodeFlagsJSDoc != 0}, - ) +func getErrorRangeForArrowFunction(sourceFile *ast.SourceFile, node *ast.Node) core.TextRange { + pos := SkipTrivia(sourceFile.Text(), node.Pos()) + body := node.AsArrowFunction().Body + if body != nil && body.Kind == ast.KindBlock { + startLine, _ := GetLineAndCharacterOfPosition(sourceFile, body.Pos()) + endLine, _ := GetLineAndCharacterOfPosition(sourceFile, body.End()) + if startLine < endLine { + // The arrow function spans multiple lines, + // make the error span be the first line, inclusive. + return core.NewTextRange(pos, GetEndLinePosition(sourceFile, startLine)) + } + } + return core.NewTextRange(pos, node.End()) +} + +func GetErrorRangeForNode(sourceFile *ast.SourceFile, node *ast.Node) core.TextRange { + errorNode := node + switch node.Kind { + case ast.KindSourceFile: + pos := SkipTrivia(sourceFile.Text(), 0) + if pos == len(sourceFile.Text()) { + return core.NewTextRange(0, 0) + } + return GetRangeOfTokenAtPosition(sourceFile, pos) + // This list is a work in progress. Add missing node kinds to improve their error spans + case ast.KindFunctionDeclaration, ast.KindMethodDeclaration: + if node.Flags&ast.NodeFlagsReparsed != 0 { + errorNode = node + break + } + fallthrough + case ast.KindVariableDeclaration, ast.KindBindingElement, ast.KindClassDeclaration, ast.KindClassExpression, ast.KindInterfaceDeclaration, + ast.KindModuleDeclaration, ast.KindEnumDeclaration, ast.KindEnumMember, ast.KindFunctionExpression, + ast.KindGetAccessor, ast.KindSetAccessor, ast.KindTypeAliasDeclaration, ast.KindJSTypeAliasDeclaration, ast.KindPropertyDeclaration, + ast.KindPropertySignature, ast.KindNamespaceImport: + errorNode = ast.GetNameOfDeclaration(node) + case ast.KindArrowFunction: + return getErrorRangeForArrowFunction(sourceFile, node) + case ast.KindCaseClause, ast.KindDefaultClause: + start := SkipTrivia(sourceFile.Text(), node.Pos()) + end := node.End() + statements := node.AsCaseOrDefaultClause().Statements.Nodes + if len(statements) != 0 { + end = statements[0].Pos() + } + return core.NewTextRange(start, end) + case ast.KindReturnStatement, ast.KindYieldExpression: + pos := SkipTrivia(sourceFile.Text(), node.Pos()) + return GetRangeOfTokenAtPosition(sourceFile, pos) + case ast.KindSatisfiesExpression: + pos := SkipTrivia(sourceFile.Text(), node.AsSatisfiesExpression().Expression.End()) + return GetRangeOfTokenAtPosition(sourceFile, pos) + case ast.KindConstructor: + if node.Flags&ast.NodeFlagsReparsed != 0 { + errorNode = node + break + } + scanner := GetScannerForSourceFile(sourceFile, node.Pos()) + start := scanner.TokenStart() + for scanner.Token() != ast.KindConstructorKeyword && scanner.Token() != ast.KindStringLiteral && scanner.Token() != ast.KindEndOfFile { + scanner.Scan() + } + return core.NewTextRange(start, scanner.TokenEnd()) + // !!! + // case KindJSDocSatisfiesTag: + // pos := scanner.SkipTrivia(sourceFile.Text(), node.tagName.pos) + // return scanner.GetRangeOfTokenAtPosition(sourceFile, pos) + } + if errorNode == nil { + // If we don't have a better node, then just set the error on the first token of + // construct. + return GetRangeOfTokenAtPosition(sourceFile, node.Pos()) + } + pos := errorNode.Pos() + if !ast.NodeIsMissing(errorNode) && !ast.IsJsxText(errorNode) { + pos = SkipTrivia(sourceFile.Text(), pos) + } + return core.NewTextRange(pos, errorNode.End()) } func ComputeLineOfPosition(lineStarts []core.TextPos, pos int) int {