Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions audit/log_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import "log/slog"

// LogAuditor implements proxy.Auditor by logging to slog
type LogAuditor struct {
logger *slog.Logger
logger *slog.Logger
sessionID string
}

// NewLogAuditor creates a new LogAuditor
Expand All @@ -14,19 +15,27 @@ func NewLogAuditor(logger *slog.Logger) *LogAuditor {
}
}

// NewLogAuditorWithSession creates a new LogAuditor that includes a session ID on every log line.
func NewLogAuditorWithSession(logger *slog.Logger, sessionID string) *LogAuditor {
return &LogAuditor{
logger: logger,
sessionID: sessionID,
}
}

// AuditRequest logs the request using structured logging
func (a *LogAuditor) AuditRequest(req Request) {
fields := []any{
"method", req.Method,
"url", req.URL,
"host", req.Host,
}
if a.sessionID != "" {
fields = append(fields, "session_id", a.sessionID)
}
if req.Allowed {
a.logger.Info("ALLOW",
"method", req.Method,
"url", req.URL,
"host", req.Host,
"rule", req.Rule)
a.logger.Info("ALLOW", append(fields, "rule", req.Rule)...)
} else {
a.logger.Warn("DENY",
"method", req.Method,
"url", req.URL,
"host", req.Host,
)
a.logger.Warn("DENY", fields...)
}
}
6 changes: 3 additions & 3 deletions audit/multi_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func (m *MultiAuditor) AuditRequest(req Request) {
// provided configuration. It always includes a LogAuditor for stderr logging,
// and conditionally adds a SocketAuditor if audit logs are enabled and the
// workspace agent's log proxy socket exists.
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) {
stderrAuditor := NewLogAuditor(logger)
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string, sessionID string) (Auditor, error) {
stderrAuditor := NewLogAuditorWithSession(logger, sessionID)
auditors := []Auditor{stderrAuditor}

if !disableAuditLogs {
Expand All @@ -48,7 +48,7 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo
}
agentWillProxy := !os.IsNotExist(err)
if agentWillProxy {
socketAuditor := NewSocketAuditor(logger, logProxySocketPath)
socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID)
go socketAuditor.Loop(ctx)
auditors = append(auditors, socketAuditor)
} else {
Expand Down
8 changes: 4 additions & 4 deletions audit/multi_auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestSetupAuditor_DisabledAuditLogs(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

auditor, err := SetupAuditor(ctx, logger, true, "")
auditor, err := SetupAuditor(ctx, logger, true, "", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand All @@ -50,7 +50,7 @@ func TestSetupAuditor_EmptySocketPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

_, err := SetupAuditor(ctx, logger, false, "")
_, err := SetupAuditor(ctx, logger, false, "", "")
if err == nil {
t.Fatal("expected error for empty socket path, got nil")
}
Expand All @@ -62,7 +62,7 @@ func TestSetupAuditor_SocketDoesNotExist(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path")
auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestSetupAuditor_SocketExists(t *testing.T) {
t.Fatalf("failed to close temp file: %v", err)
}

auditor, err := SetupAuditor(ctx, logger, false, socketPath)
auditor, err := SetupAuditor(ctx, logger, false, socketPath, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down
15 changes: 11 additions & 4 deletions audit/socket_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type SocketAuditor struct {
batchSize int
batchTimerDuration time.Duration
socketPath string
// sessionID is included in every batch sent to the workspace agent so that
// all audit events from a single boundary invocation can be correlated.
sessionID string

droppedChannelFull atomic.Int64
droppedBatchFull atomic.Int64
Expand All @@ -45,7 +48,7 @@ type SocketAuditor struct {
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID string) *SocketAuditor {
// This channel buffer size intends to allow enough buffering for bursty
// AI agent network requests while a batch is being sent to the workspace
// agent.
Expand All @@ -60,6 +63,7 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
socketPath: socketPath,
sessionID: sessionID,
}
}

Expand Down Expand Up @@ -100,14 +104,17 @@ type flushErr struct {
func (e *flushErr) Error() string { return e.err.Error() }

// flush sends the current batch of logs to the given connection.
func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr {
func flush(conn net.Conn, logs []*agentproto.BoundaryLog, sessionID string) *flushErr {
if len(logs) == 0 {
return nil
}

msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Logs{
Logs: &agentproto.ReportBoundaryLogsRequest{Logs: logs},
Logs: &agentproto.ReportBoundaryLogsRequest{
Logs: logs,
SessionId: sessionID,
},
},
}
if err := codec.WriteMessage(conn, codec.TagV2, msg); err != nil {
Expand Down Expand Up @@ -188,7 +195,7 @@ func (s *SocketAuditor) Loop(ctx context.Context) {
return
}

if err := flush(conn, batch); err != nil {
if err := flush(conn, batch, s.sessionID); err != nil {
if err.permanent {
// Data error: discard batch to avoid infinite retries.
s.logger.Warn("dropping batch due to data error on flush attempt",
Expand Down
46 changes: 44 additions & 2 deletions audit/socket_auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,17 +451,59 @@ func TestSocketAuditor_Loop_ShutdownFlushIncludesDrops(t *testing.T) {
func TestFlush_EmptyBatch(t *testing.T) {
t.Parallel()

err := flush(nil, nil)
err := flush(nil, nil, "")
if err != nil {
t.Errorf("expected nil error for empty batch, got %v", err)
}

err = flush(nil, []*agentproto.BoundaryLog{})
err = flush(nil, []*agentproto.BoundaryLog{}, "")
if err != nil {
t.Errorf("expected nil error for empty slice, got %v", err)
}
}

func TestSocketAuditor_Loop_SessionIDInBatch(t *testing.T) {
t.Parallel()

const wantSessionID = "test-boundary-session-uuid"

clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
})

logger := slog.New(slog.NewTextHandler(io.Discard, nil))
auditor := &SocketAuditor{
dial: func() (net.Conn, error) {
return clientConn, nil
},
logger: logger,
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour, // disable timer; flush via batch size
sessionID: wantSessionID,
}

cr := newConnReader()
go readFromConn(t, serverConn, cr)
go auditor.Loop(t.Context())

// Fill a full batch so it flushes immediately.
for i := 0; i < auditor.batchSize; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
}

select {
case req := <-cr.logs:
if req.SessionId != wantSessionID {
t.Errorf("expected session_id=%q, got %q", wantSessionID, req.SessionId)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for flush")
}
}

// setupSocketAuditor creates a SocketAuditor for tests that only exercise
// the queueing behavior (no connection needed).
func setupSocketAuditor(t *testing.T) *SocketAuditor {
Expand Down
36 changes: 36 additions & 0 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,40 @@ func BaseCommand(version string) *serpent.Command {
Value: &cliConfig.LogProxySocketPath,
YAML: "", // CLI only, not loaded from YAML
},
{
Flag: "session-id",
Env: "BOUNDARY_SESSION_ID",
Description: "Session ID to use for this boundary invocation. Generated as a UUID if not provided.",
Value: &cliConfig.SessionID,
YAML: "", // CLI only
},
{
Flag: "session-id-header",
Env: "BOUNDARY_SESSION_ID_HEADER",
Description: fmt.Sprintf("HTTP header name used to inject the session ID into forwarded requests. Default: %q.", config.DefaultSessionIDHeader),
Value: &cliConfig.SessionIDHeader,
YAML: "session_id_header",
},
{
Flag: "disable-session-id-header",
Env: "BOUNDARY_DISABLE_SESSION_ID_HEADER",
Description: "Disable injection of the session ID header on forwarded requests.",
Value: &cliConfig.DisableSessionIDHeader,
YAML: "disable_session_id_header",
},
{
Flag: "session-id-inject-domain",
Env: "BOUNDARY_SESSION_ID_INJECT_DOMAIN",
Description: "Match rule (repeatable) selecting which requests receive the session ID header. Merged with session_id_inject_domains from config file. Uses the same syntax as --allow (e.g. \"domain=dev.coder.com path=/api/v2/aibridge/*\"). If no rules are configured the header is never injected.",
Value: &cliConfig.SessionIDMatch,
YAML: "", // CLI only, not loaded from YAML
},
{
Flag: "", // No CLI flag, YAML only
Description: "Session ID match rules from config file (YAML only). Merged with --session-id-inject-domain CLI flags.",
Value: &cliConfig.SessionIDMatchList,
YAML: "session_id_inject_domains",
},
{
Flag: "version",
Description: "Print version information and exit.",
Expand Down Expand Up @@ -199,6 +233,8 @@ func BaseCommand(version string) *serpent.Command {
return fmt.Errorf("could not set up logging: %v", err)
}

logger.Info("boundary session started", "session_id", appConfig.SessionID)

appConfigInJSON, err := json.Marshal(appConfig)
if err != nil {
return err
Expand Down
Loading
Loading