Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/uipath-platform/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-platform"
version = "0.1.29"
version = "0.1.30"
description = "HTTP client library for programmatic access to UiPath Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import sqlparse
from httpx import Response
from sqlparse.sql import Parenthesis, Where
from sqlparse.tokens import DML, Keyword, Wildcard
from sqlparse.sql import Function, Identifier, IdentifierList, Parenthesis, Where
from sqlparse.tokens import DML, Keyword, Whitespace, Wildcard
from uipath.core.tracing import traced

from ..common._base_service import BaseService
Expand Down Expand Up @@ -49,6 +49,7 @@
"GROUPING",
"PARTITION",
]
_AGGREGATE_FUNCTIONS = ("COUNT", "SUM", "AVG", "MIN", "MAX")


class EntitiesService(BaseService):
Expand Down Expand Up @@ -177,6 +178,7 @@ def retrieve_by_name(
spec = self._retrieve_by_name_spec(entity_name)
headers = self._folder_key_headers(folder_key)
response = self.request(spec.method, spec.endpoint, headers=headers)

return Entity.model_validate(response.json())

@traced(name="entity_retrieve_by_name", run_type="uipath")
Expand All @@ -196,6 +198,7 @@ async def retrieve_by_name_async(
spec = self._retrieve_by_name_spec(entity_name)
headers = self._folder_key_headers(folder_key)
response = await self.request_async(spec.method, spec.endpoint, headers=headers)

return Entity.model_validate(response.json())

@traced(name="list_entities", run_type="uipath")
Expand Down Expand Up @@ -1333,18 +1336,91 @@ def _validate_sql_query(self, sql_query: str) -> None:

has_where = any(isinstance(t, Where) for t in stmt.tokens)
has_limit = "LIMIT" in keywords
if not has_where and not has_limit:
raise ValueError("Queries without WHERE must include a LIMIT clause.")
has_from = "FROM" in keywords

if not has_from:
raise ValueError("Queries must include a FROM clause.")

projection = self._projection_tokens(stmt)
has_wildcard = any(t.ttype is Wildcard for t in projection)
if has_wildcard and not has_where:
raise ValueError("SELECT * without filtering is not allowed.")

if self._projection_has_count_star(projection):
raise ValueError(
"COUNT(*) is not supported. Use COUNT(column_name) instead."
)

has_aggregate = self._projection_has_aggregate(projection)

if not has_where and not has_limit and not has_aggregate:
raise ValueError("Queries without WHERE must include a LIMIT clause.")

has_bare_wildcard = self._projection_has_bare_wildcard(projection)
if has_bare_wildcard:
raise ValueError("SELECT * is not allowed. Specify column names instead.")
if not has_where and self._projection_column_count(projection) > 4:
raise ValueError(
"Selecting more than 4 columns without filtering is not allowed."
)

@staticmethod
def _projection_has_aggregate(
projection: list[sqlparse.sql.Token],
) -> bool:
"""Check whether the projection contains an aggregate function call."""
for node in projection:
if isinstance(node, Function):
if node.get_name().upper() in _AGGREGATE_FUNCTIONS:
return True
if isinstance(node, (Identifier, IdentifierList)):
for child in node.tokens:
if isinstance(child, Function):
if child.get_name().upper() in _AGGREGATE_FUNCTIONS:
return True
return False

@staticmethod
def _projection_has_count_star(
projection: list[sqlparse.sql.Token],
) -> bool:
"""Check whether projection contains COUNT(*)."""

def _is_count_star(func: Function) -> bool:
if func.get_name().upper() != "COUNT":
return False
return any(t.ttype is Wildcard for t in func.flatten())

for node in projection:
if isinstance(node, Function) and _is_count_star(node):
return True
if isinstance(node, (Identifier, IdentifierList)):
for child in node.tokens:
if isinstance(child, Function) and _is_count_star(child):
return True
return False

@staticmethod
def _projection_has_bare_wildcard(
projection: list[sqlparse.sql.Token],
) -> bool:
"""Check for a bare ``*`` or qualified ``table.*`` outside a function."""

def _identifier_has_wildcard(ident: Identifier) -> bool:
return any(t.ttype is Wildcard for t in ident.tokens)

for node in projection:
if node.ttype is Wildcard:
return True
if isinstance(node, Identifier) and _identifier_has_wildcard(node):
return True
if isinstance(node, IdentifierList):
for child in node.tokens:
if child.ttype is Wildcard:
return True
Comment on lines +1414 to +1417
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Treat qualified table.* as disallowed wildcard

_projection_has_bare_wildcard only checks tokens whose type is Wildcard directly or inside an IdentifierList, but it never inspects Identifier nodes. In sqlparse, table.* is represented as an Identifier, so queries like SELECT Customers.* FROM Customers LIMIT 10 bypass the SELECT * without filtering guard and also slip past the >4 columns limit because _projection_column_count treats that projection as one column. This reopens unrestricted full-row reads that the validator is meant to block.

Useful? React with 👍 / 👎.

if isinstance(child, Identifier) and _identifier_has_wildcard(
child
):
return True
return False

@staticmethod
def _has_subquery(stmt: sqlparse.sql.Statement) -> bool:
"""Recursively walk the AST looking for SELECT inside parentheses."""
Expand All @@ -1369,27 +1445,33 @@ def _walk(token: sqlparse.sql.Token) -> bool:
def _projection_tokens(
stmt: sqlparse.sql.Statement,
) -> list[sqlparse.sql.Token]:
"""Extract tokens between the first SELECT and FROM."""
"""Extract non-flattened AST nodes between the first SELECT and FROM."""
tokens: list[sqlparse.sql.Token] = []
collecting = False
for token in stmt.flatten():
for token in stmt.tokens:
if token.ttype is DML and token.normalized == "SELECT":
collecting = True
continue
if token.ttype is Keyword and token.normalized == "FROM":
if token.ttype is Keyword and token.normalized in ("FROM", "INTO"):
break
if collecting:
if token.ttype is Keyword and token.normalized == "DISTINCT":
continue
if collecting and token.ttype is not Whitespace:
tokens.append(token)
return tokens

@staticmethod
def _projection_column_count(
projection: list[sqlparse.sql.Token],
) -> int:
text = "".join(t.value for t in projection).strip()
if not text:
return 0
return len([part for part in text.split(",") if part.strip()])
for node in projection:
if isinstance(node, IdentifierList):
return len(list(node.get_identifiers()))
if isinstance(node, (Identifier, Function)):
return 1
if node.ttype is Wildcard:
return 1
return 0


# Resolve the forward reference to EntitiesService in EntitySetResolution.
Expand Down
38 changes: 25 additions & 13 deletions packages/uipath-platform/src/uipath/platform/entities/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_origin,
)

from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, create_model

if TYPE_CHECKING:
from ._entities_service import EntitiesService
Expand Down Expand Up @@ -140,7 +140,7 @@ class FieldMetadata(BaseModel):
reference_field: Optional["EntityField"] = Field(
default=None, alias="referenceField"
)
reference_type: ReferenceType = Field(alias="referenceType")
reference_type: Optional[ReferenceType] = Field(default=None, alias="referenceType")
sql_type: "FieldDataType" = Field(alias="sqlType")
is_required: bool = Field(alias="isRequired")
display_name: str = Field(alias="displayName")
Expand Down Expand Up @@ -212,14 +212,21 @@ class SourceJoinCriteria(BaseModel):
model_config = ConfigDict(
validate_by_name=True,
validate_by_alias=True,
extra="allow",
)
id: Optional[str] = None
entity_id: Optional[str] = Field(default=None, alias="entityId")
join_field_name: Optional[str] = Field(default=None, alias="joinFieldName")
join_type: Optional[str] = Field(default=None, alias="joinType")
related_source_object_id: Optional[str] = Field(
default=None, alias="relatedSourceObjectId"
)
related_source_object_field_name: Optional[str] = Field(
default=None, alias="relatedSourceObjectFieldName"
)
related_source_field_name: Optional[str] = Field(
default=None, alias="relatedSourceFieldName"
)
id: str
entity_id: str = Field(alias="entityId")
join_field_name: str = Field(alias="joinFieldName")
join_type: str = Field(alias="joinType")
related_source_object_id: str = Field(alias="relatedSourceObjectId")
related_source_object_field_name: str = Field(alias="relatedSourceObjectFieldName")
related_source_field_name: str = Field(alias="relatedSourceFieldName")


class ChoiceSetValue(BaseModel):
Expand Down Expand Up @@ -326,11 +333,16 @@ class Entity(BaseModel):
entity_type: str = Field(alias="entityType")
description: Optional[str] = Field(default=None, alias="description")
fields: Optional[List[FieldMetadata]] = Field(default=None, alias="fields")
external_fields: Optional[List[ExternalSourceFields]] = Field(
default=None, alias="externalFields"
external_fields: Optional[
List[ExternalField | ExternalSourceFields | Dict[str, Any]]
] = Field(
default=None,
alias="externalFields",
)
source_join_criteria: Optional[List[SourceJoinCriteria]] = Field(
default=None, alias="sourceJoinCriteria"
source_join_criteria: Optional[List[SourceJoinCriteria | Dict[str, Any]]] = Field(
default=None,
validation_alias=AliasChoices("sourceJoinCriteria", "sourceJoinCriterias"),
alias="sourceJoinCriteria",
)
record_count: Optional[int] = Field(default=None, alias="recordCount")
storage_size_in_mb: Optional[float] = Field(default=None, alias="storageSizeInMB")
Expand Down
44 changes: 42 additions & 2 deletions packages/uipath-platform/tests/services/test_entities_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ def test_retrieve_records_without_start_and_limit(
[
"SELECT id FROM Customers WHERE id = 1",
"SELECT id, name FROM Customers LIMIT 10",
"SELECT * FROM Customers WHERE status = 'Active'",
"SELECT COUNT(id) FROM Customers",
"SELECT SUM(amount) FROM Orders",
"SELECT AVG(price) FROM Products",
"SELECT MIN(created), MAX(created) FROM Events",
"SELECT COUNT(id), name FROM Customers LIMIT 10",
"SELECT id, name, email, phone FROM Customers LIMIT 5",
"SELECT DISTINCT id FROM Customers WHERE id > 100",
"SELECT id FROM Customers WHERE name = 'foo;bar'",
Expand Down Expand Up @@ -356,9 +360,45 @@ def test_validate_sql_query_allows_supported_select_queries(
"SELECT id FROM Customers",
"Queries without WHERE must include a LIMIT clause.",
),
(
"SELECT UPPER(name) FROM Customers",
"Queries without WHERE must include a LIMIT clause.",
),
(
"SELECT COALESCE(name, 'N/A') FROM Customers",
"Queries without WHERE must include a LIMIT clause.",
),
(
"SELECT 1 LIMIT 1",
"Queries must include a FROM clause.",
),
(
"SELECT COUNT(*) FROM Customers",
"COUNT(*) is not supported. Use COUNT(column_name) instead.",
),
(
"SELECT COUNT(*), name FROM Customers LIMIT 10",
"COUNT(*) is not supported. Use COUNT(column_name) instead.",
),
(
"SELECT * FROM Customers LIMIT 10",
"SELECT * without filtering is not allowed.",
"SELECT * is not allowed. Specify column names instead.",
),
(
"SELECT Customers.* FROM Customers LIMIT 10",
"SELECT * is not allowed. Specify column names instead.",
),
(
"SELECT t.* FROM Customers t LIMIT 10",
"SELECT * is not allowed. Specify column names instead.",
),
(
"SELECT * FROM Customers WHERE status = 'Active'",
"SELECT * is not allowed. Specify column names instead.",
),
(
"SELECT Customers.* FROM Customers WHERE status = 'Active'",
"SELECT * is not allowed. Specify column names instead.",
),
(
"SELECT id, name, email, phone, address FROM Customers LIMIT 10",
Expand Down
2 changes: 1 addition & 1 deletion packages/uipath-platform/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/uipath/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading