diff --git a/pkg/tools/builtin/lsp.go b/pkg/tools/builtin/lsp.go index 259bb1c7d..f6450bbaa 100644 --- a/pkg/tools/builtin/lsp.go +++ b/pkg/tools/builtin/lsp.go @@ -393,111 +393,70 @@ Line and character positions are 1-based.` // WorkspaceArgs is empty - the workspace tool takes no arguments. type WorkspaceArgs struct{} -// lspToolDef defines a tool with its metadata inline for cleaner registration. -type lspToolDef struct { - name string - title string - readOnly bool - description string - params any - handler tools.ToolHandler +// lspTool is a shorthand for constructing a tools.Tool with common LSP defaults. +func lspTool(name, title, description string, readOnly bool, params any, handler tools.ToolHandler) tools.Tool { + return tools.Tool{ + Name: name, + Category: "lsp", + Description: description, + Parameters: params, + Handler: handler, + Annotations: tools.ToolAnnotations{ + Title: title, + ReadOnlyHint: readOnly, + }, + } } func (t *LSPTool) Tools(context.Context) ([]tools.Tool, error) { h := t.handler - defs := []lspToolDef{ - { - name: ToolNameLSPWorkspace, title: "Get Workspace Info", readOnly: true, - params: tools.MustSchemaFor[WorkspaceArgs](), handler: tools.NewHandler(h.workspace), - description: `Get workspace info and LSP server capabilities. Use at session start to discover available features. Takes no arguments.`, - }, - { - name: ToolNameLSPHover, title: "Get Symbol Info", readOnly: true, - params: tools.MustSchemaFor[PositionArgs](), handler: tools.NewHandler(h.hover), - description: `Get type signature, documentation, and hover info for a symbol at a given position.`, - }, - { - name: ToolNameLSPDefinition, title: "Go to Definition", readOnly: true, - params: tools.MustSchemaFor[PositionArgs](), handler: tools.NewHandler(h.definition), - description: `Find the definition location of a symbol. Returns file path and line number.`, - }, - { - name: ToolNameLSPReferences, title: "Find References", readOnly: true, - params: tools.MustSchemaFor[ReferencesArgs](), handler: tools.NewHandler(h.references), - description: `Find all references to a symbol across the codebase. IMPORTANT: You MUST use this before modifying any symbol definition. Set include_declaration to false to exclude the definition itself.`, - }, - { - name: ToolNameLSPDocumentSymbols, title: "List File Symbols", readOnly: true, - params: tools.MustSchemaFor[FileArgs](), handler: tools.NewHandler(h.documentSymbols), - description: `List all symbols (functions, types, methods, variables, etc.) defined in a file as a hierarchical list.`, - }, - { - name: ToolNameLSPWorkspaceSymbols, title: "Search Workspace Symbols", readOnly: true, - params: tools.MustSchemaFor[WorkspaceSymbolsArgs](), handler: tools.NewHandler(h.workspaceSymbols), - description: `Search for symbols across the workspace using fuzzy matching. Primary tool for locating symbols.`, - }, - { - name: ToolNameLSPDiagnostics, title: "Get Diagnostics", readOnly: true, - params: tools.MustSchemaFor[FileArgs](), handler: tools.NewHandler(h.getDiagnostics), - description: `Get compiler errors, warnings, and hints for a file. IMPORTANT: You MUST call this after every code modification on edited files. Use lsp_code_actions for suggested fixes.`, - }, - { - name: ToolNameLSPRename, title: "Rename Symbol", readOnly: false, - params: tools.MustSchemaFor[RenameArgs](), handler: tools.NewHandler(h.rename), - description: `Rename a symbol across the entire workspace. WRITE operation - modifies files on disk. Run lsp_diagnostics on modified files afterward.`, - }, - { - name: ToolNameLSPCodeActions, title: "Get Code Actions", readOnly: true, - params: tools.MustSchemaFor[CodeActionsArgs](), handler: tools.NewHandler(h.codeActions), - description: `Get available code actions (quick fixes, refactorings) for a line or range. Use after lsp_diagnostics reports errors.`, - }, - { - name: ToolNameLSPFormat, title: "Format File", readOnly: false, - params: tools.MustSchemaFor[FileArgs](), handler: tools.NewHandler(h.format), - description: `Format a file according to language standards. WRITE operation - modifies the file on disk. Only format after lsp_diagnostics reports no errors.`, - }, - { - name: ToolNameLSPCallHierarchy, title: "Call Hierarchy", readOnly: true, - params: tools.MustSchemaFor[CallHierarchyArgs](), handler: tools.NewHandler(h.callHierarchy), - description: `Analyze the call hierarchy of a function or method. Direction: 'incoming' (who calls this) or 'outgoing' (what this calls).`, - }, - { - name: ToolNameLSPTypeHierarchy, title: "Type Hierarchy", readOnly: true, - params: tools.MustSchemaFor[TypeHierarchyArgs](), handler: tools.NewHandler(h.typeHierarchy), - description: `Analyze the type hierarchy. Direction: 'supertypes' (parent types) or 'subtypes' (child types).`, - }, - { - name: ToolNameLSPImplementations, title: "Find Implementations", readOnly: true, - params: tools.MustSchemaFor[PositionArgs](), handler: tools.NewHandler(h.implementations), - description: `Find all concrete implementations of an interface or abstract method. IMPORTANT: You MUST use this before modifying interfaces to find all implementations needing updates.`, - }, - { - name: ToolNameLSPSignatureHelp, title: "Signature Help", readOnly: true, - params: tools.MustSchemaFor[PositionArgs](), handler: tools.NewHandler(h.signatureHelp), - description: `Get function signature and parameter information at a call site. Position the cursor inside a function call's parentheses.`, - }, - { - name: ToolNameLSPInlayHints, title: "Inlay Hints", readOnly: true, - params: tools.MustSchemaFor[InlayHintsArgs](), handler: tools.NewHandler(h.inlayHints), - description: `Get inlay hints (type annotations, parameter names) for a file or line range. Omit start_line/end_line to get hints for the entire file.`, - }, - } - - result := make([]tools.Tool, len(defs)) - for i, def := range defs { - result[i] = tools.Tool{ - Name: def.name, - Category: "lsp", - Description: def.description, - Parameters: def.params, - Handler: def.handler, - Annotations: tools.ToolAnnotations{ - Title: def.title, - ReadOnlyHint: def.readOnly, - }, - } - } - return result, nil + return []tools.Tool{ + lspTool(ToolNameLSPWorkspace, "Get Workspace Info", + `Get workspace info and LSP server capabilities. Use at session start to discover available features. Takes no arguments.`, + true, tools.MustSchemaFor[WorkspaceArgs](), tools.NewHandler(h.workspace)), + lspTool(ToolNameLSPHover, "Get Symbol Info", + `Get type signature, documentation, and hover info for a symbol at a given position.`, + true, tools.MustSchemaFor[PositionArgs](), tools.NewHandler(h.hover)), + lspTool(ToolNameLSPDefinition, "Go to Definition", + `Find the definition location of a symbol. Returns file path and line number.`, + true, tools.MustSchemaFor[PositionArgs](), tools.NewHandler(h.definition)), + lspTool(ToolNameLSPReferences, "Find References", + `Find all references to a symbol across the codebase. IMPORTANT: You MUST use this before modifying any symbol definition. Set include_declaration to false to exclude the definition itself.`, + true, tools.MustSchemaFor[ReferencesArgs](), tools.NewHandler(h.references)), + lspTool(ToolNameLSPDocumentSymbols, "List File Symbols", + `List all symbols (functions, types, methods, variables, etc.) defined in a file as a hierarchical list.`, + true, tools.MustSchemaFor[FileArgs](), tools.NewHandler(h.documentSymbols)), + lspTool(ToolNameLSPWorkspaceSymbols, "Search Workspace Symbols", + `Search for symbols across the workspace using fuzzy matching. Primary tool for locating symbols.`, + true, tools.MustSchemaFor[WorkspaceSymbolsArgs](), tools.NewHandler(h.workspaceSymbols)), + lspTool(ToolNameLSPDiagnostics, "Get Diagnostics", + `Get compiler errors, warnings, and hints for a file. IMPORTANT: You MUST call this after every code modification on edited files. Use lsp_code_actions for suggested fixes.`, + true, tools.MustSchemaFor[FileArgs](), tools.NewHandler(h.getDiagnostics)), + lspTool(ToolNameLSPRename, "Rename Symbol", + `Rename a symbol across the entire workspace. WRITE operation - modifies files on disk. Run lsp_diagnostics on modified files afterward.`, + false, tools.MustSchemaFor[RenameArgs](), tools.NewHandler(h.rename)), + lspTool(ToolNameLSPCodeActions, "Get Code Actions", + `Get available code actions (quick fixes, refactorings) for a line or range. Use after lsp_diagnostics reports errors.`, + true, tools.MustSchemaFor[CodeActionsArgs](), tools.NewHandler(h.codeActions)), + lspTool(ToolNameLSPFormat, "Format File", + `Format a file according to language standards. WRITE operation - modifies the file on disk. Only format after lsp_diagnostics reports no errors.`, + false, tools.MustSchemaFor[FileArgs](), tools.NewHandler(h.format)), + lspTool(ToolNameLSPCallHierarchy, "Call Hierarchy", + `Analyze the call hierarchy of a function or method. Direction: 'incoming' (who calls this) or 'outgoing' (what this calls).`, + true, tools.MustSchemaFor[CallHierarchyArgs](), tools.NewHandler(h.callHierarchy)), + lspTool(ToolNameLSPTypeHierarchy, "Type Hierarchy", + `Analyze the type hierarchy. Direction: 'supertypes' (parent types) or 'subtypes' (child types).`, + true, tools.MustSchemaFor[TypeHierarchyArgs](), tools.NewHandler(h.typeHierarchy)), + lspTool(ToolNameLSPImplementations, "Find Implementations", + `Find all concrete implementations of an interface or abstract method. IMPORTANT: You MUST use this before modifying interfaces to find all implementations needing updates.`, + true, tools.MustSchemaFor[PositionArgs](), tools.NewHandler(h.implementations)), + lspTool(ToolNameLSPSignatureHelp, "Signature Help", + `Get function signature and parameter information at a call site. Position the cursor inside a function call's parentheses.`, + true, tools.MustSchemaFor[PositionArgs](), tools.NewHandler(h.signatureHelp)), + lspTool(ToolNameLSPInlayHints, "Inlay Hints", + `Get inlay hints (type annotations, parameter names) for a file or line range. Omit start_line/end_line to get hints for the entire file.`, + true, tools.MustSchemaFor[InlayHintsArgs](), tools.NewHandler(h.inlayHints)), + }, nil } // lspHandler implementation @@ -615,71 +574,76 @@ func (h *lspHandler) ensureInitialized() error { } } - if !h.initialized.Load() { - rootURI := "file://" + h.workingDir - initParams := map[string]any{ - "processId": os.Getpid(), - "rootUri": rootURI, - "capabilities": map[string]any{ - "textDocument": map[string]any{ - "hover": map[string]any{"contentFormat": []string{"markdown", "plaintext"}}, - "definition": map[string]any{}, - "references": map[string]any{}, - "implementation": map[string]any{}, - "documentSymbol": map[string]any{}, - "publishDiagnostics": map[string]any{}, - "rename": map[string]any{"prepareSupport": true}, - "codeAction": map[string]any{ - "codeActionLiteralSupport": map[string]any{ - "codeActionKind": map[string]any{ - "valueSet": []string{"quickfix", "refactor", "refactor.extract", "refactor.inline", "refactor.rewrite", "source", "source.organizeImports"}, - }, - }, - }, - "formatting": map[string]any{}, - "callHierarchy": map[string]any{"dynamicRegistration": true}, - "typeHierarchy": map[string]any{"dynamicRegistration": true}, - "signatureHelp": map[string]any{ - "signatureInformation": map[string]any{ - "documentationFormat": []string{"markdown", "plaintext"}, - "parameterInformation": map[string]any{"labelOffsetSupport": true}, + if h.initialized.Load() { + return nil + } + + return h.initializeLocked() +} + +// initializeLocked performs the LSP initialize/initialized handshake. +// The caller must hold h.mu and the process must be running. +func (h *lspHandler) initializeLocked() error { + rootURI := "file://" + h.workingDir + + result, err := h.sendRequestLocked("initialize", map[string]any{ + "processId": os.Getpid(), + "rootUri": rootURI, + "capabilities": map[string]any{ + "textDocument": map[string]any{ + "hover": map[string]any{"contentFormat": []string{"markdown", "plaintext"}}, + "definition": map[string]any{}, + "references": map[string]any{}, + "implementation": map[string]any{}, + "documentSymbol": map[string]any{}, + "publishDiagnostics": map[string]any{}, + "rename": map[string]any{"prepareSupport": true}, + "codeAction": map[string]any{ + "codeActionLiteralSupport": map[string]any{ + "codeActionKind": map[string]any{ + "valueSet": []string{"quickfix", "refactor", "refactor.extract", "refactor.inline", "refactor.rewrite", "source", "source.organizeImports"}, }, }, - "inlayHint": map[string]any{"dynamicRegistration": true}, }, - "workspace": map[string]any{ - "symbol": map[string]any{}, - "applyEdit": true, - "workspaceEdit": map[string]any{"documentChanges": true}, + "formatting": map[string]any{}, + "callHierarchy": map[string]any{"dynamicRegistration": true}, + "typeHierarchy": map[string]any{"dynamicRegistration": true}, + "signatureHelp": map[string]any{ + "signatureInformation": map[string]any{ + "documentationFormat": []string{"markdown", "plaintext"}, + "parameterInformation": map[string]any{"labelOffsetSupport": true}, + }, }, + "inlayHint": map[string]any{"dynamicRegistration": true}, }, - } - - result, err := h.sendRequestLocked("initialize", initParams) - if err != nil { - return fmt.Errorf("failed to initialize LSP: %w", err) - } - - // Parse the initialization result to get server info and capabilities - var initResult struct { - Capabilities lspServerCapabilities `json:"capabilities"` - ServerInfo *lspServerInfo `json:"serverInfo,omitempty"` - } - if err := json.Unmarshal(result, &initResult); err != nil { - slog.Debug("Failed to parse initialize result", "error", err) - } else { - h.capabilities = &initResult.Capabilities - h.serverInfo = initResult.ServerInfo - } + "workspace": map[string]any{ + "symbol": map[string]any{}, + "applyEdit": true, + "workspaceEdit": map[string]any{"documentChanges": true}, + }, + }, + }) + if err != nil { + return fmt.Errorf("failed to initialize LSP: %w", err) + } - if err := h.sendNotificationLocked("initialized", map[string]any{}); err != nil { - return fmt.Errorf("failed to send initialized notification: %w", err) - } + var initResult struct { + Capabilities lspServerCapabilities `json:"capabilities"` + ServerInfo *lspServerInfo `json:"serverInfo,omitempty"` + } + if err := json.Unmarshal(result, &initResult); err != nil { + slog.Debug("Failed to parse initialize result", "error", err) + } else { + h.capabilities = &initResult.Capabilities + h.serverInfo = initResult.ServerInfo + } - h.initialized.Store(true) - slog.Debug("LSP server initialized", "rootUri", rootURI) + if err := h.sendNotificationLocked("initialized", map[string]any{}); err != nil { + return fmt.Errorf("failed to send initialized notification: %w", err) } + h.initialized.Store(true) + slog.Debug("LSP server initialized", "rootUri", rootURI) return nil } @@ -793,8 +757,11 @@ func (h *lspHandler) hover(ctx context.Context, args PositionArgs) (*tools.ToolC return tools.ResultSuccess(formatHoverContents(hover.Contents)), nil } -func (h *lspHandler) definition(ctx context.Context, args PositionArgs) (*tools.ToolCallResult, error) { - uri, err := h.prepareFileRequest(ctx, args.File) +// locationRequest issues a textDocument/ position request and formats +// the result as locations. Used by definition and implementations which share +// exactly the same shape. +func (h *lspHandler) locationRequest(ctx context.Context, method, file string, line, character int, emptyMsg string) (*tools.ToolCallResult, error) { + uri, err := h.prepareFileRequest(ctx, file) if err != nil { return tools.ResultError(err.Error()), nil } @@ -802,23 +769,29 @@ func (h *lspHandler) definition(ctx context.Context, args PositionArgs) (*tools. h.mu.Lock() defer h.mu.Unlock() - params := map[string]any{ + result, err := h.sendRequestLocked("textDocument/"+method, map[string]any{ "textDocument": map[string]any{"uri": uri}, - "position": map[string]any{"line": args.Line - 1, "character": args.Character - 1}, - } - - result, err := h.sendRequestLocked("textDocument/definition", params) + "position": map[string]any{"line": line - 1, "character": character - 1}, + }) if err != nil { - return tools.ResultError(fmt.Sprintf("Definition request failed: %s", err)), nil + return tools.ResultError(fmt.Sprintf("%s request failed: %s", method, err)), nil } - if len(result) == 0 || string(result) == "null" { - return tools.ResultSuccess("No definition found at this position"), nil + if len(result) == 0 || string(result) == "null" || string(result) == "[]" { + return tools.ResultSuccess(emptyMsg), nil } return tools.ResultSuccess(formatLocations(result)), nil } +func (h *lspHandler) definition(ctx context.Context, args PositionArgs) (*tools.ToolCallResult, error) { + return h.locationRequest(ctx, "definition", args.File, args.Line, args.Character, "No definition found at this position") +} + +func (h *lspHandler) implementations(ctx context.Context, args PositionArgs) (*tools.ToolCallResult, error) { + return h.locationRequest(ctx, "implementation", args.File, args.Line, args.Character, "No implementations found") +} + func (h *lspHandler) references(ctx context.Context, args ReferencesArgs) (*tools.ToolCallResult, error) { uri, err := h.prepareFileRequest(ctx, args.File) if err != nil { @@ -973,7 +946,7 @@ func (h *lspHandler) codeActions(ctx context.Context, args CodeActionsArgs) (*to fileDiags := h.diagnostics[uri] h.diagnosticsMu.RUnlock() - var rangeDiags []lspDiagnostic + rangeDiags := make([]lspDiagnostic, 0) for _, d := range fileDiags { diagLine := d.Range.Start.Line + 1 if diagLine >= args.StartLine && diagLine <= endLine { @@ -1038,7 +1011,7 @@ func (h *lspHandler) format(ctx context.Context, args FileArgs) (*tools.ToolCall return tools.ResultError(fmt.Sprintf("Failed to apply formatting: %s", err)), nil } - if err := h.NotifyFileChange(ctx, uri); err != nil { + if err := h.notifyFileChangeLocked(uri); err != nil { slog.Debug("Failed to notify LSP of format changes", "error", err) } @@ -1155,32 +1128,6 @@ func (h *lspHandler) typeHierarchy(ctx context.Context, args TypeHierarchyArgs) return tools.ResultSuccess(result.String()), nil } -func (h *lspHandler) implementations(ctx context.Context, args PositionArgs) (*tools.ToolCallResult, error) { - uri, err := h.prepareFileRequest(ctx, args.File) - if err != nil { - return tools.ResultError(err.Error()), nil - } - - h.mu.Lock() - defer h.mu.Unlock() - - params := map[string]any{ - "textDocument": map[string]any{"uri": uri}, - "position": map[string]any{"line": args.Line - 1, "character": args.Character - 1}, - } - - result, err := h.sendRequestLocked("textDocument/implementation", params) - if err != nil { - return tools.ResultError(fmt.Sprintf("Implementations request failed: %s", err)), nil - } - - if len(result) == 0 || string(result) == "null" || string(result) == "[]" { - return tools.ResultSuccess("No implementations found"), nil - } - - return tools.ResultSuccess(formatLocations(result)), nil -} - func (h *lspHandler) signatureHelp(ctx context.Context, args PositionArgs) (*tools.ToolCallResult, error) { uri, err := h.prepareFileRequest(ctx, args.File) if err != nil { @@ -1249,7 +1196,9 @@ func (h *lspHandler) inlayHints(ctx context.Context, args InlayHintsArgs) (*tool return tools.ResultSuccess(formatInlayHints(args.File, startLine, endLine, hints)), nil } -// applyWorkspaceEdit applies a workspace edit and returns a summary +// applyWorkspaceEdit applies a workspace edit to files on disk and notifies +// the LSP server of the changes so its in-memory state stays in sync. +// The caller must hold h.mu. func (h *lspHandler) applyWorkspaceEdit(edit *lspWorkspaceEdit, newName string) *tools.ToolCallResult { var totalChanges int var modifiedFiles []string @@ -1283,6 +1232,17 @@ func (h *lspHandler) applyWorkspaceEdit(edit *lspWorkspaceEdit, newName string) return tools.ResultSuccess("No changes were needed") } + // Notify the LSP server about each modified file that it has open, + // so subsequent operations (diagnostics, hover, etc.) see the new content. + for _, file := range modifiedFiles { + uri := pathToURI(file) + if h.isFileOpen(uri) { + if err := h.notifyFileChangeLocked(uri); err != nil { + slog.Debug("Failed to notify LSP of rename changes", "file", file, "error", err) + } + } + } + var result strings.Builder fmt.Fprintf(&result, "Renamed to '%s'\n", newName) fmt.Fprintf(&result, "Modified %d file(s):\n", len(modifiedFiles)) @@ -1586,6 +1546,15 @@ func (h *lspHandler) NotifyFileChange(_ context.Context, uri string) error { return fmt.Errorf("file not open: %s", uri) } + h.mu.Lock() + defer h.mu.Unlock() + + return h.notifyFileChangeLocked(uri) +} + +// notifyFileChangeLocked re-reads a file from disk and sends a +// textDocument/didChange notification. The caller must hold h.mu. +func (h *lspHandler) notifyFileChangeLocked(uri string) error { filePath := strings.TrimPrefix(uri, "file://") content, err := os.ReadFile(filePath) @@ -1598,9 +1567,6 @@ func (h *lspHandler) NotifyFileChange(_ context.Context, uri string) error { version := h.openFiles[uri] h.openFilesMu.Unlock() - h.mu.Lock() - defer h.mu.Unlock() - changeParams := map[string]any{ "textDocument": map[string]any{"uri": uri, "version": version}, "contentChanges": []map[string]any{{"text": string(content)}},