Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ def test_assess_tuning_resources(client):
assert isinstance(response, types.TuningResourceUsageAssessmentResult)


def test_assess_tuning_validity(client):
response = client.datasets.assess_tuning_validity(
dataset_name=DATASET,
dataset_usage="SFT_VALIDATION",
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.TuningValidationAssessmentResult)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand All @@ -88,7 +111,7 @@ async def test_assess_dataset_async(client):
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
}
},
],
),
),
Expand All @@ -114,3 +137,27 @@ async def test_assess_tuning_resources_async(client):
),
)
assert isinstance(response, types.TuningResourceUsageAssessmentResult)


@pytest.mark.asyncio
async def test_assess_tuning_validity_async(client):
response = await client.aio.datasets.assess_tuning_validity(
dataset_name=DATASET,
dataset_usage="SFT_VALIDATION",
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.TuningValidationAssessmentResult)
124 changes: 124 additions & 0 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,68 @@ def assess_tuning_resources(
response["tuningResourceUsageAssessmentResult"],
)

def assess_tuning_validity(
self,
*,
dataset_name: str,
model_name: str,
dataset_usage: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.TuningValidationAssessmentResult:
"""Assess if the assembled dataset is valid in terms of tuning a given
model.

Args:
dataset_name:
Required. The name of the dataset to assess the tuning validity
for.
model_name:
Required. The name of the model to assess the tuning validity
for.
dataset_usage:
Required. The dataset usage to assess the tuning validity for.
Must be one of the following: SFT_TRAINING, SFT_VALIDATION.
template_config:
Optional. The template config used to assemble the dataset
before assessing the tuning validity. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the tuning validity. If not
provided, the default configuration will be used.

Returns:
A dict containing the tuning validity assessment result. The dict
contains the following keys:
- errors: A list of errors that occurred during the tuning validity
assessment.
"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = self._assess_multimodal_dataset(
name=dataset_name,
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
model_name=model_name,
dataset_usage=dataset_usage,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
)


class AsyncDatasets(_api_module.BaseModule):

Expand Down Expand Up @@ -1875,3 +1937,65 @@ async def assess_tuning_resources(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
)

async def assess_tuning_validity(
self,
*,
dataset_name: str,
model_name: str,
dataset_usage: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.TuningValidationAssessmentResult:
"""Assess if the assembled dataset is valid in terms of tuning a given
model.

Args:
dataset_name:
Required. The name of the dataset to assess the tuning validity
for.
model_name:
Required. The name of the model to assess the tuning validity
for.
dataset_usage:
Required. The dataset usage to assess the tuning validity for.
Must be one of the following: SFT_TRAINING, SFT_VALIDATION.
template_config:
Optional. The template config used to assemble the dataset
before assessing the tuning validity. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the tuning validity. If not
provided, the default configuration will be used.

Returns:
A dict containing the tuning validity assessment result. The dict
contains the following keys:
- errors: A list of errors that occurred during the tuning validity
assessment.
"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = await self._assess_multimodal_dataset(
name=dataset_name,
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
model_name=model_name,
dataset_usage=dataset_usage,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = await self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
)
Loading