diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index c8b1e07245..335d00477e 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -187,6 +187,7 @@ def begin_task( if nt := active_context.numtracker: nt.set_headers(pass_through_headers or {}) + subscribers = [] if tiled_config := active_context.tiled_conf: # Tiled queries the root node, so must create an authorized client if isinstance(tiled_config.authentication, ServiceAccount): @@ -204,6 +205,7 @@ def begin_task( tiled_writer_token = active_context.run_engine.subscribe( TiledWriter(tiled_client, batch_size=1) ) + subscribers.append((active_context.run_engine, tiled_writer_token)) def remove_callback_when_task_finished( event: WorkerEvent, correlation_id: str | None @@ -213,15 +215,21 @@ def remove_callback_when_task_finished( and event.task_status.task_id == task.task_id and event.task_status.task_complete ): - active_context.run_engine.unsubscribe(tiled_writer_token) - active_worker.worker_events.unsubscribe(remove_callback) + for channel, token in subscribers: + channel.unsubscribe(token) remove_callback = active_worker.worker_events.subscribe( remove_callback_when_task_finished ) + subscribers.append((active_worker.worker_events, remove_callback)) if task.task_id is not None: - active_worker.begin_task(task.task_id) + try: + active_worker.begin_task(task.task_id) + except KeyError: + for channel, token in subscribers: + channel.unsubscribe(token) + raise return task diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 3e075a7767..4763cebfac 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -617,3 +617,24 @@ def on_event(event: AnyEvent) -> None: assert outcome.result.message.startswith( "403: Access policy rejects the provided access blob" ) + + +# Regression test for #1480 +def test_task_submission_after_invalid_task(client_with_stomp: BlueapiClient): + with pytest.raises(KeyError): + # This task hasn't been submitted so should return an error... + client_with_stomp._rest.update_worker_task(WorkerTask(task_id="missing")) + + # ...but should leave the serve in a state where it can still run tasks + res = client_with_stomp.run_task( + TaskRequest( + name="count", + params={ + "detectors": [ + "det", + ], + }, + instrument_session=AUTHORIZED_INSTRUMENT_SESSION, + ) + ) + assert isinstance(res.result, TaskResult) diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index cc619259bd..ccef2f0971 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -242,6 +242,33 @@ def test_begin_task_no_task_id(worker_mock: MagicMock): worker_mock.assert_not_called() +@patch("blueapi.service.interface.from_uri") +@patch("blueapi.service.interface.config") +@patch("blueapi.service.interface.context") +@patch("blueapi.service.interface.worker") +def test_subscribers_removed_when_task_not_found( + worker_mock: MagicMock, + context_mock: MagicMock, + config_mock: MagicMock, + from_uri_mock: MagicMock, +): + # regression test for #1480 + worker = worker_mock() + ctx = context_mock() + worker.begin_task.side_effect = KeyError() + + with pytest.raises(KeyError): + interface.begin_task(WorkerTask(task_id="missing")) + + ctx.run_engine.subscribe.assert_called_once() + tiled_token = ctx.run_engine.subscribe() + ctx.run_engine.unsubscribe.assert_called_once_with(tiled_token) + + worker.worker_events.subscribe.assert_called_once() + remove_token = worker.worker_events.subscribe() + worker.worker_events.unsubscribe.assert_called_once_with(remove_token) + + @patch("blueapi.service.interface.TaskWorker.get_tasks_by_status") def test_get_tasks_by_status(get_tasks_by_status_mock: MagicMock): pending_task1 = TrackableTask(task_id="0", task=Task(name="pending_task1"))