Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Invalid bedrock config - missing region
// Invalid bedrock config - missing region & base url
bedrockCfg := &config.AWSBedrock{
Region: "",
AccessKey: "test-key",
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "create anthropic client")
require.Contains(t, string(body), "region required")
require.Contains(t, string(body), "region or base url required")
})

t.Run("/v1/messages", func(t *testing.T) {
Expand Down Expand Up @@ -281,15 +281,14 @@ func TestAWSBedrockIntegration(t *testing.T) {
srv.Start()
t.Cleanup(srv.Close)

// Configure Bedrock with test credentials and model names.
// The EndpointOverride will make requests go to the mock server instead of real AWS endpoints.
// We define region here to validate that with Region & BaseURL defined, the latter takes precedence.
bedrockCfg := &config.AWSBedrock{
Region: "us-west-2",
AccessKey: "test-access-key",
AccessKeySecret: "test-secret-key",
Model: "danthropic", // This model should override the request's given one.
SmallFastModel: "danthropic-mini", // Unused but needed for validation.
EndpointOverride: srv.URL,
Region: "us-west-2",
AccessKey: "test-access-key",
AccessKeySecret: "test-secret-key",
Model: "danthropic", // This model should override the request's given one.
SmallFastModel: "danthropic-mini", // Unused but needed for validation.
BaseURL: srv.URL, // Use the mock server.
}

recorderClient := &testutil.MockRecorder{}
Expand Down
7 changes: 4 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ type AWSBedrock struct {
Region string
AccessKey, AccessKeySecret string
Model, SmallFastModel string
// EndpointOverride allows overriding the Bedrock endpoint URL for testing.
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint.
EndpointOverride string
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint
// (https://bedrock-runtime.{region}.amazonaws.com).
// This is useful for routing requests through a proxy or for testing.
BaseURL string
}

type OpenAI struct {
Expand Down
54 changes: 14 additions & 40 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -163,34 +162,23 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
if i.bedrockCfg != nil {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg)
bedrockOpts, err := i.withAWSBedrockOptions(ctx, i.bedrockCfg)
if err != nil {
return anthropic.MessageService{}, err
}
opts = append(opts, bedrockOpt)
opts = append(opts, bedrockOpts...)
i.augmentRequestForBedrock()

// If an endpoint override is set (for testing), add a custom HTTP client AFTER the bedrock config
// This overrides any HTTP client set by the bedrock middleware
if i.bedrockCfg.EndpointOverride != "" {
opts = append(opts, option.WithHTTPClient(&http.Client{
Transport: &redirectTransport{
base: http.DefaultTransport,
redirectToURL: i.bedrockCfg.EndpointOverride,
},
}))
}
}

return anthropic.NewMessageService(opts...), nil
}

func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AWSBedrock) (option.RequestOption, error) {
func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
if cfg == nil {
return nil, fmt.Errorf("nil config given")
}
if cfg.Region == "" {
return nil, fmt.Errorf("region required")
if cfg.Region == "" && cfg.BaseURL == "" {
return nil, fmt.Errorf("region or base url required")
}
if cfg.AccessKey == "" {
return nil, fmt.Errorf("access key required")
Expand Down Expand Up @@ -221,7 +209,15 @@ func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AW
return nil, fmt.Errorf("failed to load AWS Bedrock config: %w", err)
}

return bedrock.WithConfig(awsCfg), nil
var out []option.RequestOption
out = append(out, bedrock.WithConfig(awsCfg))

// If a custom base URL is set, override the default endpoint constructed by the bedrock middleware.
if cfg.BaseURL != "" {
out = append(out, option.WithBaseURL(cfg.BaseURL))
}

return out, nil
}

// augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support
Expand Down Expand Up @@ -261,28 +257,6 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
}
}

// redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint.
// This is useful for testing when we need to redirect AWS Bedrock requests to a mock server.
type redirectTransport struct {
base http.RoundTripper
redirectToURL string
}

func (t *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Parse the redirect URL
redirectURL, err := url.Parse(t.redirectToURL)
if err != nil {
return nil, err
}

// Redirect the request to the mock server
req.URL.Scheme = redirectURL.Scheme
req.URL.Host = redirectURL.Host
req.Host = redirectURL.Host

return t.base.RoundTrip(req)
}

// accumulateUsage accumulates usage statistics from source into dest.
// It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any].
// The function uses reflection to handle the differences between the types:
Expand Down
38 changes: 33 additions & 5 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ func TestAWSBedrockValidation(t *testing.T) {
expectError bool
errorMsg string
}{
// Valid cases.
{
name: "valid",
name: "valid with region",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
Expand All @@ -32,7 +33,33 @@ func TestAWSBedrockValidation(t *testing.T) {
},
},
{
name: "missing region",
name: "valid with base url",
cfg: &config.AWSBedrock{
BaseURL: "http://bedrock.internal",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
},
{
// There unfortunately isn't a way for us to determine precedence in a unit test,
// since the produced options take a `requestconfig.RequestConfig` input value
// which is internal to the anthropic SDK.
//
// See TestAWSBedrockIntegration which validates this.
name: "valid with base url & region",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
},
// Invalid cases.
{
name: "missing region & base url",
cfg: &config.AWSBedrock{
Region: "",
AccessKey: "test-key",
Expand All @@ -41,7 +68,7 @@ func TestAWSBedrockValidation(t *testing.T) {
SmallFastModel: "test-small-model",
},
expectError: true,
errorMsg: "region required",
errorMsg: "region or base url required",
},
{
name: "missing access key",
Expand Down Expand Up @@ -95,7 +122,7 @@ func TestAWSBedrockValidation(t *testing.T) {
name: "all fields empty",
cfg: &config.AWSBedrock{},
expectError: true,
errorMsg: "region required",
errorMsg: "region or base url required",
},
{
name: "nil config",
Expand All @@ -108,12 +135,13 @@ func TestAWSBedrockValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
base := &interceptionBase{}
_, err := base.withAWSBedrock(context.Background(), tt.cfg)
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)

if tt.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NotEmpty(t, opts)
require.NoError(t, err)
}
})
Expand Down
12 changes: 6 additions & 6 deletions trace_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,11 +863,11 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e

func testBedrockCfg(url string) *config.AWSBedrock {
return &config.AWSBedrock{
Region: "us-west-2",
AccessKey: "test-access-key",
AccessKeySecret: "test-secret-key",
Model: "beddel", // This model should override the request's given one.
SmallFastModel: "modrock", // Unused but needed for validation.
EndpointOverride: url,
Region: "us-west-2",
AccessKey: "test-access-key",
AccessKeySecret: "test-secret-key",
Model: "beddel", // This model should override the request's given one.
SmallFastModel: "modrock", // Unused but needed for validation.
BaseURL: url,
}
}