diff --git a/.devcontainer/java/devcontainer.json b/.devcontainer/java/devcontainer.json index fcd9c3a..6b631e0 100644 --- a/.devcontainer/java/devcontainer.json +++ b/.devcontainer/java/devcontainer.json @@ -1,24 +1,20 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the -// README at: https://github.com/devcontainers/templates/tree/main/src/java +// README at: https://github.com/devcontainers/templates/tree/main/src/universal { "name": "Default Java", // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile "image": "mcr.microsoft.com/devcontainers/java:latest", "features": { - "ghcr.io/devcontainers/features/azure-cli:1": {}, - "ghcr.io/devcontainers/features/docker-in-docker:2": {}, - "ghcr.io/azure/azure-dev/azd:0": {} + "ghcr.io/devcontainers/features/java:1": { + "version": "none", + "installMaven": "true" + }, + "ghcr.io/devcontainers/features/azure-cli:1": {} }, "customizations": { "vscode": { "extensions": [ - "ms-azuretools.vscode-cosmosdb", - "buildwithlayer.mongodb-integration-expert-qS6DB", - "mongodb.mongodb-vscode", - "ms-azuretools.vscode-documentdb", - "redhat.java", - "vscjava.vscode-maven", - "vscjava.vscode-gradle" + "ms-azuretools.vscode-documentdb" ] } } diff --git a/.gitignore b/.gitignore index 65b1ad6..9f22943 100644 --- a/.gitignore +++ b/.gitignore @@ -485,3 +485,8 @@ dist/ *.user *.suo *.sln.docstates + +# Java +*.class +*.jar +target/ diff --git a/ai/vector-search-java/README.md b/ai/vector-search-java/README.md new file mode 100644 index 0000000..929c2df --- /dev/null +++ b/ai/vector-search-java/README.md @@ -0,0 +1,177 @@ +# DocumentDB Vector Samples (Java) + +This project demonstrates vector search capabilities using Azure DocumentDB with Java. It includes implementations of three different vector index types: DiskANN, HNSW, and IVF. + +## Overview + +Vector search enables semantic similarity searching by converting text into high-dimensional vector representations (embeddings) and finding the most similar vectors in the database. This project shows how to: + +- Generate embeddings using Azure OpenAI +- Store vectors in DocumentDB +- Create and use different types of vector indexes +- Perform similarity searches with various algorithms + +## Prerequisites + +Before running this project, you need: + +### Azure Resources +1. **Azure subscription** with appropriate permissions +2. **[Azure Developer CLI (azd)](https://learn.microsoft.com/azure/developer/azure-developer-cli/)** installed + +### Development Environment +- [Java 21 or higher](https://learn.microsoft.com/java/openjdk/download) +- [Maven 3.6 or higher](https://maven.apache.org/download.cgi) +- [Git](https://git-scm.com/downloads) (for cloning the repository) +- [Visual Studio Code](https://code.visualstudio.com/) (recommended) or another Java IDE + +## Setup Instructions + +### Clone and Setup Project + +```bash +# Clone this repository +git clone https://github.com/Azure-Samples/documentdb-samples +``` + +### Deploy Azure Resources + +This project uses Azure Developer CLI (azd) to deploy all required Azure resources from the existing infrastructure-as-code files. + +#### Install Azure Developer CLI + +If you haven't already, install the Azure Developer CLI: + +**Windows:** +```powershell +winget install microsoft.azd +``` + +**macOS:** +```bash +brew tap azure/azd && brew install azd +``` + +**Linux:** +```bash +curl -fsSL https://aka.ms/install-azd.sh | bash +``` + +#### Deploy Resources + +Navigate to the root of the repository and run: + +```bash +# Login to Azure +azd auth login + +# Provision Azure resources +azd up +``` + +During provisioning, you'll be prompted for: +- **Environment name**: A unique name for your deployment (e.g., "my-vector-search") +- **Azure subscription**: Select your Azure subscription +- **Location**: Choose from `eastus2` or `swedencentral` (required for OpenAI models) + +The `azd up` command will: +- Create a resource group +- Deploy Azure OpenAI with text-embedding-3-small model +- Deploy Azure DocumentDB (MongoDB vCore) cluster +- Create a managed identity for secure access +- Configure all necessary permissions and networking +- Generate a `.env` file with all connection information at the repository root + +### Compile the Project + +```bash +# Move to Java vector search project +cd ai/vector-search-java + +# Compile the project +mvn clean compile +``` + +### Load Environment Variables + +After deployment completes, load the environment variables from the generated `.env` file. The `set -a` command ensures variables are exported to child processes (like the Maven JVM): + +```bash +# From the ai/vector-search-java directory +set -a && source ../../.env && set +a +``` + +You can verify the environment variables are set: + +```bash +echo $MONGO_CLUSTER_NAME +``` + +## Usage + +The project includes several Java classes that demonstrate different aspects of vector search. + +### Sign in to Azure for passwordless connection + +```bash +az login +``` + +### DiskANN Vector Search + +Run DiskANN (Disk-based Approximate Nearest Neighbor) search: + +```bash +mvn exec:java -Dexec.mainClass="com.azure.documentdb.samples.DiskAnn" +``` + +DiskANN is optimized for: +- Large datasets that don't fit in memory +- Efficient disk-based storage +- Good balance of speed and accuracy + +### HNSW Vector Search + +Run HNSW (Hierarchical Navigable Small World) search: + +```bash +mvn exec:java -Dexec.mainClass="com.azure.documentdb.samples.HNSW" +``` + +HNSW provides: +- Excellent search performance +- High recall rates +- Hierarchical graph structure +- Good for real-time applications + +### IVF Vector Search + +Run IVF (Inverted File) search: + +```bash +mvn exec:java -Dexec.mainClass="com.azure.documentdb.samples.IVF" +``` + +IVF features: +- Clusters vectors by similarity +- Fast search through cluster centroids +- Configurable accuracy vs speed trade-offs +- Efficient for large vector datasets + +## Further Resources + +- [Azure Developer CLI Documentation](https://learn.microsoft.com/azure/developer/azure-developer-cli/) +- [Azure DocumentDB Documentation](https://learn.microsoft.com/azure/documentdb/) +- [Azure OpenAI Service Documentation](https://learn.microsoft.com/azure/ai-services/openai/) +- [Vector Search in DocumentDB](https://learn.microsoft.com/azure/documentdb/vector-search) +- [MongoDB Java Driver Documentation](https://mongodb.github.io/mongo-java-driver/) +- [Azure SDK for Java Documentation](https://learn.microsoft.com/java/api/overview/azure/) + +## Support + +If you encounter issues: +1. Verify Java 21+ is installed: `java -version` +2. Verify Maven is installed: `mvn -version` +3. Ensure Azure CLI is logged in: `az login` +4. Verify environment variables are exported: `echo $MONGO_CLUSTER_NAME` +5. Check Azure service status and quotas diff --git a/ai/vector-search-java/pom.xml b/ai/vector-search-java/pom.xml new file mode 100644 index 0000000..3c83df2 --- /dev/null +++ b/ai/vector-search-java/pom.xml @@ -0,0 +1,43 @@ + + 4.0.0 + + com.azure.documentdb.samples + vector-search-quickstart + 1.0-SNAPSHOT + + + 21 + UTF-8 + + + + + org.mongodb + mongodb-driver-sync + 5.6.2 + + + com.azure + azure-identity + 1.18.1 + + + com.azure + azure-ai-openai + 1.0.0-beta.16 + + + tools.jackson.core + jackson-databind + 3.0.3 + + + org.slf4j + slf4j-nop + 2.0.17 + runtime + + + diff --git a/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/DiskAnn.java b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/DiskAnn.java new file mode 100644 index 0000000..676630b --- /dev/null +++ b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/DiskAnn.java @@ -0,0 +1,226 @@ +package com.azure.documentdb.samples; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Indexes; +import org.bson.Document; +import tools.jackson.core.type.TypeReference; +import tools.jackson.databind.json.JsonMapper; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Vector search sample using DiskANN index. + */ +public class DiskAnn { + private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; + private static final String DATABASE_NAME = "Hotels"; + private static final String COLLECTION_NAME = "hotels_diskann"; + private static final String VECTOR_INDEX_NAME = "vectorIndex_diskann"; + + private final JsonMapper jsonMapper = JsonMapper.builder().build(); + + public static void main(String[] args) { + new DiskAnn().run(); + System.exit(0); + } + + public void run() { + try (var mongoClient = createMongoClient()) { + var openAIClient = createOpenAIClient(); + + var database = mongoClient.getDatabase(DATABASE_NAME); + var collection = database.getCollection(COLLECTION_NAME, Document.class); + + // Drop and recreate collection + collection.drop(); + database.createCollection(COLLECTION_NAME); + System.out.println("Created collection: " + COLLECTION_NAME); + + // Load and insert data + var hotelData = loadHotelData(); + insertDataInBatches(collection, hotelData); + + // Create standard indexes + createStandardIndexes(collection); + + // Create vector index + createVectorIndex(database); + + // Perform vector search + var queryEmbedding = createEmbedding(openAIClient, SAMPLE_QUERY); + performVectorSearch(collection, queryEmbedding); + + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + } + + private MongoClient createMongoClient() { + var clusterName = System.getenv("MONGO_CLUSTER_NAME"); + var managedIdentityPrincipalId = System.getenv("AZURE_MANAGED_IDENTITY_PRINCIPAL_ID"); + var azureCredential = new DefaultAzureCredentialBuilder().build(); + + MongoCredential.OidcCallback callback = (MongoCredential.OidcCallbackContext context) -> { + var token = azureCredential.getToken( + new com.azure.core.credential.TokenRequestContext() + .addScopes("https://ossrdbms-aad.database.windows.net/.default") + ).block(); + + if (token == null) { + throw new RuntimeException("Failed to obtain Azure AD token"); + } + + return new MongoCredential.OidcCallbackResult(token.getToken()); + }; + + var credential = MongoCredential.createOidcCredential(null) + .withMechanismProperty("OIDC_CALLBACK", callback); + + var connectionString = new ConnectionString( + String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000", + managedIdentityPrincipalId, clusterName) + ); + + var settings = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .credential(credential) + .build(); + + return MongoClients.create(settings); + } + + private OpenAIClient createOpenAIClient() { + var endpoint = System.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT"); + var credential = new DefaultAzureCredentialBuilder().build(); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(credential) + .buildClient(); + } + + private List> loadHotelData() throws IOException { + var dataFile = System.getenv("DATA_FILE_WITH_VECTORS"); + var filePath = Path.of(dataFile); + + System.out.println("Reading JSON file from " + filePath.toAbsolutePath()); + var jsonContent = Files.readString(filePath); + + return jsonMapper.readValue(jsonContent, new TypeReference>>() {}); + } + + private void insertDataInBatches(MongoCollection collection, List> hotelData) { + var batchSizeStr = System.getenv("LOAD_SIZE_BATCH"); + var batchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; + var batches = partitionList(hotelData, batchSize); + + System.out.println("Processing in batches of " + batchSize + "..."); + + for (int i = 0; i < batches.size(); i++) { + var batch = batches.get(i); + var documents = batch.stream() + .map(Document::new) + .toList(); + + collection.insertMany(documents); + System.out.println("Batch " + (i + 1) + " complete: " + documents.size() + " inserted"); + } + } + + private void createStandardIndexes(MongoCollection collection) { + collection.createIndex(Indexes.ascending("HotelId")); + collection.createIndex(Indexes.ascending("Category")); + collection.createIndex(Indexes.ascending("Description")); + collection.createIndex(Indexes.ascending("Description_fr")); + } + + private void createVectorIndex(MongoDatabase database) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + var dimensionsStr = System.getenv("EMBEDDING_DIMENSIONS"); + var dimensions = dimensionsStr != null ? Integer.parseInt(dimensionsStr) : 1536; + + var indexDefinition = new Document() + .append("createIndexes", COLLECTION_NAME) + .append("indexes", List.of( + new Document() + .append("name", VECTOR_INDEX_NAME) + .append("key", new Document(embeddedField, "cosmosSearch")) + .append("cosmosSearchOptions", new Document() + .append("kind", "vector-diskann") + .append("dimensions", dimensions) + .append("similarity", "COS") + .append("maxDegree", 20) + .append("lBuild", 10) + ) + )); + + database.runCommand(indexDefinition); + System.out.println("Created vector index: " + VECTOR_INDEX_NAME); + } + + private List createEmbedding(OpenAIClient openAIClient, String text) { + var model = System.getenv("AZURE_OPENAI_EMBEDDING_MODEL"); + var options = new EmbeddingsOptions(List.of(text)); + + var response = openAIClient.getEmbeddings(model, options); + return response.getData().get(0).getEmbedding().stream() + .map(Float::doubleValue) + .toList(); + } + + private void performVectorSearch(MongoCollection collection, List queryEmbedding) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + + var searchStage = new Document("$search", new Document() + .append("cosmosSearch", new Document() + .append("vector", queryEmbedding) + .append("path", embeddedField) + .append("k", 5) + ) + ); + + var projectStage = new Document("$project", new Document() + .append("score", new Document("$meta", "searchScore")) + .append("document", "$$ROOT") + ); + + var pipeline = List.of(searchStage, projectStage); + + System.out.println("\nVector search results for: \"" + SAMPLE_QUERY + "\""); + + AggregateIterable results = collection.aggregate(pipeline); + var rank = 1; + + for (var result : results) { + var document = result.get("document", Document.class); + var hotelName = document.getString("HotelName"); + var score = result.getDouble("score"); + System.out.printf("%d. HotelName: %s, Score: %.4f%n", rank++, hotelName, score); + } + } + + private static List> partitionList(List list, int batchSize) { + var partitions = new ArrayList>(); + for (int i = 0; i < list.size(); i += batchSize) { + partitions.add(list.subList(i, Math.min(i + batchSize, list.size()))); + } + return partitions; + } +} diff --git a/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/HNSW.java b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/HNSW.java new file mode 100644 index 0000000..146fc27 --- /dev/null +++ b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/HNSW.java @@ -0,0 +1,226 @@ +package com.azure.documentdb.samples; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Indexes; +import org.bson.Document; +import tools.jackson.core.type.TypeReference; +import tools.jackson.databind.json.JsonMapper; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Vector search sample using HNSW index. + */ +public class HNSW { + private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; + private static final String DATABASE_NAME = "Hotels"; + private static final String COLLECTION_NAME = "hotels_hnsw"; + private static final String VECTOR_INDEX_NAME = "vectorIndex_hnsw"; + + private final JsonMapper jsonMapper = JsonMapper.builder().build(); + + public static void main(String[] args) { + new HNSW().run(); + System.exit(0); + } + + public void run() { + try (var mongoClient = createMongoClient()) { + var openAIClient = createOpenAIClient(); + + var database = mongoClient.getDatabase(DATABASE_NAME); + var collection = database.getCollection(COLLECTION_NAME, Document.class); + + // Drop and recreate collection + collection.drop(); + database.createCollection(COLLECTION_NAME); + System.out.println("Created collection: " + COLLECTION_NAME); + + // Load and insert data + var hotelData = loadHotelData(); + insertDataInBatches(collection, hotelData); + + // Create standard indexes + createStandardIndexes(collection); + + // Create vector index + createVectorIndex(database); + + // Perform vector search + var queryEmbedding = createEmbedding(openAIClient, SAMPLE_QUERY); + performVectorSearch(collection, queryEmbedding); + + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + } + + private MongoClient createMongoClient() { + var clusterName = System.getenv("MONGO_CLUSTER_NAME"); + var managedIdentityPrincipalId = System.getenv("AZURE_MANAGED_IDENTITY_PRINCIPAL_ID"); + var azureCredential = new DefaultAzureCredentialBuilder().build(); + + MongoCredential.OidcCallback callback = (MongoCredential.OidcCallbackContext context) -> { + var token = azureCredential.getToken( + new com.azure.core.credential.TokenRequestContext() + .addScopes("https://ossrdbms-aad.database.windows.net/.default") + ).block(); + + if (token == null) { + throw new RuntimeException("Failed to obtain Azure AD token"); + } + + return new MongoCredential.OidcCallbackResult(token.getToken()); + }; + + var credential = MongoCredential.createOidcCredential(null) + .withMechanismProperty("OIDC_CALLBACK", callback); + + var connectionString = new ConnectionString( + String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000", + managedIdentityPrincipalId, clusterName) + ); + + var settings = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .credential(credential) + .build(); + + return MongoClients.create(settings); + } + + private OpenAIClient createOpenAIClient() { + var endpoint = System.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT"); + var credential = new DefaultAzureCredentialBuilder().build(); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(credential) + .buildClient(); + } + + private List> loadHotelData() throws IOException { + var dataFile = System.getenv("DATA_FILE_WITH_VECTORS"); + var filePath = Path.of(dataFile); + + System.out.println("Reading JSON file from " + filePath.toAbsolutePath()); + var jsonContent = Files.readString(filePath); + + return jsonMapper.readValue(jsonContent, new TypeReference>>() {}); + } + + private void insertDataInBatches(MongoCollection collection, List> hotelData) { + var batchSizeStr = System.getenv("LOAD_SIZE_BATCH"); + var batchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; + var batches = partitionList(hotelData, batchSize); + + System.out.println("Processing in batches of " + batchSize + "..."); + + for (int i = 0; i < batches.size(); i++) { + var batch = batches.get(i); + var documents = batch.stream() + .map(Document::new) + .toList(); + + collection.insertMany(documents); + System.out.println("Batch " + (i + 1) + " complete: " + documents.size() + " inserted"); + } + } + + private void createStandardIndexes(MongoCollection collection) { + collection.createIndex(Indexes.ascending("HotelId")); + collection.createIndex(Indexes.ascending("Category")); + collection.createIndex(Indexes.ascending("Description")); + collection.createIndex(Indexes.ascending("Description_fr")); + } + + private void createVectorIndex(MongoDatabase database) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + var dimensionsStr = System.getenv("EMBEDDING_DIMENSIONS"); + var dimensions = dimensionsStr != null ? Integer.parseInt(dimensionsStr) : 1536; + + var indexDefinition = new Document() + .append("createIndexes", COLLECTION_NAME) + .append("indexes", List.of( + new Document() + .append("name", VECTOR_INDEX_NAME) + .append("key", new Document(embeddedField, "cosmosSearch")) + .append("cosmosSearchOptions", new Document() + .append("kind", "vector-hnsw") + .append("dimensions", dimensions) + .append("similarity", "COS") + .append("m", 16) + .append("efConstruction", 64) + ) + )); + + database.runCommand(indexDefinition); + System.out.println("Created vector index: " + VECTOR_INDEX_NAME); + } + + private List createEmbedding(OpenAIClient openAIClient, String text) { + var model = System.getenv("AZURE_OPENAI_EMBEDDING_MODEL"); + var options = new EmbeddingsOptions(List.of(text)); + + var response = openAIClient.getEmbeddings(model, options); + return response.getData().get(0).getEmbedding().stream() + .map(Float::doubleValue) + .toList(); + } + + private void performVectorSearch(MongoCollection collection, List queryEmbedding) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + + var searchStage = new Document("$search", new Document() + .append("cosmosSearch", new Document() + .append("vector", queryEmbedding) + .append("path", embeddedField) + .append("k", 5) + ) + ); + + var projectStage = new Document("$project", new Document() + .append("score", new Document("$meta", "searchScore")) + .append("document", "$$ROOT") + ); + + var pipeline = List.of(searchStage, projectStage); + + System.out.println("\nVector search results for: \"" + SAMPLE_QUERY + "\""); + + AggregateIterable results = collection.aggregate(pipeline); + var rank = 1; + + for (var result : results) { + var document = result.get("document", Document.class); + var hotelName = document.getString("HotelName"); + var score = result.getDouble("score"); + System.out.printf("%d. HotelName: %s, Score: %.4f%n", rank++, hotelName, score); + } + } + + private static List> partitionList(List list, int batchSize) { + var partitions = new ArrayList>(); + for (int i = 0; i < list.size(); i += batchSize) { + partitions.add(list.subList(i, Math.min(i + batchSize, list.size()))); + } + return partitions; + } +} diff --git a/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/IVF.java b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/IVF.java new file mode 100644 index 0000000..e800107 --- /dev/null +++ b/ai/vector-search-java/src/main/java/com/azure/documentdb/samples/IVF.java @@ -0,0 +1,225 @@ +package com.azure.documentdb.samples; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Indexes; +import org.bson.Document; +import tools.jackson.core.type.TypeReference; +import tools.jackson.databind.json.JsonMapper; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Vector search sample using IVF index. + */ +public class IVF { + private static final String SAMPLE_QUERY = "quintessential lodging near running trails, eateries, retail"; + private static final String DATABASE_NAME = "Hotels"; + private static final String COLLECTION_NAME = "hotels_ivf"; + private static final String VECTOR_INDEX_NAME = "vectorIndex_ivf"; + + private final JsonMapper jsonMapper = JsonMapper.builder().build(); + + public static void main(String[] args) { + new IVF().run(); + System.exit(0); + } + + public void run() { + try (var mongoClient = createMongoClient()) { + var openAIClient = createOpenAIClient(); + + var database = mongoClient.getDatabase(DATABASE_NAME); + var collection = database.getCollection(COLLECTION_NAME, Document.class); + + // Drop and recreate collection + collection.drop(); + database.createCollection(COLLECTION_NAME); + System.out.println("Created collection: " + COLLECTION_NAME); + + // Load and insert data + var hotelData = loadHotelData(); + insertDataInBatches(collection, hotelData); + + // Create standard indexes + createStandardIndexes(collection); + + // Create vector index + createVectorIndex(database); + + // Perform vector search + var queryEmbedding = createEmbedding(openAIClient, SAMPLE_QUERY); + performVectorSearch(collection, queryEmbedding); + + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + } + + private MongoClient createMongoClient() { + var clusterName = System.getenv("MONGO_CLUSTER_NAME"); + var managedIdentityPrincipalId = System.getenv("AZURE_MANAGED_IDENTITY_PRINCIPAL_ID"); + var azureCredential = new DefaultAzureCredentialBuilder().build(); + + MongoCredential.OidcCallback callback = (MongoCredential.OidcCallbackContext context) -> { + var token = azureCredential.getToken( + new com.azure.core.credential.TokenRequestContext() + .addScopes("https://ossrdbms-aad.database.windows.net/.default") + ).block(); + + if (token == null) { + throw new RuntimeException("Failed to obtain Azure AD token"); + } + + return new MongoCredential.OidcCallbackResult(token.getToken()); + }; + + var credential = MongoCredential.createOidcCredential(null) + .withMechanismProperty("OIDC_CALLBACK", callback); + + var connectionString = new ConnectionString( + String.format("mongodb+srv://%s@%s.mongocluster.cosmos.azure.com/?authMechanism=MONGODB-OIDC&tls=true&retrywrites=false&maxIdleTimeMS=120000", + managedIdentityPrincipalId, clusterName) + ); + + var settings = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .credential(credential) + .build(); + + return MongoClients.create(settings); + } + + private OpenAIClient createOpenAIClient() { + var endpoint = System.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT"); + var credential = new DefaultAzureCredentialBuilder().build(); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(credential) + .buildClient(); + } + + private List> loadHotelData() throws IOException { + var dataFile = System.getenv("DATA_FILE_WITH_VECTORS"); + var filePath = Path.of(dataFile); + + System.out.println("Reading JSON file from " + filePath.toAbsolutePath()); + var jsonContent = Files.readString(filePath); + + return jsonMapper.readValue(jsonContent, new TypeReference>>() {}); + } + + private void insertDataInBatches(MongoCollection collection, List> hotelData) { + var batchSizeStr = System.getenv("LOAD_SIZE_BATCH"); + var batchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; + var batches = partitionList(hotelData, batchSize); + + System.out.println("Processing in batches of " + batchSize + "..."); + + for (int i = 0; i < batches.size(); i++) { + var batch = batches.get(i); + var documents = batch.stream() + .map(Document::new) + .toList(); + + collection.insertMany(documents); + System.out.println("Batch " + (i + 1) + " complete: " + documents.size() + " inserted"); + } + } + + private void createStandardIndexes(MongoCollection collection) { + collection.createIndex(Indexes.ascending("HotelId")); + collection.createIndex(Indexes.ascending("Category")); + collection.createIndex(Indexes.ascending("Description")); + collection.createIndex(Indexes.ascending("Description_fr")); + } + + private void createVectorIndex(MongoDatabase database) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + var dimensionsStr = System.getenv("EMBEDDING_DIMENSIONS"); + var dimensions = dimensionsStr != null ? Integer.parseInt(dimensionsStr) : 1536; + + var indexDefinition = new Document() + .append("createIndexes", COLLECTION_NAME) + .append("indexes", List.of( + new Document() + .append("name", VECTOR_INDEX_NAME) + .append("key", new Document(embeddedField, "cosmosSearch")) + .append("cosmosSearchOptions", new Document() + .append("kind", "vector-ivf") + .append("dimensions", dimensions) + .append("similarity", "COS") + .append("numLists", 10) + ) + )); + + database.runCommand(indexDefinition); + System.out.println("Created vector index: " + VECTOR_INDEX_NAME); + } + + private List createEmbedding(OpenAIClient openAIClient, String text) { + var model = System.getenv("AZURE_OPENAI_EMBEDDING_MODEL"); + var options = new EmbeddingsOptions(List.of(text)); + + var response = openAIClient.getEmbeddings(model, options); + return response.getData().get(0).getEmbedding().stream() + .map(Float::doubleValue) + .toList(); + } + + private void performVectorSearch(MongoCollection collection, List queryEmbedding) { + var embeddedField = System.getenv("EMBEDDED_FIELD"); + + var searchStage = new Document("$search", new Document() + .append("cosmosSearch", new Document() + .append("vector", queryEmbedding) + .append("path", embeddedField) + .append("k", 5) + ) + ); + + var projectStage = new Document("$project", new Document() + .append("score", new Document("$meta", "searchScore")) + .append("document", "$$ROOT") + ); + + var pipeline = List.of(searchStage, projectStage); + + System.out.println("\nVector search results for: \"" + SAMPLE_QUERY + "\""); + + AggregateIterable results = collection.aggregate(pipeline); + var rank = 1; + + for (var result : results) { + var document = result.get("document", Document.class); + var hotelName = document.getString("HotelName"); + var score = result.getDouble("score"); + System.out.printf("%d. HotelName: %s, Score: %.4f%n", rank++, hotelName, score); + } + } + + private static List> partitionList(List list, int batchSize) { + var partitions = new ArrayList>(); + for (int i = 0; i < list.size(); i += batchSize) { + partitions.add(list.subList(i, Math.min(i + batchSize, list.size()))); + } + return partitions; + } +}