Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ public interface BaseArtifactService {
Single<Integer> saveArtifact(
String appName, String userId, String sessionId, String filename, Part artifact);

/**
* Saves an artifact and returns it with fileData if available.
*
* <p>Implementations should override this default method for efficiency, as the default performs
* two I/O operations (save then load).
*
* @param appName the app name
* @param userId the user ID
* @param sessionId the session ID
* @param filename the filename
* @param artifact the artifact to save
* @return the saved artifact with fileData if available.
*/
default Single<Part> saveAndReloadArtifact(
String appName, String userId, String sessionId, String filename, Part artifact) {
return saveArtifact(appName, userId, sessionId, filename, artifact)
.flatMap(
version ->
loadArtifact(appName, userId, sessionId, filename, Optional.of(version))
.toSingle());
}

/**
* Gets an artifact.
*
Expand Down
103 changes: 75 additions & 28 deletions core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static java.util.Collections.max;

import com.google.auto.value.AutoValue;
import com.google.cloud.storage.Blob;
import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.BlobInfo;
Expand All @@ -27,6 +28,7 @@
import com.google.common.base.Splitter;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.FileData;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
Expand Down Expand Up @@ -108,34 +110,8 @@ private String getBlobName(
@Override
public Single<Integer> saveArtifact(
String appName, String userId, String sessionId, String filename, Part artifact) {
return listVersions(appName, userId, sessionId, filename)
.map(versions -> versions.isEmpty() ? 0 : max(versions) + 1)
.map(
nextVersion -> {
String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion);
BlobId blobId = BlobId.of(bucketName, blobName);

BlobInfo blobInfo =
BlobInfo.newBuilder(blobId)
.setContentType(artifact.inlineData().get().mimeType().orElse(null))
.build();

try {
byte[] dataToSave =
artifact
.inlineData()
.get()
.data()
.orElseThrow(
() ->
new IllegalArgumentException(
"Saveable artifact data must be non-empty."));
storageClient.create(blobInfo, dataToSave);
return nextVersion;
} catch (StorageException e) {
throw new VerifyException("Failed to save artifact to GCS", e);
}
});
return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact)
.map(SaveResult::version);
}

/**
Expand Down Expand Up @@ -275,4 +251,75 @@ public Single<ImmutableList<Integer>> listVersions(
return Single.just(ImmutableList.of());
}
}

@Override
public Single<Part> saveAndReloadArtifact(
String appName, String userId, String sessionId, String filename, Part artifact) {
return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact)
.flatMap(
blob -> {
Blob savedBlob = blob.blob();
String resultMimeType =
Optional.ofNullable(savedBlob.getContentType())
.or(
() ->
artifact.inlineData().flatMap(com.google.genai.types.Blob::mimeType))
.orElse("application/octet-stream");
return Single.just(
Part.builder()
.fileData(
FileData.builder()
.fileUri("gs://" + savedBlob.getBucket() + "/" + savedBlob.getName())
.mimeType(resultMimeType)
.build())
.build());
});
}

@AutoValue
abstract static class SaveResult {
static SaveResult create(Blob blob, int version) {
return new AutoValue_GcsArtifactService_SaveResult(blob, version);
}

abstract Blob blob();

abstract int version();
}

private Single<SaveResult> saveArtifactAndReturnBlob(
String appName, String userId, String sessionId, String filename, Part artifact) {
return listVersions(appName, userId, sessionId, filename)
.map(versions -> versions.isEmpty() ? 0 : max(versions) + 1)
.map(
nextVersion -> {
if (artifact.inlineData().isEmpty()) {
throw new IllegalArgumentException("Saveable artifact must have inline data.");
}

String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion);
BlobId blobId = BlobId.of(bucketName, blobName);

BlobInfo blobInfo =
BlobInfo.newBuilder(blobId)
.setContentType(artifact.inlineData().get().mimeType().orElse(null))
.build();

try {
byte[] dataToSave =
artifact
.inlineData()
.get()
.data()
.orElseThrow(
() ->
new IllegalArgumentException(
"Saveable artifact data must be non-empty."));
Blob blob = storageClient.create(blobInfo, dataToSave);
return SaveResult.create(blob, nextVersion);
} catch (StorageException e) {
throw new VerifyException("Failed to save artifact to GCS", e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ public Single<ImmutableList<Integer>> listVersions(
return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList()));
}

@Override
public Single<Part> saveAndReloadArtifact(
String appName, String userId, String sessionId, String filename, Part artifact) {
return saveArtifact(appName, userId, sessionId, filename, artifact)
.flatMap(
version ->
loadArtifact(appName, userId, sessionId, filename, Optional.of(version))
.toSingle());
}

private Map<String, List<Part>> getArtifactsMap(String appName, String userId, String sessionId) {
return artifacts
.computeIfAbsent(appName, unused -> new HashMap<>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -76,6 +77,7 @@ private Blob mockBlob(String name, String contentType, byte[] content) {
when(blob.exists()).thenReturn(true);
BlobId blobId = BlobId.of(BUCKET_NAME, name);
when(blob.getBlobId()).thenReturn(blobId);
when(blob.getBucket()).thenReturn(BUCKET_NAME);
return blob;
}

Expand All @@ -89,6 +91,8 @@ public void save_firstVersion_savesCorrectly() {
BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build();

when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of());
Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3});
when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob);

int version =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet();
Expand All @@ -109,6 +113,8 @@ public void save_subsequentVersion_savesCorrectly() {

Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1});
when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0));
Blob savedBlob = mockBlob(expectedBlobNameV1, "image/png", new byte[] {4, 5});
when(mockStorage.create(eq(expectedBlobInfoV1), eq(new byte[] {4, 5}))).thenReturn(savedBlob);

int version =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet();
Expand All @@ -126,6 +132,8 @@ public void save_userNamespace_savesCorrectly() {
BlobInfo.newBuilder(expectedBlobId).setContentType("application/json").build();

when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of());
Blob savedBlob = mockBlob(expectedBlobName, "application/json", new byte[] {1, 2, 3});
when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob);

int version =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, artifact).blockingGet();
Expand Down Expand Up @@ -330,7 +338,36 @@ public void listVersions_noVersions_returnsEmptyList() {
assertThat(versions).isEmpty();
}

@Test
public void saveAndReloadArtifact_savesAndReturnsFileData() {
Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "application/octet-stream");
String expectedBlobName =
String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME);
BlobId expectedBlobId = BlobId.of(BUCKET_NAME, expectedBlobName);
BlobInfo expectedBlobInfo =
BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build();

when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of());
Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3});
when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob);

Optional<Part> result =
asOptional(
service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact));

assertThat(result).isPresent();
assertThat(result.get().fileData()).isPresent();
assertThat(result.get().fileData().get().fileUri())
.hasValue("gs://" + BUCKET_NAME + "/" + expectedBlobName);
assertThat(result.get().fileData().get().mimeType()).hasValue("application/octet-stream");
verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}));
}

private static <T> Optional<T> asOptional(Maybe<T> maybe) {
return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet();
}

private static <T> Optional<T> asOptional(Single<T> single) {
return Optional.of(single.blockingGet());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.adk.artifacts;

import static com.google.common.truth.Truth.assertThat;

import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.util.Optional;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for {@link InMemoryArtifactService}. */
@RunWith(JUnit4.class)
public class InMemoryArtifactServiceTest {

private static final String APP_NAME = "test-app";
private static final String USER_ID = "test-user";
private static final String SESSION_ID = "test-session";
private static final String FILENAME = "test-file.txt";

private InMemoryArtifactService service;

@Before
public void setUp() {
service = new InMemoryArtifactService();
}

@Test
public void saveArtifact_savesAndReturnsVersion() {
Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain");
int version =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet();
assertThat(version).isEqualTo(0);
}

@Test
public void loadArtifact_loadsLatest() {
Part artifact1 = Part.fromBytes(new byte[] {1}, "text/plain");
Part artifact2 = Part.fromBytes(new byte[] {1, 2}, "text/plain");
var unused1 =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact1).blockingGet();
var unused2 =
service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet();
Optional<Part> result =
asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty()));
assertThat(result).hasValue(artifact2);
}

@Test
public void saveAndReloadArtifact_reloadsArtifact() {
Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain");
Optional<Part> result =
asOptional(
service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact));
assertThat(result).hasValue(artifact);
}

private static <T> Optional<T> asOptional(Maybe<T> maybe) {
return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet();
}

private static <T> Optional<T> asOptional(Single<T> single) {
return Optional.of(single.blockingGet());
}
}
Loading