diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 61a742668a..e26ab388ae 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -64,7 +64,7 @@ def __stream__(self) -> Iterator[_T]: if sse.event and sse.event.startswith("thread."): data = sse.json() - if sse.event == "error" and is_mapping(data) and data.get("error"): + if sse.event == "thread.error" and is_mapping(data) and data.get("error"): message = None error = data.get("error") if is_mapping(error): @@ -167,7 +167,7 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.event and sse.event.startswith("thread."): data = sse.json() - if sse.event == "error" and is_mapping(data) and data.get("error"): + if sse.event == "thread.error" and is_mapping(data) and data.get("error"): message = None error = data.get("error") if is_mapping(error): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 04f8e51abd..f21238cbeb 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -7,6 +7,7 @@ from openai import OpenAI, AsyncOpenAI from openai._streaming import Stream, AsyncStream, ServerSentEvent +from openai._exceptions import APIError @pytest.mark.asyncio @@ -216,6 +217,66 @@ def body() -> Iterator[bytes]: assert sse.json() == {"content": "известни"} +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_thread_error_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + """Test that thread.error events are properly handled and raise APIError.""" + def body() -> Iterator[bytes]: + yield b"event: thread.error\n" + yield b'data: {"error": {"message": "Something went wrong"}}\n' + yield b"\n" + + # Create a proper request object for the response + request = httpx.Request("POST", "https://api.openai.com/v1/test") + + if sync: + response = httpx.Response(200, content=body(), request=request) + stream = Stream(cast_to=object, client=client, response=response) + + with pytest.raises(APIError) as exc_info: + next(iter(stream)) + + assert "Something went wrong" in str(exc_info.value) + else: + response = httpx.Response(200, content=to_aiter(body()), request=request) + stream = AsyncStream(cast_to=object, client=async_client, response=response) + + with pytest.raises(APIError) as exc_info: + await stream.__anext__() + + assert "Something went wrong" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_thread_non_error_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + """Test that thread.* events (non-error) are processed correctly.""" + def body() -> Iterator[bytes]: + yield b"event: thread.run.created\n" + yield b'data: {"id": "run_123", "status": "in_progress"}\n' + yield b"\n" + yield b"data: [DONE]\n" + yield b"\n" + + # Create a proper request object for the response + request = httpx.Request("POST", "https://api.openai.com/v1/test") + + if sync: + response = httpx.Response(200, content=body(), request=request) + stream = Stream(cast_to=object, client=client, response=response) + + result = next(iter(stream)) + assert result["event"] == "thread.run.created" + assert result["data"]["id"] == "run_123" + else: + response = httpx.Response(200, content=to_aiter(body()), request=request) + stream = AsyncStream(cast_to=object, client=async_client, response=response) + + result = await stream.__anext__() + assert result["event"] == "thread.run.created" + assert result["data"]["id"] == "run_123" + + async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: for chunk in iter: yield chunk @@ -246,3 +307,19 @@ def make_event_iterator( return AsyncStream( cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) )._iter_events() + + +def make_stream_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: OpenAI, + async_client: AsyncOpenAI, +) -> Iterator[object] | AsyncIterator[object]: + """Create a Stream or AsyncStream iterator for testing the full stream processing.""" + if sync: + return iter(Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))) + + return AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) + ).__aiter__()