Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)
reader := bufio.NewReader(r)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
Expand All @@ -94,31 +92,43 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt Event
dataBuf *bytes.Buffer // if non-nil, preceding field was also data
)
flushData := func() {
yieldEvent := func() bool {
if dataBuf != nil {
evt.Data = dataBuf.Bytes()
dataBuf = nil
}
if evt.Empty() {
return true
}
if !yield(evt, nil) {
return false
}
evt = Event{}
return true
}
for scanner.Scan() {
line := scanner.Bytes()
for {
line, err := reader.ReadBytes('\n')
if err != nil && !errors.Is(err, io.EOF) {
yield(Event{}, fmt.Errorf("error reading event: %v", err))
return
}
line = bytes.TrimRight(line, "\r\n")
isEOF := errors.Is(err, io.EOF)

if len(line) == 0 {
flushData()
// \n\n is the record delimiter
if !evt.Empty() && !yield(evt, nil) {
if !yieldEvent() {
return
}
if isEOF {
return
}
evt = Event{}
continue
}
before, after, found := bytes.Cut(line, []byte{':'})
if !found {
yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line)))
yield(Event{}, fmt.Errorf("%w: malformed line in SSE stream: %q", errMalformedEvent, string(line)))
return
}
if !bytes.Equal(before, dataKey) {
flushData()
}
switch {
case bytes.Equal(before, eventKey):
evt.Name = strings.TrimSpace(string(after))
Expand All @@ -128,27 +138,19 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt.Retry = strings.TrimSpace(string(after))
case bytes.Equal(before, dataKey):
data := bytes.TrimSpace(after)
if dataBuf != nil {
dataBuf.WriteByte('\n')
dataBuf.Write(data)
} else {
if dataBuf == nil {
dataBuf = new(bytes.Buffer)
dataBuf.Write(data)
} else {
dataBuf.WriteByte('\n')
}
dataBuf.Write(data)
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
}
if !yield(Event{}, err) {

if isEOF {
yieldEvent()
return
}
}
flushData()
if !evt.Empty() {
yield(evt, nil)
}
}
}

Expand Down Expand Up @@ -310,6 +312,11 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string,
// index is no longer available.
var ErrEventsPurged = errors.New("data purged")

// errMalformedEvent is returned when an SSE event cannot be parsed due to format violations.
// This is a hard error indicating corrupted data or protocol violations, as opposed to
// transient I/O errors which may be retryable.
var errMalformedEvent = errors.New("malformed event")

// After implements [EventStore.After].
func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] {
// Return the data items to yield.
Expand Down
72 changes: 72 additions & 0 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,78 @@ func TestScanEvents(t *testing.T) {
input: "invalid line\n\n",
wantErr: "malformed line",
},
{
name: "message with 2 data lines and another event",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\n\nevent:keepalive",
want: []Event{
{Name: "message", Data: []byte("hello\nhello\nhello")},
{Name: "keepalive"},
},
},
{
name: "event with multiple lines",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\nid:1",
want: []Event{
{Name: "message", ID: "1", Data: []byte("hello\nhello\nhello")},
},
},
{
name: "multiple events, out of order keys",
input: strings.Join([]string{
"event:message",
"data: hello0",
"\n",
"data: hello1",
"data: hello1",
"id:1",
"event:message",
"\n",
"event:message",
"data: hello3",
"data: hello3",
"id:3",
"\n",
"data: hello4",
"data: hello4",
"id:4",
"event:message",
}, "\n"),
want: []Event{
{Name: "message", Data: []byte("hello0")},
{Name: "message", ID: "1", Data: []byte("hello1\nhello1")},
{Name: "message", ID: "3", Data: []byte("hello3\nhello3")},
{Name: "message", ID: "4", Data: []byte("hello4\nhello4")},
},
},
{
name: "non-continuous data items in the event",
input: "event: foo\ndata: 123\nretry: 5\ndata: 456",
want: []Event{
{Name: "foo", Data: []byte("123\n456"), Retry: "5"},
},
},
{
name: "no-data events",
input: "event: foo\n\nevent: bar",
want: []Event{
{Name: "foo"},
{Name: "bar"},
},
},
{
name: "empty data event",
input: "event: foo\ndata:\n\nevent: bar",
want: []Event{
{Name: "foo"},
{Name: "bar"},
},
},
{

name: "malformed data event",
input: "someline",
wantErr: "malformed event",
},
}

for _, tt := range tests {
Expand Down
3 changes: 2 additions & 1 deletion mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

if req.Method != http.MethodGet {
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
w.Header().Set("Allow", "GET, POST")
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}

Expand Down
8 changes: 8 additions & 0 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,14 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}

// Malformed events are hard errors that indicate corrupted data or protocol
// violations. These should fail the connection permanently.
if errors.Is(err, errMalformedEvent) {
c.fail(fmt.Errorf("%s: %v", requestSummary, err))
return "", 0, true
}

break
}

Expand Down
Loading