diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 840666396..8cc84b185 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Add support for unified hosts with experimental flag. + ### Bug Fixes ### Security Vulnerabilities diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/AccountClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/AccountClient.java index 5461ba07e..7408c0629 100755 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/AccountClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/AccountClient.java @@ -5,6 +5,7 @@ import com.databricks.sdk.core.ApiClient; import com.databricks.sdk.core.ConfigLoader; import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.HostType; import com.databricks.sdk.core.utils.AzureUtils; import com.databricks.sdk.service.billing.BillableUsageAPI; import com.databricks.sdk.service.billing.BillableUsageService; @@ -1111,6 +1112,14 @@ public DatabricksConfig config() { } public WorkspaceClient getWorkspaceClient(Workspace workspace) { + // For unified hosts, reuse the same host and set workspace ID + if (this.config.getHostType() == HostType.UNIFIED) { + DatabricksConfig workspaceConfig = this.config.clone(); + workspaceConfig.setWorkspaceId(String.valueOf(workspace.getWorkspaceId())); + return new WorkspaceClient(workspaceConfig); + } + + // For traditional account hosts, get workspace deployment URL String host = this.config.getDatabricksEnvironment().getDeploymentUrl(workspace.getDeploymentName()); DatabricksConfig config = this.config.newWithWorkspaceHost(host); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ClientType.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ClientType.java new file mode 100644 index 000000000..e03956e8c --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ClientType.java @@ -0,0 +1,13 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.support.InternalApi; + +/** Represents the type of Databricks client being used for API operations. */ +@InternalApi +public enum ClientType { + /** Workspace client (traditional or unified host with workspaceId). */ + WORKSPACE, + + /** Account client (traditional or unified host without workspaceId). */ + ACCOUNT +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index 687ec1dd7..e4376f203 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -31,7 +31,7 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { } List cmd = new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost())); - if (config.isAccountClient()) { + if (config.getClientType() == ClientType.ACCOUNT) { cmd.add("--account-id"); cmd.add(config.getAccountId()); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 572d3cb9b..0b8ee81f8 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -27,6 +27,20 @@ public class DatabricksConfig { @ConfigAttribute(env = "DATABRICKS_ACCOUNT_ID") private String accountId; + /** + * Workspace ID for unified host operations. Note: This API is experimental and may change or be + * removed in future releases without notice. + */ + @ConfigAttribute(env = "DATABRICKS_WORKSPACE_ID") + private String workspaceId; + + /** + * Flag to explicitly mark a host as a unified host. Note: This API is experimental and may change + * or be removed in future releases without notice. + */ + @ConfigAttribute(env = "DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST") + private Boolean experimentalIsUnifiedHost; + @ConfigAttribute(env = "DATABRICKS_TOKEN", auth = "pat", sensitive = true) private String token; @@ -43,10 +57,8 @@ public class DatabricksConfig { private String redirectUrl; /** - * The OpenID Connect discovery URL used to retrieve OIDC configuration and endpoints. - * - *

