diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 6c75f0e..dcaf62f 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -186,7 +186,7 @@ func TestAWSBedrockIntegration(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Invalid bedrock config - missing region + // Invalid bedrock config - missing region & base url bedrockCfg := &config.AWSBedrock{ Region: "", AccessKey: "test-key", @@ -218,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "create anthropic client") - require.Contains(t, string(body), "region required") + require.Contains(t, string(body), "region or base url required") }) t.Run("/v1/messages", func(t *testing.T) { @@ -281,15 +281,14 @@ func TestAWSBedrockIntegration(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - // Configure Bedrock with test credentials and model names. - // The EndpointOverride will make requests go to the mock server instead of real AWS endpoints. + // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "danthropic", // This model should override the request's given one. - SmallFastModel: "danthropic-mini", // Unused but needed for validation. - EndpointOverride: srv.URL, + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "danthropic", // This model should override the request's given one. + SmallFastModel: "danthropic-mini", // Unused but needed for validation. + BaseURL: srv.URL, // Use the mock server. } recorderClient := &testutil.MockRecorder{} diff --git a/config/config.go b/config/config.go index 74209c6..3387007 100644 --- a/config/config.go +++ b/config/config.go @@ -46,9 +46,10 @@ type AWSBedrock struct { Region string AccessKey, AccessKeySecret string Model, SmallFastModel string - // EndpointOverride allows overriding the Bedrock endpoint URL for testing. - // If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint. - EndpointOverride string + // If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint + // (https://bedrock-runtime.{region}.amazonaws.com). + // This is useful for routing requests through a proxy or for testing. + BaseURL string } type OpenAI struct { diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 4c993e1..ba37067 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "net/url" "strings" "time" @@ -163,34 +162,23 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio if i.bedrockCfg != nil { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg) + bedrockOpts, err := i.withAWSBedrockOptions(ctx, i.bedrockCfg) if err != nil { return anthropic.MessageService{}, err } - opts = append(opts, bedrockOpt) + opts = append(opts, bedrockOpts...) i.augmentRequestForBedrock() - - // If an endpoint override is set (for testing), add a custom HTTP client AFTER the bedrock config - // This overrides any HTTP client set by the bedrock middleware - if i.bedrockCfg.EndpointOverride != "" { - opts = append(opts, option.WithHTTPClient(&http.Client{ - Transport: &redirectTransport{ - base: http.DefaultTransport, - redirectToURL: i.bedrockCfg.EndpointOverride, - }, - })) - } } return anthropic.NewMessageService(opts...), nil } -func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AWSBedrock) (option.RequestOption, error) { +func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { if cfg == nil { return nil, fmt.Errorf("nil config given") } - if cfg.Region == "" { - return nil, fmt.Errorf("region required") + if cfg.Region == "" && cfg.BaseURL == "" { + return nil, fmt.Errorf("region or base url required") } if cfg.AccessKey == "" { return nil, fmt.Errorf("access key required") @@ -221,7 +209,15 @@ func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AW return nil, fmt.Errorf("failed to load AWS Bedrock config: %w", err) } - return bedrock.WithConfig(awsCfg), nil + var out []option.RequestOption + out = append(out, bedrock.WithConfig(awsCfg)) + + // If a custom base URL is set, override the default endpoint constructed by the bedrock middleware. + if cfg.BaseURL != "" { + out = append(out, option.WithBaseURL(cfg.BaseURL)) + } + + return out, nil } // augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support @@ -261,28 +257,6 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err } } -// redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint. -// This is useful for testing when we need to redirect AWS Bedrock requests to a mock server. -type redirectTransport struct { - base http.RoundTripper - redirectToURL string -} - -func (t *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Parse the redirect URL - redirectURL, err := url.Parse(t.redirectToURL) - if err != nil { - return nil, err - } - - // Redirect the request to the mock server - req.URL.Scheme = redirectURL.Scheme - req.URL.Host = redirectURL.Host - req.Host = redirectURL.Host - - return t.base.RoundTrip(req) -} - // accumulateUsage accumulates usage statistics from source into dest. // It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any]. // The function uses reflection to handle the differences between the types: diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index b880b11..5413a7d 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -21,8 +21,9 @@ func TestAWSBedrockValidation(t *testing.T) { expectError bool errorMsg string }{ + // Valid cases. { - name: "valid", + name: "valid with region", cfg: &config.AWSBedrock{ Region: "us-east-1", AccessKey: "test-key", @@ -32,7 +33,33 @@ func TestAWSBedrockValidation(t *testing.T) { }, }, { - name: "missing region", + name: "valid with base url", + cfg: &config.AWSBedrock{ + BaseURL: "http://bedrock.internal", + AccessKey: "test-key", + AccessKeySecret: "test-secret", + Model: "test-model", + SmallFastModel: "test-small-model", + }, + }, + { + // There unfortunately isn't a way for us to determine precedence in a unit test, + // since the produced options take a `requestconfig.RequestConfig` input value + // which is internal to the anthropic SDK. + // + // See TestAWSBedrockIntegration which validates this. + name: "valid with base url & region", + cfg: &config.AWSBedrock{ + Region: "us-east-1", + AccessKey: "test-key", + AccessKeySecret: "test-secret", + Model: "test-model", + SmallFastModel: "test-small-model", + }, + }, + // Invalid cases. + { + name: "missing region & base url", cfg: &config.AWSBedrock{ Region: "", AccessKey: "test-key", @@ -41,7 +68,7 @@ func TestAWSBedrockValidation(t *testing.T) { SmallFastModel: "test-small-model", }, expectError: true, - errorMsg: "region required", + errorMsg: "region or base url required", }, { name: "missing access key", @@ -95,7 +122,7 @@ func TestAWSBedrockValidation(t *testing.T) { name: "all fields empty", cfg: &config.AWSBedrock{}, expectError: true, - errorMsg: "region required", + errorMsg: "region or base url required", }, { name: "nil config", @@ -108,12 +135,13 @@ func TestAWSBedrockValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { base := &interceptionBase{} - _, err := base.withAWSBedrock(context.Background(), tt.cfg) + opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg) if tt.expectError { require.Error(t, err) require.Contains(t, err.Error(), tt.errorMsg) } else { + require.NotEmpty(t, opts) require.NoError(t, err) } }) diff --git a/trace_integration_test.go b/trace_integration_test.go index 10ec30c..0b1de49 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -863,11 +863,11 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e func testBedrockCfg(url string) *config.AWSBedrock { return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - EndpointOverride: url, + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, } }