diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index fefbf134bd11..66b173dc3635 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -40,7 +40,9 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_functions_nested::expr_fn::{array_has, array_max, array_min}; +use datafusion_functions_nested::expr_fn::{ + array_has, array_max, array_min, array_position, cardinality, +}; mod binary_op; mod function; @@ -610,7 +612,7 @@ impl SqlToRel<'_, S> { _ => { let left_expr = self.sql_to_expr(*left, schema, planner_context)?; let right_expr = self.sql_to_expr(*right, schema, planner_context)?; - plan_any_op(left_expr, right_expr, &compare_op) + plan_quantified_op(&left_expr, &right_expr, &compare_op, false) } }, SQLExpr::AllOp { @@ -626,7 +628,11 @@ impl SqlToRel<'_, S> { schema, planner_context, ), - _ => not_impl_err!("ALL only supports subquery comparison currently"), + _ => { + let left_expr = self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = self.sql_to_expr(*right, schema, planner_context)?; + plan_quantified_op(&left_expr, &right_expr, &compare_op, true) + } }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { @@ -1234,58 +1240,80 @@ impl SqlToRel<'_, S> { } } -/// Builds a CASE expression that handles NULL semantics for `x ANY(arr)`: +/// Plans `left_expr ANY/ALL(right_expr)` with proper SQL NULL semantics. /// -/// ```text -/// CASE -/// WHEN (arr) IS NOT NULL THEN -/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array -/// ELSE NULL -- NULL array -/// END -/// ``` -fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result { - when(bound.is_not_null(), comparison) - .when(arr.is_not_null(), lit(false)) - .otherwise(lit(ScalarValue::Boolean(None))) -} - -/// Plans a ` ANY()` expression for non-subquery operands. -fn plan_any_op( - left_expr: Expr, - right_expr: Expr, +/// When `is_all` is false (ANY): returns TRUE if any element satisfies the condition. +/// When `is_all` is true (ALL): returns TRUE if all elements satisfy the condition. +/// +/// CASE/WHEN structure: +/// WHEN arr IS NULL → NULL +/// WHEN empty → is_all (ANY:false, ALL:true) +/// WHEN lhs IS NULL → NULL +/// WHEN decisive_condition → !is_all (ANY:true match found, ALL:false violation found) +/// WHEN has_nulls → NULL +/// ELSE → is_all (ANY:false, ALL:true) +fn plan_quantified_op( + left_expr: &Expr, + right_expr: &Expr, compare_op: &BinaryOperator, + is_all: bool, ) -> Result { - match compare_op { - BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)), - BinaryOperator::NotEq => { - let min = array_min(right_expr.clone()); - let max = array_max(right_expr.clone()); - // NOT EQ is true when either bound differs from left - let comparison = min - .not_eq(left_expr.clone()) - .or(max.clone().not_eq(left_expr)); - any_op_with_null_handling(max, comparison, right_expr) + let null_arr_check = right_expr.clone().is_null(); + let empty_check = cardinality(right_expr.clone()).eq(lit(0u64)); + let null_lhs_check = left_expr.clone().is_null(); + // DataFusion's array_position uses is_null() checks internally (not equality), + // so it can locate NULL elements even though NULL = NULL is NULL in standard SQL. + let has_nulls = array_position(right_expr.clone(), lit(ScalarValue::Null), lit(1i64)) + .is_not_null(); + + let decisive_condition = match (compare_op, is_all) { + (BinaryOperator::Eq, false) => array_has(right_expr.clone(), left_expr.clone()), + (BinaryOperator::NotEq, true) => array_has(right_expr.clone(), left_expr.clone()), + (BinaryOperator::Eq, true) | (BinaryOperator::NotEq, false) => { + let all_equal = array_min(right_expr.clone()) + .eq(left_expr.clone()) + .and(array_max(right_expr.clone()).eq(left_expr.clone())); + Expr::Not(Box::new(all_equal)) } - BinaryOperator::Gt => { - let min = array_min(right_expr.clone()); - any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr) + (BinaryOperator::Gt, false) => { + left_expr.clone().gt(array_min(right_expr.clone())) } - BinaryOperator::Lt => { - let max = array_max(right_expr.clone()); - any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr) + (BinaryOperator::Gt, true) => Expr::Not(Box::new( + left_expr.clone().gt(array_max(right_expr.clone())), + )), + (BinaryOperator::Lt, false) => { + left_expr.clone().lt(array_max(right_expr.clone())) } - BinaryOperator::GtEq => { - let min = array_min(right_expr.clone()); - any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr) + (BinaryOperator::Lt, true) => Expr::Not(Box::new( + left_expr.clone().lt(array_min(right_expr.clone())), + )), + (BinaryOperator::GtEq, false) => { + left_expr.clone().gt_eq(array_min(right_expr.clone())) } - BinaryOperator::LtEq => { - let max = array_max(right_expr.clone()); - any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr) + (BinaryOperator::GtEq, true) => Expr::Not(Box::new( + left_expr.clone().gt_eq(array_max(right_expr.clone())), + )), + (BinaryOperator::LtEq, false) => { + left_expr.clone().lt_eq(array_max(right_expr.clone())) } - _ => plan_err!( - "Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" - ), - } + (BinaryOperator::LtEq, true) => Expr::Not(Box::new( + left_expr.clone().lt_eq(array_min(right_expr.clone())), + )), + _ => { + let quantifier = if is_all { "AllOp" } else { "AnyOp" }; + return plan_err!( + "Unsupported {quantifier}: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" + ); + } + }; + + let null_bool = lit(ScalarValue::Boolean(None)); + when(null_arr_check, null_bool.clone()) + .when(empty_check, lit(is_all)) + .when(null_lhs_check, null_bool.clone()) + .when(decisive_condition, lit(!is_all)) + .when(has_nulls, null_bool) + .otherwise(lit(is_all)) } #[cfg(test)] diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 0dad48b16897..eb6c99804d0f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -367,7 +367,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError> sql: "select left from array where 1 = any(left);", parser_dialect: GenericDialect {}, unparser_dialect: UnparserPostgreSqlDialect {}, - expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#, + expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#, ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 25136ca777c7..a04fe30758ea 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -7165,16 +7165,18 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: generate_series() projection=[value] +06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL +07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3 +08)--------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] -06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[] +06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3] +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -7402,26 +7404,26 @@ select 5 <= any(make_array()); false # Mixed NULL + non-NULL array where no non-NULL element satisfies the condition -# These return false (NULLs are skipped by array_min/array_max) +# These return NULL because NULLs leave the result indeterminate query B select 5 > any(make_array(6, NULL)); ---- -false +NULL query B select 5 < any(make_array(3, NULL)); ---- -false +NULL query B select 5 >= any(make_array(6, NULL)); ---- -false +NULL query B select 5 <= any(make_array(3, NULL)); ---- -false +NULL # Mixed NULL + non-NULL array where a non-NULL element satisfies the condition query B @@ -7452,33 +7454,38 @@ true query B select 5 <> any(make_array(5, NULL)); ---- -false +NULL -# All-NULL array: all operators should return false +# All-NULL array: all operators should return NULL (unknown comparison) query B select 5 > any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 < any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 >= any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 <= any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 <> any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL + +query B +select 5 = any(make_array(NULL::INT, NULL::INT)); +---- +NULL # NULL left operand: should return NULL for non-empty arrays query B @@ -7538,9 +7545,243 @@ select 5 <> any(NULL::INT[]); ---- NULL +query B +select 5 = any(NULL::INT[]); +---- +NULL + +# NULL = ANY with non-empty array +query B +select NULL = any(make_array(1, 2, 3)); +---- +NULL + +# = ANY with no match, no NULLs +query B +select 5 = any(make_array(1, 2, 3)); +---- +false + +# = ANY with mixed NULL (satisfying) returns TRUE +query B +select 5 = any(make_array(5, NULL)); +---- +true + +# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate +query B +select 5 = any(make_array(1, 2, NULL)); +---- +NULL + statement ok DROP TABLE any_op_test; +## all operator + +# = ALL: true when all elements equal val +query B +select 5 = ALL(make_array(5, 5, 5)); +---- +true + +query B +select 5 = ALL(make_array(5, 5, 3)); +---- +false + +# <> ALL: true when val differs from every element +query B +select 5 <> ALL(make_array(1, 2, 3)); +---- +true + +query B +select 5 <> ALL(make_array(1, 2, 5)); +---- +false + +# > ALL: true when val greater than all elements +query B +select 10 > ALL(make_array(1, 2, 3)); +---- +true + +query B +select 3 > ALL(make_array(1, 2, 3)); +---- +false + +# < ALL: true when val less than all elements +query B +select 0 < ALL(make_array(1, 2, 3)); +---- +true + +query B +select 2 < ALL(make_array(1, 2, 3)); +---- +false + +# >= ALL: true when val >= all elements +query B +select 5 >= ALL(make_array(1, 2, 5)); +---- +true + +query B +select 4 >= ALL(make_array(1, 2, 5)); +---- +false + +# <= ALL: true when val <= all elements +query B +select 1 <= ALL(make_array(1, 2, 5)); +---- +true + +query B +select 2 <= ALL(make_array(1, 2, 5)); +---- +false + +# Empty arrays: all operators return TRUE (vacuous truth) +query B +select 5 = ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 <> ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 > ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 < ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 >= ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 <= ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +# NULL LHS with empty array returns TRUE (vacuous truth) +query B +select NULL = ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +# NULL LHS with non-empty array returns NULL +query B +select NULL = ALL(make_array(1, 2, 3)); +---- +NULL + +query B +select NULL > ALL(make_array(1, 2, 3)); +---- +NULL + +query B +select NULL <> ALL(make_array(1, 2, 3)); +---- +NULL + +# All-NULL arrays: returns NULL +query B +select 5 = ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 <> ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 > ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 < ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 >= ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 <= ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +# Mixed NULL + non-NULL (non-NULL elements satisfy, but NULLs present → NULL) +query B +select 5 > ALL(make_array(3, NULL)); +---- +NULL + +query B +select 5 >= ALL(make_array(5, NULL)); +---- +NULL + +query B +select 1 < ALL(make_array(3, NULL)); +---- +NULL + +query B +select 1 <= ALL(make_array(1, NULL)); +---- +NULL + +# Mixed NULL + non-NULL (not satisfying condition → FALSE wins over NULL) +query B +select 5 > ALL(make_array(6, NULL)); +---- +false + +query B +select 5 < ALL(make_array(3, NULL)); +---- +false + +query B +select 5 = ALL(make_array(5, 3, NULL)); +---- +false + +# NULL array input returns NULL +query B +select 5 = ALL(NULL::INT[]); +---- +NULL + +query B +select 5 > ALL(NULL::INT[]); +---- +NULL + +query B +select 5 < ALL(NULL::INT[]); +---- +NULL + ## array_distinct #TODO: https://github.com/apache/datafusion/issues/7142