Note: This API is experimental and may change or be removed in future releases - * without notice. + * The OpenID Connect discovery URL used to retrieve OIDC configuration and endpoints. Note: This + * API is experimental and may change or be removed in future releases without notice. */ @ConfigAttribute(env = "DATABRICKS_DISCOVERY_URL") private String discoveryUrl; @@ -233,8 +245,16 @@ public synchronized Map authenticate() throws DatabricksExceptio if (headerFactory == null) { // Calling authenticate without resolve ConfigLoader.fixHostIfNeeded(this); - headerFactory = credentialsProvider.configure(this); + HeaderFactory rawHeaderFactory = credentialsProvider.configure(this); setAuthType(credentialsProvider.authType()); + + // For unified hosts with workspace operations, wrap the header factory + // to inject the X-Databricks-Org-Id header + if (getHostType() == HostType.UNIFIED && workspaceId != null && !workspaceId.isEmpty()) { + headerFactory = new UnifiedHostHeaderFactory(rawHeaderFactory, workspaceId); + } else { + headerFactory = rawHeaderFactory; + } } return headerFactory.headers(); } catch (DatabricksException e) { @@ -298,6 +318,24 @@ public DatabricksConfig setAccountId(String accountId) { return this; } + public String getWorkspaceId() { + return workspaceId; + } + + public DatabricksConfig setWorkspaceId(String workspaceId) { + this.workspaceId = workspaceId; + return this; + } + + public Boolean getExperimentalIsUnifiedHost() { + return experimentalIsUnifiedHost; + } + + public DatabricksConfig setExperimentalIsUnifiedHost(Boolean experimentalIsUnifiedHost) { + this.experimentalIsUnifiedHost = experimentalIsUnifiedHost; + return this; + } + public String getDatabricksCliPath() { return this.databricksCliPath; } @@ -679,12 +717,49 @@ public boolean isAws() { } public boolean isAccountClient() { + if (getHostType() == HostType.UNIFIED) { + throw new DatabricksException( + "Cannot determine account client status for unified hosts. " + + "Use getHostType() or getClientType() instead. " + + "For unified hosts, client type depends on whether workspaceId is set."); + } if (host == null) { return false; } return host.startsWith("https://accounts.") || host.startsWith("https://accounts-dod."); } + /** Returns the host type based on configuration settings and host URL. */ + public HostType getHostType() { + if (experimentalIsUnifiedHost != null && experimentalIsUnifiedHost) { + return HostType.UNIFIED; + } + if (host == null) { + return HostType.WORKSPACE; + } + if (host.startsWith("https://accounts.") || host.startsWith("https://accounts-dod.")) { + return HostType.ACCOUNTS; + } + return HostType.WORKSPACE; + } + + /** Returns the client type based on host type and workspace ID configuration. */ + public ClientType getClientType() { + HostType hostType = getHostType(); + switch (hostType) { + case UNIFIED: + // For unified hosts, client type depends on whether workspaceId is set + return (workspaceId != null && !workspaceId.isEmpty()) + ? ClientType.WORKSPACE + : ClientType.ACCOUNT; + case ACCOUNTS: + return ClientType.ACCOUNT; + case WORKSPACE: + default: + return ClientType.WORKSPACE; + } + } + public OpenIDConnectEndpoints getOidcEndpoints() throws IOException { if (discoveryUrl == null) { return fetchDefaultOidcEndpoints(); @@ -705,10 +780,25 @@ private OpenIDConnectEndpoints fetchOidcEndpointsFromDiscovery() { return null; } + private OpenIDConnectEndpoints getUnifiedOidcEndpoints(String accountId) throws IOException { + if (accountId == null || accountId.isEmpty()) { + throw new DatabricksException( + "account_id is required for unified host OIDC endpoint discovery"); + } + String prefix = getHost() + "/oidc/accounts/" + accountId; + return new OpenIDConnectEndpoints(prefix + "/v1/token", prefix + "/v1/authorize"); + } + private OpenIDConnectEndpoints fetchDefaultOidcEndpoints() throws IOException { if (getHost() == null) { return null; } + + // For unified hosts, use account-based OIDC endpoints + if (getHostType() == HostType.UNIFIED) { + return getUnifiedOidcEndpoints(getAccountId()); + } + if (isAzure() && getAzureClientId() != null) { Request request = new Request("GET", getHost() + "/oidc/oauth2/v2.0/authorize"); request.setRedirectionBehavior(false); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index 59ec6eca0..8d4e593f4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -150,7 +150,8 @@ private void addOIDCCredentialsProviders(DatabricksConfig config) { namedIdTokenSource.idTokenSource, config.getHttpClient()) .audience(config.getTokenAudience()) - .accountId(config.isAccountClient() ? config.getAccountId() : null) + .accountId( + config.getClientType() == ClientType.ACCOUNT ? config.getAccountId() : null) .scopes(config.getScopes()) .build(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java index 755c1b331..463d2bab9 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java @@ -66,7 +66,7 @@ public HeaderFactory configure(DatabricksConfig config) { Map headers = new HashMap<>(); headers.put("Authorization", String.format("Bearer %s", idToken.getTokenValue())); - if (config.isAccountClient()) { + if (config.getClientType() == ClientType.ACCOUNT) { AccessToken token; try { token = finalServiceAccountCredentials.createScoped(GCP_SCOPES).refreshAccessToken(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java index c51dfd4cc..376d691c5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java @@ -69,7 +69,7 @@ public HeaderFactory configure(DatabricksConfig config) { throw new DatabricksException(message, e); } - if (config.isAccountClient()) { + if (config.getClientType() == ClientType.ACCOUNT) { try { headers.put( SA_ACCESS_TOKEN_HEADER, gcpScopedCredentials.refreshAccessToken().getTokenValue()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/HostType.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/HostType.java new file mode 100644 index 000000000..005807839 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/HostType.java @@ -0,0 +1,16 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.support.InternalApi; + +/** Represents the type of Databricks host being used. */ +@InternalApi +public enum HostType { + /** Traditional workspace host. */ + WORKSPACE, + + /** Traditional accounts host. */ + ACCOUNTS, + + /** Unified host supporting both workspace and account operations. */ + UNIFIED +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/UnifiedHostHeaderFactory.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/UnifiedHostHeaderFactory.java new file mode 100644 index 000000000..2889329e5 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/UnifiedHostHeaderFactory.java @@ -0,0 +1,36 @@ +package com.databricks.sdk.core; + +import java.util.HashMap; +import java.util.Map; + +/** + * HeaderFactory wrapper that adds X-Databricks-Org-Id header for unified host workspace operations. + */ +class UnifiedHostHeaderFactory implements HeaderFactory { + private final HeaderFactory delegate; + private final String workspaceId; + + /** + * Creates a new unified host header factory. + * + * @param delegate The underlying header factory (e.g., OAuth, PAT) + * @param workspaceId The workspace ID to inject in the X-Databricks-Org-Id header + */ + public UnifiedHostHeaderFactory(HeaderFactory delegate, String workspaceId) { + if (delegate == null) { + throw new IllegalArgumentException("delegate cannot be null"); + } + if (workspaceId == null || workspaceId.isEmpty()) { + throw new IllegalArgumentException("workspaceId cannot be null or empty"); + } + this.delegate = delegate; + this.workspaceId = workspaceId; + } + + @Override + public Map headers() { + Map headers = new HashMap<>(delegate.headers()); + headers.put("X-Databricks-Org-Id", workspaceId); + return headers; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/AccountClientTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/AccountClientTest.java new file mode 100644 index 000000000..ca20fe5a2 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/AccountClientTest.java @@ -0,0 +1,75 @@ +package com.databricks.sdk; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.ClientType; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.HostType; +import com.databricks.sdk.service.provisioning.Workspace; +import org.junit.jupiter.api.Test; + +public class AccountClientTest { + + @Test + public void testGetWorkspaceClientForTraditionalAccount() { + DatabricksConfig accountConfig = + new DatabricksConfig() + .setHost("https://accounts.cloud.databricks.com") + .setAccountId("test-account") + .setToken("test-token"); + + AccountClient accountClient = new AccountClient(accountConfig); + + Workspace workspace = new Workspace(); + workspace.setWorkspaceId(123L); + workspace.setDeploymentName("test-workspace"); + + WorkspaceClient workspaceClient = accountClient.getWorkspaceClient(workspace); + + // Should have a different host + assertNotEquals(accountConfig.getHost(), workspaceClient.config().getHost()); + assertTrue(workspaceClient.config().getHost().contains("test-workspace")); + } + + @Test + public void testGetWorkspaceClientForUnifiedHost() { + String unifiedHost = "https://unified.databricks.com"; + DatabricksConfig accountConfig = + new DatabricksConfig() + .setHost(unifiedHost) + .setExperimentalIsUnifiedHost(true) + .setAccountId("test-account") + .setToken("test-token"); + + AccountClient accountClient = new AccountClient(accountConfig); + + Workspace workspace = new Workspace(); + workspace.setWorkspaceId(123456L); + workspace.setDeploymentName("test-workspace"); + + WorkspaceClient workspaceClient = accountClient.getWorkspaceClient(workspace); + + // Should have the same host + assertEquals(unifiedHost, workspaceClient.config().getHost()); + + // Should have workspace ID set + assertEquals("123456", workspaceClient.config().getWorkspaceId()); + + // Should be workspace client type (on unified host) + assertEquals(ClientType.WORKSPACE, workspaceClient.config().getClientType()); + + // Host type should still be unified + assertEquals(HostType.UNIFIED, workspaceClient.config().getHostType()); + } + + @Test + public void testGetWorkspaceClientForUnifiedHostType() { + // Verify unified host type is correctly detected + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true); + + assertEquals(HostType.UNIFIED, config.getHostType()); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java index d805de323..eab6d991c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java @@ -358,4 +358,67 @@ public void testConfigFileScopes(String testName, String profile, List e List scopes = config.getScopes(); assertIterableEquals(expectedScopes, scopes); } + + // --- Unified Host Tests (added for SPOG support) --- + + @Test + public void testGetHostTypeWorkspace() { + assertEquals( + HostType.WORKSPACE, + new DatabricksConfig().setHost("https://adb-123.azuredatabricks.net").getHostType()); + } + + @Test + public void testGetHostTypeAccounts() { + assertEquals( + HostType.ACCOUNTS, + new DatabricksConfig().setHost("https://accounts.cloud.databricks.com").getHostType()); + } + + @Test + public void testGetHostTypeUnified() { + assertEquals( + HostType.UNIFIED, + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .getHostType()); + } + + @Test + public void testGetClientTypeWorkspace() { + assertEquals( + ClientType.WORKSPACE, + new DatabricksConfig().setHost("https://adb-123.azuredatabricks.net").getClientType()); + } + + @Test + public void testGetClientTypeAccount() { + assertEquals( + ClientType.ACCOUNT, + new DatabricksConfig().setHost("https://accounts.cloud.databricks.com").getClientType()); + } + + @Test + public void testGetClientTypeWorkspaceOnUnified() { + // For unified hosts with workspaceId, client type is WORKSPACE + assertEquals( + ClientType.WORKSPACE, + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .setWorkspaceId("123456") + .getClientType()); + } + + @Test + public void testGetClientTypeAccountOnUnified() { + // For unified hosts without workspaceId, client type is ACCOUNT + assertEquals( + ClientType.ACCOUNT, + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .getClientType()); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/UnifiedHostTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/UnifiedHostTest.java new file mode 100644 index 000000000..6ecb9a048 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/UnifiedHostTest.java @@ -0,0 +1,265 @@ +package com.databricks.sdk.core; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.utils.Environment; +import java.io.IOException; +import java.util.*; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +/** + * Tests for unified host support (SPOG). + * + *

Covers host type detection, client type determination, header injection, and OIDC endpoint + * resolution for unified hosts. + */ +public class UnifiedHostTest { + + // --- Host Type Detection Tests --- + + @Test + public void testHostTypeWorkspace() { + DatabricksConfig config = + new DatabricksConfig().setHost("https://adb-123456789.0.azuredatabricks.net"); + assertEquals(HostType.WORKSPACE, config.getHostType()); + } + + @Test + public void testHostTypeAccounts() { + DatabricksConfig config = + new DatabricksConfig().setHost("https://accounts.cloud.databricks.com"); + assertEquals(HostType.ACCOUNTS, config.getHostType()); + } + + @Test + public void testHostTypeAccountsDod() { + DatabricksConfig config = + new DatabricksConfig().setHost("https://accounts-dod.cloud.databricks.us"); + assertEquals(HostType.ACCOUNTS, config.getHostType()); + } + + @Test + public void testHostTypeUnifiedExplicitFlag() { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true); + assertEquals(HostType.UNIFIED, config.getHostType()); + } + + @Test + public void testHostTypeUnifiedOverridesAccounts() { + // Even if host looks like accounts, explicit flag takes precedence + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://accounts.cloud.databricks.com") + .setExperimentalIsUnifiedHost(true); + assertEquals(HostType.UNIFIED, config.getHostType()); + } + + @Test + public void testHostTypeNullHost() { + DatabricksConfig config = new DatabricksConfig(); + assertEquals(HostType.WORKSPACE, config.getHostType()); + } + + // --- Client Type Detection Tests --- + + private static Stream provideClientTypeTestCases() { + return Stream.of( + Arguments.of( + "Workspace host", + "https://adb-123.azuredatabricks.net", + null, + false, + ClientType.WORKSPACE), + Arguments.of( + "Account host", + "https://accounts.cloud.databricks.com", + null, + false, + ClientType.ACCOUNT), + Arguments.of( + "Unified without workspace ID", + "https://unified.databricks.com", + null, + true, + ClientType.ACCOUNT), + Arguments.of( + "Unified with workspace ID", + "https://unified.databricks.com", + "123456", + true, + ClientType.WORKSPACE)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideClientTypeTestCases") + public void testClientType( + String testName, String host, String workspaceId, boolean isUnified, ClientType expected) { + DatabricksConfig config = new DatabricksConfig().setHost(host).setWorkspaceId(workspaceId); + if (isUnified) { + config.setExperimentalIsUnifiedHost(true); + } + assertEquals(expected, config.getClientType()); + } + + // --- OIDC Endpoint Tests --- + + @Test + public void testOidcEndpointsForUnifiedHost() throws IOException { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .setAccountId("test-account-123"); + + OpenIDConnectEndpoints endpoints = config.getOidcEndpoints(); + + assertEquals( + "https://unified.databricks.com/oidc/accounts/test-account-123/v1/authorize", + endpoints.getAuthorizationEndpoint()); + assertEquals( + "https://unified.databricks.com/oidc/accounts/test-account-123/v1/token", + endpoints.getTokenEndpoint()); + } + + @Test + public void testOidcEndpointsForUnifiedHostMissingAccountId() { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true); + // No account ID set + + DatabricksException exception = + assertThrows(DatabricksException.class, () -> config.getOidcEndpoints()); + assertTrue(exception.getMessage().contains("account_id is required")); + } + + // --- isAccountClient() Deprecation Tests --- + + @Test + public void testIsAccountClientThrowsForUnifiedHost() { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true); + + DatabricksException exception = + assertThrows(DatabricksException.class, config::isAccountClient); + assertTrue(exception.getMessage().contains("Cannot determine account client status")); + assertTrue(exception.getMessage().contains("getHostType()")); + } + + @Test + public void testIsAccountClientWorksFineForTraditionalHosts() { + assertTrue( + new DatabricksConfig().setHost("https://accounts.cloud.databricks.com").isAccountClient()); + + assertFalse( + new DatabricksConfig().setHost("https://adb-123.azuredatabricks.net").isAccountClient()); + } + + // --- Environment Variable Tests --- + + @Test + public void testUnifiedHostFromEnvironmentVariables() { + Map env = new HashMap<>(); + env.put("DATABRICKS_HOST", "https://unified.databricks.com"); + env.put("DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST", "true"); + env.put("DATABRICKS_WORKSPACE_ID", "987654321"); + env.put("DATABRICKS_ACCOUNT_ID", "account-abc"); + + DatabricksConfig config = new DatabricksConfig(); + config.resolve(new Environment(env, new ArrayList<>(), System.getProperty("os.name"))); + + assertEquals(HostType.UNIFIED, config.getHostType()); + assertEquals("987654321", config.getWorkspaceId()); + assertEquals("account-abc", config.getAccountId()); + assertEquals(ClientType.WORKSPACE, config.getClientType()); + } + + // --- UnifiedHostHeaderFactory Tests --- + + @Test + public void testUnifiedHostHeaderFactoryAddsHeader() { + Map baseHeaders = new HashMap<>(); + baseHeaders.put("Authorization", "Bearer token123"); + + HeaderFactory baseFactory = () -> baseHeaders; + UnifiedHostHeaderFactory unifiedFactory = new UnifiedHostHeaderFactory(baseFactory, "ws-456"); + + Map headers = unifiedFactory.headers(); + + assertEquals("Bearer token123", headers.get("Authorization")); + assertEquals("ws-456", headers.get("X-Databricks-Org-Id")); + } + + @Test + public void testUnifiedHostHeaderFactoryRequiresDelegate() { + assertThrows( + IllegalArgumentException.class, () -> new UnifiedHostHeaderFactory(null, "ws-123")); + } + + @Test + public void testUnifiedHostHeaderFactoryRequiresWorkspaceId() { + HeaderFactory baseFactory = () -> new HashMap<>(); + assertThrows( + IllegalArgumentException.class, () -> new UnifiedHostHeaderFactory(baseFactory, null)); + assertThrows( + IllegalArgumentException.class, () -> new UnifiedHostHeaderFactory(baseFactory, "")); + } + + // --- Header Injection Integration Tests --- + + @Test + public void testHeaderInjectionForWorkspaceOnUnified() { + String workspaceId = "123456789"; + + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .setWorkspaceId(workspaceId) + .setToken("test-token"); + + Map headers = config.authenticate(); + + assertEquals("Bearer test-token", headers.get("Authorization")); + assertEquals(workspaceId, headers.get("X-Databricks-Org-Id")); + } + + @Test + public void testNoHeaderInjectionForAccountOnUnified() { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://unified.databricks.com") + .setExperimentalIsUnifiedHost(true) + .setToken("test-token"); + // No workspace ID set + + Map headers = config.authenticate(); + + assertEquals("Bearer test-token", headers.get("Authorization")); + assertNull(headers.get("X-Databricks-Org-Id")); + } + + @Test + public void testNoHeaderInjectionForTraditionalWorkspace() { + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://adb-123.azuredatabricks.net") + .setToken("test-token"); + + Map headers = config.authenticate(); + + assertEquals("Bearer test-token", headers.get("Authorization")); + assertNull(headers.get("X-Databricks-Org-Id")); + } +}