diff --git a/src/Storages/IStorageCluster.cpp b/src/Storages/IStorageCluster.cpp index b7d72ce08ce8..5888b2644879 100644 --- a/src/Storages/IStorageCluster.cpp +++ b/src/Storages/IStorageCluster.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -115,11 +117,14 @@ class SearcherVisitor : public InDepthQueryTreeVisitorWithContext; using Base::Base; - explicit SearcherVisitor(std::unordered_set types_, ContextPtr context) : Base(context), types(types_) {} + explicit SearcherVisitor(std::unordered_set types_, size_t entry_, ContextPtr context) + : Base(context) + , types(types_) + , entry(entry_) {} bool needChildVisit(QueryTreeNodePtr & /*parent*/, QueryTreeNodePtr & /*child*/) { - return getSubqueryDepth() <= 2 && !passed_node; + return getSubqueryDepth() <= 2 && !passed_node && !current_entry; } void enterImpl(QueryTreeNodePtr & node) @@ -130,13 +135,19 @@ class SearcherVisitor : public InDepthQueryTreeVisitorWithContextgetNodeType(); if (types.contains(node_type)) - passed_node = node; + { + ++current_entry; + if (current_entry == entry) + passed_node = node; + } } QueryTreeNodePtr getNode() const { return passed_node; } private: std::unordered_set types; + size_t entry; + size_t current_entry = 0; QueryTreeNodePtr passed_node; }; @@ -203,15 +214,24 @@ Converts localtable as t ON s3.key == t.key -to +to (object_storage_cluster_join_mode='local') SELECT s3.c1, s3.c2, s3.key FROM s3Cluster(...) AS s3 + +or (object_storage_cluster_join_mode='global') + + SELECT s3.c1, s3.c2, t.c3 + FROM + s3Cluster(...) as s3 + JOIN + values('key UInt32, data String', (1, 'one'), (2, 'two'), ...) as t + ON s3.key == t.key */ void IStorageCluster::updateQueryWithJoinToSendIfNeeded( ASTPtr & query_to_send, - QueryTreeNodePtr query_tree, + SelectQueryInfo query_info, const ContextPtr & context) { auto object_storage_cluster_join_mode = context->getSettingsRef()[Setting::object_storage_cluster_join_mode]; @@ -219,17 +239,17 @@ void IStorageCluster::updateQueryWithJoinToSendIfNeeded( { case ObjectStorageClusterJoinMode::LOCAL: { - auto info = getQueryTreeInfo(query_tree, context); + auto info = getQueryTreeInfo(query_info.query_tree, context); if (info.has_join || info.has_cross_join || info.has_local_columns_in_where) { - auto modified_query_tree = query_tree->clone(); + auto modified_query_tree = query_info.query_tree->clone(); - SearcherVisitor left_table_expression_searcher({QueryTreeNodeType::TABLE, QueryTreeNodeType::TABLE_FUNCTION}, context); + SearcherVisitor left_table_expression_searcher({QueryTreeNodeType::TABLE, QueryTreeNodeType::TABLE_FUNCTION}, 1, context); left_table_expression_searcher.visit(modified_query_tree); auto table_function_node = left_table_expression_searcher.getNode(); if (!table_function_node) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Can't find table function node"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Can't find left table function node"); QueryTreeNodePtr query_tree_distributed; @@ -242,7 +262,7 @@ void IStorageCluster::updateQueryWithJoinToSendIfNeeded( } else if (info.has_cross_join) { - SearcherVisitor join_searcher({QueryTreeNodeType::CROSS_JOIN}, context); + SearcherVisitor join_searcher({QueryTreeNodeType::CROSS_JOIN}, 1, context); join_searcher.visit(modified_query_tree); auto cross_join_node = join_searcher.getNode(); if (!cross_join_node) @@ -297,8 +317,25 @@ void IStorageCluster::updateQueryWithJoinToSendIfNeeded( return; } case ObjectStorageClusterJoinMode::GLOBAL: - // TODO - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "`Global` mode for `object_storage_cluster_join_mode` setting is unimplemented for now"); + { + auto info = getQueryTreeInfo(query_info.query_tree, context); + + if (info.has_join || info.has_cross_join || info.has_local_columns_in_where) + { + auto modified_query_tree = query_info.query_tree->clone(); + + rewriteJoinToGlobalJoin(modified_query_tree, context); + modified_query_tree = buildQueryTreeForShard( + query_info.planner_context, + modified_query_tree, + /*allow_global_join_for_right_table*/ true, + /*find_cross_join*/ true); + query_to_send = queryNodeToDistributedSelectQuery(modified_query_tree); + send_external_tables = true; + } + + return; + } case ObjectStorageClusterJoinMode::ALLOW: // Do nothing special return; } @@ -336,7 +373,7 @@ void IStorageCluster::read( SharedHeader sample_block; ASTPtr query_to_send = query_info.query; - updateQueryWithJoinToSendIfNeeded(query_to_send, query_info.query_tree, context); + updateQueryWithJoinToSendIfNeeded(query_to_send, query_info, context); if (settings[Setting::allow_experimental_analyzer]) { @@ -374,6 +411,10 @@ void IStorageCluster::read( auto this_ptr = std::static_pointer_cast(shared_from_this()); + std::optional external_tables = std::nullopt; + if (send_external_tables && query_info.planner_context && query_info.planner_context->getMutableQueryContext()) + external_tables = query_info.planner_context->getMutableQueryContext()->getExternalTables(); + auto reading = std::make_unique( column_names, query_info, @@ -384,7 +425,8 @@ void IStorageCluster::read( std::move(query_to_send), processed_stage, cluster, - log); + log, + external_tables); query_plan.addStep(std::move(reading)); } @@ -502,7 +544,7 @@ void ReadFromCluster::initializePipeline(QueryPipelineBuilder & pipeline, const new_context, /*throttler=*/nullptr, scalars, - Tables(), + external_tables.has_value() ? *external_tables : Tables(), processed_stage, nullptr, RemoteQueryExecutor::Extension{.task_iterator = extension->task_iterator, .replica_info = std::move(replica_info)}); @@ -540,7 +582,7 @@ IStorageCluster::QueryTreeInfo IStorageCluster::getQueryTreeInfo(QueryTreeNodePt info.has_cross_join = true; } - SearcherVisitor left_table_expression_searcher({QueryTreeNodeType::TABLE, QueryTreeNodeType::TABLE_FUNCTION}, context); + SearcherVisitor left_table_expression_searcher({QueryTreeNodeType::TABLE, QueryTreeNodeType::TABLE_FUNCTION}, 1, context); left_table_expression_searcher.visit(query_tree); auto table_function_node = left_table_expression_searcher.getNode(); if (!table_function_node) @@ -573,11 +615,14 @@ QueryProcessingStage::Enum IStorageCluster::getQueryProcessingStage( { if (!context->getSettingsRef()[Setting::allow_experimental_analyzer]) throw Exception(ErrorCodes::NOT_IMPLEMENTED, - "object_storage_cluster_join_mode!='allow' is not supported without allow_experimental_analyzer=true"); + "object_storage_cluster_join_mode!='allow' is not supported without allow_experimental_analyzer=false"); - auto info = getQueryTreeInfo(query_info.query_tree, context); - if (info.has_join || info.has_cross_join || info.has_local_columns_in_where) - return QueryProcessingStage::Enum::FetchColumns; + if (object_storage_cluster_join_mode == ObjectStorageClusterJoinMode::LOCAL) + { + auto info = getQueryTreeInfo(query_info.query_tree, context); + if (info.has_join || info.has_cross_join || info.has_local_columns_in_where) + return QueryProcessingStage::Enum::FetchColumns; + } } /// Initiator executes query on remote node. diff --git a/src/Storages/IStorageCluster.h b/src/Storages/IStorageCluster.h index 69ebd0b777fb..6108673d60d4 100644 --- a/src/Storages/IStorageCluster.h +++ b/src/Storages/IStorageCluster.h @@ -63,7 +63,7 @@ class IStorageCluster : public IStorage protected: virtual void updateQueryToSendIfNeeded(ASTPtr & /*query*/, const StorageSnapshotPtr & /*storage_snapshot*/, const ContextPtr & /*context*/) {} - void updateQueryWithJoinToSendIfNeeded(ASTPtr & query_to_send, QueryTreeNodePtr query_tree, const ContextPtr & context); + void updateQueryWithJoinToSendIfNeeded(ASTPtr & query_to_send, SelectQueryInfo query_info, const ContextPtr & context); virtual void updateConfigurationIfNeeded(ContextPtr /* context */) {} @@ -108,6 +108,7 @@ class IStorageCluster : public IStorage LoggerPtr log; String cluster_name; + bool send_external_tables = false; struct QueryTreeInfo { @@ -137,7 +138,8 @@ class ReadFromCluster : public SourceStepWithFilter ASTPtr query_to_send_, QueryProcessingStage::Enum processed_stage_, ClusterPtr cluster_, - LoggerPtr log_) + LoggerPtr log_, + std::optional external_tables_) : SourceStepWithFilter( std::move(sample_block), column_names_, @@ -149,6 +151,7 @@ class ReadFromCluster : public SourceStepWithFilter , processed_stage(processed_stage_) , cluster(std::move(cluster_)) , log(log_) + , external_tables(external_tables_) { } @@ -160,6 +163,7 @@ class ReadFromCluster : public SourceStepWithFilter LoggerPtr log; std::optional extension; + std::optional external_tables; void createExtension(const ActionsDAG::Node * predicate); ContextPtr updateSettings(const Settings & settings); diff --git a/src/Storages/buildQueryTreeForShard.cpp b/src/Storages/buildQueryTreeForShard.cpp index 939dcfdfaa1a..cf43fa25c2f1 100644 --- a/src/Storages/buildQueryTreeForShard.cpp +++ b/src/Storages/buildQueryTreeForShard.cpp @@ -42,6 +42,7 @@ namespace Setting extern const SettingsBool prefer_global_in_and_join; extern const SettingsBool enable_add_distinct_to_in_subqueries; extern const SettingsInt64 optimize_const_name_size; + extern const SettingsObjectStorageClusterJoinMode object_storage_cluster_join_mode; } namespace ErrorCodes @@ -120,8 +121,9 @@ class DistributedProductModeRewriteInJoinVisitor : public InDepthQueryTreeVisito using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - explicit DistributedProductModeRewriteInJoinVisitor(const ContextPtr & context_) + explicit DistributedProductModeRewriteInJoinVisitor(const ContextPtr & context_, bool find_cross_join_) : Base(context_) + , find_cross_join(find_cross_join_) {} struct InFunctionOrJoin @@ -157,9 +159,11 @@ class DistributedProductModeRewriteInJoinVisitor : public InDepthQueryTreeVisito { auto * function_node = node->as(); auto * join_node = node->as(); + CrossJoinNode * cross_join_node = find_cross_join ? node->as() : nullptr; if ((function_node && isNameOfGlobalInFunction(function_node->getFunctionName())) || - (join_node && join_node->getLocality() == JoinLocality::Global)) + (join_node && join_node->getLocality() == JoinLocality::Global) || + cross_join_node) { InFunctionOrJoin in_function_or_join_entry; in_function_or_join_entry.query_node = node; @@ -223,7 +227,9 @@ class DistributedProductModeRewriteInJoinVisitor : public InDepthQueryTreeVisito replacement_table_expression->setTableExpressionModifiers(*table_expression_modifiers); replacement_map.emplace(table_node.get(), std::move(replacement_table_expression)); } - else if ((distributed_product_mode == DistributedProductMode::GLOBAL || getSettings()[Setting::prefer_global_in_and_join]) && + else if ((distributed_product_mode == DistributedProductMode::GLOBAL || + getSettings()[Setting::prefer_global_in_and_join] || + (find_cross_join && getSettings()[Setting::object_storage_cluster_join_mode] == ObjectStorageClusterJoinMode::GLOBAL)) && !in_function_or_join_stack.empty()) { auto * in_or_join_node_to_modify = in_function_or_join_stack.back().query_node.get(); @@ -257,6 +263,8 @@ class DistributedProductModeRewriteInJoinVisitor : public InDepthQueryTreeVisito std::vector in_function_or_join_stack; std::unordered_map replacement_map; std::vector global_in_or_join_nodes; + + bool find_cross_join = false; }; /** Replaces large constant values with `__getScalar` function calls to avoid @@ -504,14 +512,18 @@ QueryTreeNodePtr getSubqueryFromTableExpression( } -QueryTreeNodePtr buildQueryTreeForShard(const PlannerContextPtr & planner_context, QueryTreeNodePtr query_tree_to_modify, bool allow_global_join_for_right_table) +QueryTreeNodePtr buildQueryTreeForShard( + const PlannerContextPtr & planner_context, + QueryTreeNodePtr query_tree_to_modify, + bool allow_global_join_for_right_table, + bool find_cross_join) { CollectColumnSourceToColumnsVisitor collect_column_source_to_columns_visitor; collect_column_source_to_columns_visitor.visit(query_tree_to_modify); const auto & column_source_to_columns = collect_column_source_to_columns_visitor.getColumnSourceToColumns(); - DistributedProductModeRewriteInJoinVisitor visitor(planner_context->getQueryContext()); + DistributedProductModeRewriteInJoinVisitor visitor(planner_context->getQueryContext(), find_cross_join); visitor.visit(query_tree_to_modify); auto replacement_map = visitor.getReplacementMap(); @@ -550,6 +562,24 @@ QueryTreeNodePtr buildQueryTreeForShard(const PlannerContextPtr & planner_contex replacement_map.emplace(join_table_expression.get(), std::move(temporary_table_expression_node)); continue; } + if (auto * cross_join_node = global_in_or_join_node.query_node->as()) + { + auto tables_count = cross_join_node->getTableExpressions().size(); + for (size_t i = 1; i < tables_count; ++i) + { + QueryTreeNodePtr join_table_expression = cross_join_node->getTableExpressions()[i]; + + auto subquery_node = getSubqueryFromTableExpression(join_table_expression, column_source_to_columns, planner_context->getQueryContext()); + + auto temporary_table_expression_node = executeSubqueryNode(subquery_node, + planner_context->getMutableQueryContext(), + global_in_or_join_node.subquery_depth); + temporary_table_expression_node->setAlias(join_table_expression->getAlias()); + + replacement_map.emplace(join_table_expression.get(), std::move(temporary_table_expression_node)); + } + continue; + } if (auto * in_function_node = global_in_or_join_node.query_node->as()) { auto & in_function_subquery_node = in_function_node->getArguments().getNodes().at(1); @@ -661,7 +691,8 @@ class RewriteJoinToGlobalJoinVisitor : public InDepthQueryTreeVisitorWithContext { if (auto * join_node = node->as()) { - bool prefer_local_join = getContext()->getSettingsRef()[Setting::parallel_replicas_prefer_local_join]; + bool prefer_local_join = getContext()->getSettingsRef()[Setting::parallel_replicas_prefer_local_join] + && getContext()->getSettingsRef()[Setting::object_storage_cluster_join_mode] != ObjectStorageClusterJoinMode::GLOBAL; bool should_use_global_join = !prefer_local_join || !allStoragesAreMergeTree(join_node->getRightTableExpression()); if (should_use_global_join) join_node->setLocality(JoinLocality::Global); diff --git a/src/Storages/buildQueryTreeForShard.h b/src/Storages/buildQueryTreeForShard.h index 90cbfd36f660..bcbac10b55e0 100644 --- a/src/Storages/buildQueryTreeForShard.h +++ b/src/Storages/buildQueryTreeForShard.h @@ -16,7 +16,11 @@ using PlannerContextPtr = std::shared_ptr; class Context; using ContextPtr = std::shared_ptr; -QueryTreeNodePtr buildQueryTreeForShard(const PlannerContextPtr & planner_context, QueryTreeNodePtr query_tree_to_modify, bool allow_global_join_for_right_table); +QueryTreeNodePtr buildQueryTreeForShard( + const PlannerContextPtr & planner_context, + QueryTreeNodePtr query_tree_to_modify, + bool allow_global_join_for_right_table, + bool find_cross_join = false); void rewriteJoinToGlobalJoin(QueryTreeNodePtr query_tree_to_modify, ContextPtr context); diff --git a/tests/integration/test_s3_cluster/test.py b/tests/integration/test_s3_cluster/test.py index a1397da6eea6..29900398c3c2 100644 --- a/tests/integration/test_s3_cluster/test.py +++ b/tests/integration/test_s3_cluster/test.py @@ -125,7 +125,7 @@ def started_cluster(): yield cluster finally: - shutil.rmtree(os.path.join(SCRIPT_DIR, "data/generated/")) + shutil.rmtree(os.path.join(SCRIPT_DIR, "data/generated/"), ignore_errors=True) cluster.shutdown() @@ -1034,7 +1034,8 @@ def test_hive_partitioning(started_cluster, allow_experimental_analyzer): node.query("SET allow_experimental_analyzer = DEFAULT") -def test_joins(started_cluster): +@pytest.mark.parametrize("join_mode", ["local", "global"]) +def test_joins(started_cluster, join_mode): node = started_cluster.instances["s0_0_0"] # Table join_table only exists on the node 's0_0_0'. @@ -1068,7 +1069,7 @@ def test_joins(started_cluster): join_table AS t2 ON t1.value = t2.id ORDER BY t1.name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) @@ -1091,7 +1092,7 @@ def test_joins(started_cluster): 'name String, value UInt32, polygon Array(Array(Tuple(Float64, Float64)))') AS t1 ON t1.value = t2.id ORDER BY t1.name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) @@ -1109,7 +1110,7 @@ def test_joins(started_cluster): ON t1.value = t2.id WHERE (t1.value % 2) ORDER BY t1.name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) @@ -1128,7 +1129,7 @@ def test_joins(started_cluster): ON t1.value = t2.id WHERE (t2.id % 2) ORDER BY t1.name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) @@ -1146,27 +1147,29 @@ def test_joins(started_cluster): ON t1.value = t2.id WHERE (t1.value % 2) AND ((t2.id % 3) == 2) ORDER BY t1.name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) res = list(map(str.split, result5.splitlines())) assert len(res) == 6 + # With WHERE clause with global subquery result6 = node.query( f""" SELECT name FROM s3Cluster('cluster_simple', 'http://minio1:9001/root/data/{{clickhouse,database}}/*', 'minio', '{minio_secret_key}', 'CSV', 'name String, value UInt32, polygon Array(Array(Tuple(Float64, Float64)))') - WHERE value IN (SELECT id FROM join_table) + WHERE value GLOBAL IN (SELECT id FROM join_table) ORDER BY name - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) res = list(map(str.split, result6.splitlines())) assert len(res) == 25 + # With WHERE clause without columns in condition result7 = node.query( f""" SELECT count() FROM @@ -1177,11 +1180,12 @@ def test_joins(started_cluster): join_table AS t2 ON 1 GROUP BY ALL - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) assert result7.strip() == "625" + # With WHERE clause without columns in condition and with local column in SELECT result8 = node.query( f""" SELECT count(), t2.id FROM @@ -1192,7 +1196,7 @@ def test_joins(started_cluster): join_table AS t2 ON 1 GROUP BY ALL - SETTINGS object_storage_cluster_join_mode='local'; + SETTINGS object_storage_cluster_join_mode='{join_mode}'; """ ) res = list(map(str.split, result8.splitlines()))