Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package com.google.adk.agents;

import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;

import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
Expand All @@ -36,12 +38,17 @@
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.jspecify.annotations.Nullable;

/** Base class for all agents. */
public abstract class BaseAgent {

// Pattern for valid agent names.
private static final String IDENTIFIER_REGEX = "^_?[a-zA-Z0-9]*([. _-][a-zA-Z0-9]+)*$";
private static final Pattern IDENTIFIER_PATTERN = Pattern.compile(IDENTIFIER_REGEX);

/** The agent's name. Must be a unique identifier within the agent tree. */
private final String name;

Expand Down Expand Up @@ -79,6 +86,7 @@ public BaseAgent(
@Nullable List<? extends BaseAgent> subAgents,
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallback,
@Nullable List<? extends AfterAgentCallback> afterAgentCallback) {
validateAgentName(name);
this.name = name;
this.description = description;
this.parentAgent = null;
Expand All @@ -96,6 +104,20 @@ public BaseAgent(
}
}

private static void validateAgentName(String name) {
if (isNullOrEmpty(name)) {
throw new IllegalArgumentException("Agent name cannot be null or empty.");
}
if (!IDENTIFIER_PATTERN.matcher(name).matches()) {
throw new IllegalArgumentException(
format("Agent name '%s' does not match regex '%s'.", name, IDENTIFIER_REGEX));
}
if (name.equals("user")) {
throw new IllegalArgumentException(
"Agent name cannot be 'user'; reserved for end-user input.");
}
}

/**
* Gets the agent's unique name.
*
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/com/google/adk/agents/RunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ public abstract Builder setInputAudioTranscription(

public RunConfig build() {
RunConfig runConfig = autoBuild();
if (runConfig.maxLlmCalls() == Integer.MAX_VALUE) {
throw new IllegalArgumentException("maxLlmCalls should be less than Integer.MAX_VALUE.");
}
if (runConfig.maxLlmCalls() < 0) {
logger.warn(
"maxLlmCalls is negative. This will result in no enforcement on total"
Expand Down
15 changes: 15 additions & 0 deletions core/src/test/java/com/google/adk/agents/BaseAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.agents;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
Expand Down Expand Up @@ -336,4 +337,18 @@ public void runLive_invokesRunLiveImpl() {
assertThat(results.get(0).content()).hasValue(runLiveImplContent);
assertThat(runLiveCallback.wasCalled()).isTrue();
}

@Test
public void constructor_invalidName_throwsIllegalArgumentException() {
assertThrows(
IllegalArgumentException.class,
() -> new TestBaseAgent("invalid name?", "description", null, null, null));
}

@Test
public void constructor_userName_throwsIllegalArgumentException() {
assertThrows(
IllegalArgumentException.class,
() -> new TestBaseAgent("user", "description", null, null, null));
}
}
8 changes: 8 additions & 0 deletions core/src/test/java/com/google/adk/agents/RunConfigTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.agents;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.genai.types.AudioTranscriptionConfig;
Expand Down Expand Up @@ -114,4 +115,11 @@ public void testInputAudioTranscriptionOnly() {
assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.BIDI);
assertThat(runConfig.responseModalities()).containsExactly(new Modality(Modality.Known.AUDIO));
}

@Test
public void testMaxLlmCalls_integerMaxValue_throwsIllegalArgumentException() {
assertThrows(
IllegalArgumentException.class,
() -> RunConfig.builder().setMaxLlmCalls(Integer.MAX_VALUE).build());
}
}