Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def _get_pipeline_name_from_task_spec(
return spec.name or None


@dataclasses.dataclass(frozen=True, kw_only=True)
class ExecutionStatusSummary:
total_executions: int
ended_executions: int
has_ended: bool


# ==== PipelineJobService
@dataclasses.dataclass(kw_only=True)
class PipelineRunResponse:
Expand All @@ -64,6 +71,7 @@ class PipelineRunResponse:
created_at: datetime.datetime | None = None
pipeline_name: str | None = None
execution_status_stats: dict[str, int] | None = None
execution_summary: ExecutionStatusSummary | None = None

@classmethod
def from_db(cls, pipeline_run: bts.PipelineRun) -> "PipelineRunResponse":
Expand Down Expand Up @@ -266,21 +274,32 @@ def _create_pipeline_run_response(
)
response.pipeline_name = pipeline_name
if include_execution_stats:
response.execution_status_stats = self._get_execution_status_stats(
stats, summary = self._get_execution_stats_and_summary(
session=session,
root_execution_id=pipeline_run.root_execution_id,
)
response.execution_status_stats = stats
response.execution_summary = summary
return response

def _get_execution_status_stats(
def _get_execution_stats_and_summary(
self,
session: orm.Session,
root_execution_id: bts.IdType,
) -> dict[str, int]:
) -> tuple[dict[str, int], ExecutionStatusSummary]:
stats = self._calculate_execution_status_stats(
session=session, root_execution_id=root_execution_id
)
return {status.value: count for status, count in stats.items()}
total = sum(stats.values())
ended = sum(c for s, c in stats.items() if s in bts.CONTAINER_STATUSES_ENDED)
summary = ExecutionStatusSummary(
total_executions=total,
ended_executions=ended,
has_ended=(ended == total),
)
# e.g. {"SUCCEEDED": 3, "RUNNING": 1, "FAILED": 2}
status_stats = {s.value: c for s, c in stats.items()}
return status_stats, summary

def _calculate_execution_status_stats(
self, session: orm.Session, root_execution_id: bts.IdType
Expand Down Expand Up @@ -477,22 +496,6 @@ class ArtifactNodeIdResponse:
id: bts.IdType


@dataclasses.dataclass(kw_only=True)
class ExecutionStatusSummary:
total_executions: int = 0
ended_executions: int = 0
has_ended: bool = False

def count_execution_status(
self, *, status: bts.ContainerExecutionStatus, count: int
) -> None:
self.total_executions += count
if status in bts.CONTAINER_STATUSES_ENDED:
self.ended_executions += count

self.has_ended = self.ended_executions == self.total_executions


@dataclasses.dataclass
class GetGraphExecutionStateResponse:
child_execution_status_stats: dict[bts.IdType, dict[str, int]]
Expand Down
98 changes: 49 additions & 49 deletions tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,52 @@


class TestExecutionStatusSummary:
def test_initial_state(self):
summary = api_server_sql.ExecutionStatusSummary()
assert summary.total_executions == 0
assert summary.ended_executions == 0
assert summary.has_ended is False

def test_accumulate_all_ended_statuses(self):
"""Add each ended status with 2^i count for robust uniqueness."""
summary = api_server_sql.ExecutionStatusSummary()
ended_statuses = sorted(bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value)
expected_total = 0
expected_ended = 0
for i, status in enumerate(ended_statuses):
count = 2**i
summary.count_execution_status(status=status, count=count)
expected_total += count
expected_ended += count
assert summary.total_executions == expected_total
assert summary.ended_executions == expected_ended
assert summary.has_ended is True

def test_accumulate_all_in_progress_statuses(self):
"""Add each in-progress status with 2^i count for robust uniqueness."""
summary = api_server_sql.ExecutionStatusSummary()
in_progress_statuses = sorted(
def test_all_ended_statuses(self) -> None:
stats = {
status: 2**i
for i, status in enumerate(
sorted(bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value)
)
}
total = sum(stats.values())
summary = api_server_sql.ExecutionStatusSummary(
total_executions=total,
ended_executions=total,
has_ended=True,
)
assert summary.total_executions == total
assert summary.ended_executions == total
assert summary.has_ended is True

def test_all_in_progress_statuses(self) -> None:
in_progress = sorted(
set(bts.ContainerExecutionStatus) - bts.CONTAINER_STATUSES_ENDED,
key=lambda s: s.value,
)
expected_total = 0
for i, status in enumerate(in_progress_statuses):
count = 2**i
summary.count_execution_status(status=status, count=count)
expected_total += count
assert summary.total_executions == expected_total
assert summary.ended_executions == 0
assert summary.has_ended is False
stats = {status: 2**i for i, status in enumerate(in_progress)}
total = sum(stats.values())
summary = api_server_sql.ExecutionStatusSummary(
total_executions=total,
ended_executions=0,
has_ended=False,
)
assert summary.total_executions == total
assert summary.ended_executions == 0
assert summary.has_ended is False

def test_accumulate_all_statuses(self):
"""Add every status with 2^i count. Summary math must be exact."""
summary = api_server_sql.ExecutionStatusSummary()
def test_mixed_statuses(self) -> None:
all_statuses = sorted(bts.ContainerExecutionStatus, key=lambda s: s.value)
expected_total = 0
expected_ended = 0
for i, status in enumerate(all_statuses):
count = 2**i
expected_total += count
if status in bts.CONTAINER_STATUSES_ENDED:
expected_ended += count
summary.count_execution_status(status=status, count=count)
assert summary.total_executions == expected_total
assert summary.ended_executions == expected_ended
assert summary.has_ended == (expected_ended == expected_total)
stats = {status: 2**i for i, status in enumerate(all_statuses)}
total = sum(stats.values())
ended = sum(c for s, c in stats.items() if s in bts.CONTAINER_STATUSES_ENDED)
summary = api_server_sql.ExecutionStatusSummary(
total_executions=total,
ended_executions=ended,
has_ended=(ended == total),
)
assert summary.total_executions == total
assert summary.ended_executions == ended
assert summary.has_ended is False


def _make_task_spec(pipeline_name: str = "test-pipeline") -> structures.TaskSpec:
Expand Down Expand Up @@ -1735,8 +1729,14 @@ def test_list_with_execution_stats(self) -> None:
with session_factory() as session:
result = service.list(session=session, include_execution_stats=True)
assert len(result.pipeline_runs) == 1
assert result.pipeline_runs[0].root_execution_id == root_id
stats = result.pipeline_runs[0].execution_status_stats
run = result.pipeline_runs[0]
assert run.root_execution_id == root_id
stats = run.execution_status_stats
assert stats is not None
assert stats["SUCCEEDED"] == 1
assert stats["RUNNING"] == 1
summary = run.execution_summary
assert summary is not None
assert summary.total_executions == 2
assert summary.ended_executions == 1
assert summary.has_ended is False