diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index b2e82f4e0..156de8b25 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -96,8 +96,10 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &bearerAuthTransport{ - transport: http.DefaultTransport, - token: cfg.Token, + transport: &github.GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + }, + token: cfg.Token, }, } gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) @@ -625,12 +627,6 @@ type bearerAuthTransport struct { func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) req.Header.Set("Authorization", "Bearer "+t.token) - - // Check for GraphQL-Features in context and add header if present - if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 { - req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) - } - return t.transport.RoundTrip(req) } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 3d57165d5..ce3f138df 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -1609,6 +1609,104 @@ func (d *mvpDescription) String() string { return sb.String() } +// linkedPullRequest represents a PR linked to an issue by Copilot. +type linkedPullRequest struct { + Number int + URL string + Title string + State string + CreatedAt time.Time +} + +// pollConfigKey is a context key for polling configuration. +type pollConfigKey struct{} + +// PollConfig configures the PR polling behavior. +type PollConfig struct { + MaxAttempts int + Delay time.Duration +} + +// ContextWithPollConfig returns a context with polling configuration. +// Use this in tests to reduce or disable polling. +func ContextWithPollConfig(ctx context.Context, config PollConfig) context.Context { + return context.WithValue(ctx, pollConfigKey{}, config) +} + +// getPollConfig returns the polling configuration from context, or defaults. +func getPollConfig(ctx context.Context) PollConfig { + if config, ok := ctx.Value(pollConfigKey{}).(PollConfig); ok { + return config + } + // Default: 9 attempts with 1s delay = 8s max wait + // Based on observed latency in remote server: p50 ~5s, p90 ~7s + return PollConfig{MaxAttempts: 9, Delay: 1 * time.Second} +} + +// findLinkedCopilotPR searches for a PR created by the copilot-swe-agent bot that references the given issue. +// It queries the issue's timeline for CrossReferencedEvent items from PRs authored by copilot-swe-agent. +// The createdAfter parameter filters to only return PRs created after the specified time. +func findLinkedCopilotPR(ctx context.Context, client *githubv4.Client, owner, repo string, issueNumber int, createdAfter time.Time) (*linkedPullRequest, error) { + // Query timeline items looking for CrossReferencedEvent from PRs by copilot-swe-agent + var query struct { + Repository struct { + Issue struct { + TimelineItems struct { + Nodes []struct { + TypeName string `graphql:"__typename"` + CrossReferencedEvent struct { + Source struct { + PullRequest struct { + Number int + URL string + Title string + State string + CreatedAt githubv4.DateTime + Author struct { + Login string + } + } `graphql:"... on PullRequest"` + } + } `graphql:"... on CrossReferencedEvent"` + } + } `graphql:"timelineItems(first: 20, itemTypes: [CROSS_REFERENCED_EVENT])"` + } `graphql:"issue(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + variables := map[string]any{ + "owner": githubv4.String(owner), + "name": githubv4.String(repo), + "number": githubv4.Int(issueNumber), //nolint:gosec // Issue numbers are always small positive integers + } + + if err := client.Query(ctx, &query, variables); err != nil { + return nil, err + } + + // Look for a PR from copilot-swe-agent created after the assignment time + for _, node := range query.Repository.Issue.TimelineItems.Nodes { + if node.TypeName != "CrossReferencedEvent" { + continue + } + pr := node.CrossReferencedEvent.Source.PullRequest + if pr.Number > 0 && pr.Author.Login == "copilot-swe-agent" { + // Only return PRs created after the assignment time + if pr.CreatedAt.Time.After(createdAfter) { + return &linkedPullRequest{ + Number: pr.Number, + URL: pr.URL, + Title: pr.Title, + State: pr.State, + CreatedAt: pr.CreatedAt.Time, + }, nil + } + } + } + + return nil, nil +} + func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", @@ -1802,6 +1900,9 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // The header will be read by the HTTP transport if it's configured to do so ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + // Capture the time before assignment to filter out older PRs during polling + assignmentTime := time.Now().UTC() + if err := client.Mutate( ctxWithFeatures, &updateIssueMutation, @@ -1815,7 +1916,55 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server return nil, nil, fmt.Errorf("failed to update issue with agent assignment: %w", err) } - return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil + // Poll for a linked PR created by Copilot after the assignment + pollConfig := getPollConfig(ctx) + + var linkedPR *linkedPullRequest + for attempt := range pollConfig.MaxAttempts { + if attempt > 0 { + time.Sleep(pollConfig.Delay) + } + + pr, err := findLinkedCopilotPR(ctx, client, params.Owner, params.Repo, int(params.IssueNumber), assignmentTime) + if err != nil { + // Polling errors are non-fatal, continue to next attempt + continue + } + if pr != nil { + linkedPR = pr + break + } + } + + // Build the result + result := map[string]any{ + "message": "successfully assigned copilot to issue", + "issue_number": int(updateIssueMutation.UpdateIssue.Issue.Number), + "issue_url": string(updateIssueMutation.UpdateIssue.Issue.URL), + "owner": params.Owner, + "repo": params.Repo, + } + + // Add PR info if found during polling + if linkedPR != nil { + result["pull_request"] = map[string]any{ + "number": linkedPR.Number, + "url": linkedPR.URL, + "title": linkedPR.Title, + "state": linkedPR.State, + } + result["message"] = "successfully assigned copilot to issue - pull request created" + } else { + result["message"] = "successfully assigned copilot to issue - pull request pending" + result["note"] = "The pull request may still be in progress. Once created, the PR number can be used to check job status, or check the issue timeline for updates." + } + + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to marshal response: %s", err)), nil, nil + } + + return utils.NewToolResultText(string(r)), result, nil }) } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 0d9070ace..a338efcba 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -2765,8 +2765,12 @@ func TestAssignCopilotToIssue(t *testing.T) { // Create call request request := createMCPRequest(tc.requestArgs) + // Disable polling in tests to avoid timeouts + ctx := ContextWithPollConfig(context.Background(), PollConfig{MaxAttempts: 0}) + ctx = ContextWithDeps(ctx, deps) + // Call handler - result, err := handler(ContextWithDeps(context.Background(), deps), &request) + result, err := handler(ctx, &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2778,7 +2782,16 @@ func TestAssignCopilotToIssue(t *testing.T) { } require.False(t, result.IsError, fmt.Sprintf("expected there to be no tool error, text was %s", textContent.Text)) - require.Equal(t, textContent.Text, "successfully assigned copilot to issue") + + // Verify the JSON response contains expected fields + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "response should be valid JSON") + assert.Equal(t, float64(123), response["issue_number"]) + assert.Equal(t, "https://github.com/owner/repo/issues/123", response["issue_url"]) + assert.Equal(t, "owner", response["owner"]) + assert.Equal(t, "repo", response["repo"]) + assert.Contains(t, response["message"], "successfully assigned copilot to issue") }) } } diff --git a/pkg/github/transport.go b/pkg/github/transport.go new file mode 100644 index 000000000..0a4372b23 --- /dev/null +++ b/pkg/github/transport.go @@ -0,0 +1,47 @@ +package github + +import ( + "net/http" + "strings" +) + +// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features +// header to requests based on context values. This is required for using +// non-GA GraphQL API features like the agent assignment API. +// +// This transport is used internally by the MCP server and is also exported +// for library consumers who need to build their own HTTP clients with +// GraphQL feature flag support. +// +// Usage: +// +// httpClient := &http.Client{ +// Transport: &github.GraphQLFeaturesTransport{ +// Transport: http.DefaultTransport, +// }, +// } +// gqlClient := githubv4.NewClient(httpClient) +// +// Then use withGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations. +type GraphQLFeaturesTransport struct { + // Transport is the underlying HTTP transport. If nil, http.DefaultTransport is used. + Transport http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + // Clone the request to avoid mutating the original + req = req.Clone(req.Context()) + + // Check for GraphQL-Features in context and add header if present + if features := GetGraphQLFeatures(req.Context()); len(features) > 0 { + req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) + } + + return transport.RoundTrip(req) +} diff --git a/pkg/github/transport_test.go b/pkg/github/transport_test.go new file mode 100644 index 000000000..c98108255 --- /dev/null +++ b/pkg/github/transport_test.go @@ -0,0 +1,151 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGraphQLFeaturesTransport(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + features []string + expectedHeader string + hasHeader bool + }{ + { + name: "no features in context", + features: nil, + expectedHeader: "", + hasHeader: false, + }, + { + name: "single feature in context", + features: []string{"issues_copilot_assignment_api_support"}, + expectedHeader: "issues_copilot_assignment_api_support", + hasHeader: true, + }, + { + name: "multiple features in context", + features: []string{"feature1", "feature2", "feature3"}, + expectedHeader: "feature1, feature2, feature3", + hasHeader: true, + }, + { + name: "empty features slice", + features: []string{}, + expectedHeader: "", + hasHeader: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var capturedHeader string + var headerExists bool + + // Create a test server that captures the request header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeader = r.Header.Get("GraphQL-Features") + headerExists = r.Header.Get("GraphQL-Features") != "" + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create the transport + transport := &GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + } + + // Create a request + ctx := context.Background() + if tc.features != nil { + ctx = withGraphQLFeatures(ctx, tc.features...) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + // Execute the request + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify the header + assert.Equal(t, tc.hasHeader, headerExists) + if tc.hasHeader { + assert.Equal(t, tc.expectedHeader, capturedHeader) + } + }) + } +} + +func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) { + t.Parallel() + + var capturedHeader string + + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeader = r.Header.Get("GraphQL-Features") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create the transport with nil Transport (should use DefaultTransport) + transport := &GraphQLFeaturesTransport{ + Transport: nil, + } + + // Create a request with features + ctx := withGraphQLFeatures(context.Background(), "test_feature") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + // Execute the request + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify the header was added + assert.Equal(t, "test_feature", capturedHeader) +} + +func TestGraphQLFeaturesTransport_DoesNotMutateOriginalRequest(t *testing.T) { + t.Parallel() + + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create the transport + transport := &GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + } + + // Create a request with features + ctx := withGraphQLFeatures(context.Background(), "test_feature") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + // Store the original header value + originalHeader := req.Header.Get("GraphQL-Features") + + // Execute the request + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify the original request was not mutated + assert.Equal(t, originalHeader, req.Header.Get("GraphQL-Features")) +}