Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def begin_task(
if nt := active_context.numtracker:
nt.set_headers(pass_through_headers or {})

subscribers = []
if tiled_config := active_context.tiled_conf:
# Tiled queries the root node, so must create an authorized client
if isinstance(tiled_config.authentication, ServiceAccount):
Expand All @@ -204,6 +205,7 @@ def begin_task(
tiled_writer_token = active_context.run_engine.subscribe(
TiledWriter(tiled_client, batch_size=1)
)
subscribers.append((active_context.run_engine, tiled_writer_token))

def remove_callback_when_task_finished(
event: WorkerEvent, correlation_id: str | None
Expand All @@ -213,15 +215,21 @@ def remove_callback_when_task_finished(
and event.task_status.task_id == task.task_id
and event.task_status.task_complete
):
active_context.run_engine.unsubscribe(tiled_writer_token)
active_worker.worker_events.unsubscribe(remove_callback)
for ch, token in subscribers:
ch.unsubscribe(token)

remove_callback = active_worker.worker_events.subscribe(
remove_callback_when_task_finished
)
subscribers.append((active_worker.worker_events, remove_callback))

if task.task_id is not None:
active_worker.begin_task(task.task_id)
try:
active_worker.begin_task(task.task_id)
except KeyError:
for ch, token in subscribers:
ch.unsubscribe(token)
raise
return task


Expand Down
21 changes: 21 additions & 0 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,24 @@ def on_event(event: AnyEvent) -> None:
assert outcome.result.message.startswith(
"403: Access policy rejects the provided access blob"
)


# Regression test for #1480
def test_task_submission_after_invalid_task(client_with_stomp: BlueapiClient):
with pytest.raises(KeyError):
# This task hasn't been submitted so should return an error...
client_with_stomp._rest.update_worker_task(WorkerTask(task_id="missing"))

# ...but should leave the serve in a state where it can still run tasks
res = client_with_stomp.run_task(
TaskRequest(
name="count",
params={
"detectors": [
"det",
],
},
instrument_session=AUTHORIZED_INSTRUMENT_SESSION,
)
)
assert isinstance(res.result, TaskResult)
27 changes: 27 additions & 0 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,33 @@ def test_begin_task_no_task_id(worker_mock: MagicMock):
worker_mock.assert_not_called()


@patch("blueapi.service.interface.from_uri")
@patch("blueapi.service.interface.config")
@patch("blueapi.service.interface.context")
@patch("blueapi.service.interface.worker")
def test_subscribers_removed_when_task_not_found(
worker_mock: MagicMock,
context_mock: MagicMock,
config_mock: MagicMock,
from_uri_mock: MagicMock,
):
# regression test for #1480
worker = worker_mock()
ctx = context_mock()
worker.begin_task.side_effect = KeyError()

with pytest.raises(KeyError):
interface.begin_task(WorkerTask(task_id="missing"))

ctx.run_engine.subscribe.assert_called_once()
tiled_token = ctx.run_engine.subscribe()
ctx.run_engine.unsubscribe.assert_called_once_with(tiled_token)

worker.worker_events.subscribe.assert_called_once()
remove_token = worker.worker_events.subscribe()
worker.worker_events.unsubscribe.assert_called_once_with(remove_token)


@patch("blueapi.service.interface.TaskWorker.get_tasks_by_status")
def test_get_tasks_by_status(get_tasks_by_status_mock: MagicMock):
pending_task1 = TrackableTask(task_id="0", task=Task(name="pending_task1"))
Expand Down
Loading