diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1e7c02e424256..7dafd58160ae4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -31,6 +31,15 @@ macro_rules! handle_transform_recursion { }}; } +/// These macros are used to determine continuation during transforming traversals. +macro_rules! handle_transform_recursion_in_scope { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_children_in_scope($F_CHILD))? + .transform_parent($F_UP) + }}; +} + /// API for inspecting and rewriting tree data structures. /// /// The `TreeNode` API is used to express algorithms separately from traversing @@ -435,6 +444,286 @@ pub trait TreeNode: Sized { ) -> Result>; } +/// API for inspecting and rewriting tree data structures. +/// +/// See [`TreeNode`] for more details. +/// +/// This add the notion of scopes to [`TreeNode`] and allow you to operate in that. +/// +/// Scope is left for implementers to define, for `PhysicalExpr` child is defined in scope if it have the same input schema as current `PhysicalExpr`. +pub trait ScopedTreeNode: TreeNode { + /// Visit the tree node with a [`TreeNodeVisitor`], performing a + /// depth-first walk of the node and its children that are in the same scope. + /// + /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before + /// children are visited), [`TreeNodeVisitor::f_up()`] is called in + /// bottom-up order (after children are visited). + /// + /// # Return Value + /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details. + /// + /// # See Also: + /// * [`Self::apply_in_scope`] for inspecting nodes with a closure + /// * [`Self::rewrite_in_scope`] to rewrite owned `ScopedTreeNode`s + /// + /// # Example + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeVisitor::f_down(ParentNode) + /// TreeNodeVisitor::f_down(ChildNode1) + /// TreeNodeVisitor::f_up(ChildNode1) + /// TreeNodeVisitor::f_down(ChildNode2) + /// TreeNodeVisitor::f_up(ChildNode2) + /// TreeNodeVisitor::f_up(ParentNode) + /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn visit_in_scope<'n, V: TreeNodeVisitor<'n, Node = Self>>( + &'n self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_children_in_scope(|c| c.visit_in_scope(visitor)) + })? + .visit_parent(|| visitor.f_up(self)) + } + + /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a + /// depth-first walk of the node and its children that are in the same scope. + /// + /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before + /// children are visited), [`TreeNodeRewriter::f_up()`] is called in + /// bottom-up order (after children are visited). + /// + /// Note: If using the default [`TreeNodeRewriter::f_up`] or + /// [`TreeNodeRewriter::f_down`] that do nothing, consider using + /// [`Self::transform_down_in_scope`] instead. + /// + /// # Return Value + /// The returns value specifies how the tree walk should proceed. See + /// [`TreeNodeRecursion`] for details. If an [`Err`] is returned, the + /// recursion stops immediately. + /// + /// # See Also + /// * [`Self::visit_in_scope`] for inspecting (without modification) `ScopedTreeNode`s + /// * [Self::transform_down_up_in_scope] for a combined top-down and bottom-up traversal. + /// * [Self::transform_down_in_scope] for a top-down (pre-order) traversal. + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. + /// + /// # Example + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) + /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn rewrite_in_scope>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion_in_scope!( + rewriter.f_down(self), + |c| c.rewrite_in_scope(rewriter), + |n| { rewriter.f_up(n) } + ) + } + + /// Applies `f` to the node then each of its children that are in the + /// same scope, recursively (a top-down, pre-order traversal). + /// + /// The return [`TreeNodeRecursion`] controls the recursion and can cause + /// an early return. + /// + /// # See Also + /// * [`Self::transform_down_in_scope`] for the equivalent transformation API. + /// * [`Self::visit_in_scope`] for both top-down and bottom up traversal. + fn apply_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_impl< + 'n, + N: ScopedTreeNode, + F: FnMut(&'n N) -> Result, + >( + node: &'n N, + f: &mut F, + ) -> Result { + f(node)?.visit_children(|| node.apply_children_in_scope(|c| apply_impl(c, f))) + } + + apply_impl(self, &mut f) + } + + /// Recursively rewrite the node's children in scope and then the node + /// using `f` (a bottom-up post-order traversal). + /// + /// A synonym of [`Self::transform_up_in_scope`]. + fn transform_in_scope Result>>( + self, + f: F, + ) -> Result> { + self.transform_up_in_scope(f) + } + + /// Recursively rewrite the tree using `f` in a top-down (pre-order) + /// fashion, limited to children in the same scope. + /// + /// `f` is applied to the node first, and then its children in scope. + /// + /// # See Also + /// * [`TreeNode::transform_down`] for the same transformation but in all children ignoring scope + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. + /// * [Self::transform_down_up_in_scope] for a combined traversal with closures + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor + fn transform_down_in_scope Result>>( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_impl< + N: ScopedTreeNode, + F: FnMut(N) -> Result>, + >( + node: N, + f: &mut F, + ) -> Result> { + f(node)?.transform_children(|n| { + n.map_children_in_scope(|c| transform_down_impl(c, f)) + }) + } + + transform_down_impl(self, &mut f) + } + + /// Recursively rewrite the node using `f` in a bottom-up (post-order) + /// fashion, limited to children in the same scope. + /// + /// `f` is applied to the node's children in scope first, and then to the node itself. + /// + /// # See Also + /// * [`Self::transform_down_in_scope`] for a top-down (pre-order) traversal. + /// * [Self::transform_down_up_in_scope] for a combined traversal with closures + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor + fn transform_up_in_scope Result>>( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_impl Result>>( + node: N, + f: &mut F, + ) -> Result> { + node.map_children_in_scope(|c| transform_up_impl(c, f))? + .transform_parent(f) + } + + transform_up_impl(self, &mut f) + } + + /// Transforms the node using `f_down` while traversing the tree top-down + /// (pre-order), and using `f_up` while traversing the tree bottom-up + /// (post-order), limited to children in the same scope. + /// + /// Same as [`TreeNode::transform_down_up`] but limited to the same scope. + /// + /// # See Also + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. + /// * [Self::transform_down_in_scope] for a top-down (pre-order) traversal. + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor + fn transform_down_up_in_scope< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + mut f_down: FD, + mut f_up: FU, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_up_impl< + N: ScopedTreeNode, + FD: FnMut(N) -> Result>, + FU: FnMut(N) -> Result>, + >( + node: N, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion_in_scope!( + f_down(node), + |c| transform_down_up_impl(c, f_down, f_up), + f_up + ) + } + + transform_down_up_impl(self, &mut f_down, &mut f_up) + } + + /// Returns true if `f` returns true for any node in the tree + /// that is in the same scope. + /// + /// Stops recursion as soon as a matching node is found + fn exists_in_scope Result>(&self, mut f: F) -> Result { + let mut found = false; + self.apply_in_scope(|n| { + Ok(if f(n)? { + found = true; + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }) + .map(|_| found) + } + + /// Low-level API used to implement other APIs. + /// + /// If you want to implement the [`ScopedTreeNode`] trait for your own type, you + /// should implement this method and [`Self::map_children_in_scope`]. + /// + /// Users should use one of the higher level APIs described on [`Self`]. + /// + /// Description: Apply `f` to inspect node's children that are in the same scope as this node (but not the node + /// itself), scope is defined by the node. + fn apply_children_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + f: F, + ) -> Result; + + /// Low-level API used to implement other APIs. + /// + /// If you want to implement the [`ScopedTreeNode`] trait for your own type, you + /// should implement this method and [`Self::apply_children_in_scope`]. + /// + /// Users should use one of the higher level APIs described on [`Self`]. + /// + /// Description: Apply `f` to rewrite the node's children in scope (but not the node itself). + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result>; +} + /// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively /// inspecting [`TreeNode`]s via [`TreeNode::visit`]. /// @@ -1293,6 +1582,57 @@ impl TreeNode for Arc { } } +/// Helper trait for implementing [`ScopedTreeNode`] that have children stored as +/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// its related `Arc` will automatically implement [`ScopedTreeNode`]. +pub trait DynScopedTreeNode: DynTreeNode { + /// Returns all children of the specified `ScopedTreeNode`. + fn arc_children_in_scope(&self) -> Vec<&Arc>; + + /// Constructs a new node with the specified children in scope. + fn with_new_arc_children_in_scope( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result>; +} + +/// Blanket implementation for any `Arc` where `T` implements [`DynScopedTreeNode`] +/// (such as [`Arc`]). +impl ScopedTreeNode for Arc { + fn apply_children_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + f: F, + ) -> Result { + self.arc_children_in_scope().into_iter().apply_until_stop(f) + } + + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result> { + let children_in_scope = self.arc_children_in_scope(); + if !children_in_scope.is_empty() { + let new_children_in_scope = children_in_scope + .into_iter() + .cloned() + .map_until_stop_and_collect(f)?; + // Propagate up `new_children_in_scope.transformed` and `new_children_in_scope.tnr` + // along with the node containing transformed children. + if new_children_in_scope.transformed { + let arc_self = Arc::clone(&self); + new_children_in_scope.map_data(|new_children_in_scope| { + self.with_new_arc_children_in_scope(arc_self, new_children_in_scope) + }) + } else { + Ok(Transformed::new(self, false, new_children_in_scope.tnr)) + } + } else { + Ok(Transformed::no(self)) + } + } +} + /// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for /// trees that contain nodes with payloads. This approach ensures safe execution of algorithms /// involving payloads, by enforcing rules for detaching and reattaching child nodes. @@ -1338,24 +1678,31 @@ pub(crate) mod tests { use crate::Result; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, + ScopedTreeNode, Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; #[derive(Debug, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { pub(crate) children: Vec>, + pub(crate) children_in_same_scope: Vec>, pub(crate) data: T, } - impl TestTreeNode { + impl TestTreeNode { + /// Creates a node where all children are in the same scope. pub(crate) fn new(children: Vec>, data: T) -> Self { - Self { children, data } + Self { + children_in_same_scope: children.clone(), + children, + data, + } } pub(crate) fn new_leaf(data: T) -> Self { Self { children: vec![], + children_in_same_scope: vec![], data, } } @@ -1363,6 +1710,39 @@ pub(crate) mod tests { pub(crate) fn is_leaf(&self) -> bool { self.children.is_empty() } + + /// Strip children_in_new_scope recursively - used to compare trees + /// in TreeNode tests where children_in_new_scope is not relevant. + fn strip_scope(self) -> Self { + Self { + children: self.children.into_iter().map(|c| c.strip_scope()).collect(), + children_in_same_scope: vec![], + data: self.data, + } + } + } + + impl TestTreeNode { + /// Creates a node with explicit `children` (all, in order). + /// `out_of_scope_children` are children that start a new scope + /// (i.e., NOT in the current node's scope). The remaining children + /// are computed as `children_in_same_scope`. + pub(crate) fn new_mixed( + all_children: Vec>, + out_of_scope_children: Vec>, + data: T, + ) -> Self { + let children_in_same_scope = all_children + .iter() + .filter(|c| !out_of_scope_children.contains(c)) + .cloned() + .collect(); + Self { + children: all_children, + children_in_same_scope, + data, + } + } } impl TreeNode for TestTreeNode { @@ -1387,6 +1767,31 @@ pub(crate) mod tests { } } + impl ScopedTreeNode for TestTreeNode { + fn apply_children_in_scope< + 'n, + F: FnMut(&'n Self) -> Result, + >( + &'n self, + f: F, + ) -> Result { + self.children_in_same_scope.apply_elements(f) + } + + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result> { + Ok(self + .children_in_same_scope + .map_elements(f)? + .update_data(|new_children| Self { + children_in_same_scope: new_children, + ..self + })) + } + } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { fn apply_elements Result>( &'a self, @@ -1403,28 +1808,42 @@ pub(crate) mod tests { } } - // J - // | - // I - // | - // F - // / \ - // E G - // | | - // C H - // / \ - // B D - // | - // A + // J + // | + // I + // | + // F (mixed) + // / \ + // E (new scope) G (same scope as F) + // | | + // C (mixed) H + // / \ + // B (Same scope as C) D (new scope) + // | + // A + // + // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, D, A, G, H + // (new/new_mixed set both children and children_in_same_scope) + // + // ScopedTreeNode (children_in_same_scope) traversal visits: J, I, F, G, H + // (skips E which is out of F's scope; skips C, B, D, A which are under E) fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_c = TestTreeNode::new_mixed( + vec![node_b, node_d.clone()], + vec![node_d], + "c".to_string(), + ); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); - let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_f = TestTreeNode::new_mixed( + vec![node_e.clone(), node_g], + vec![node_e], + "f".to_string(), + ); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); TestTreeNode::new(vec![node_i], "j".to_string()) } @@ -1970,7 +2389,7 @@ pub(crate) mod tests { } } - fn transform_yes>( + fn transform_yes + Clone>( transformation_name: N, ) -> impl FnMut(TestTreeNode) -> Result>> { move |node| { @@ -1983,7 +2402,7 @@ pub(crate) mod tests { fn transform_and_event_on< N: Display, - T: PartialEq + Display + From, + T: PartialEq + Display + From + Clone, D: Into, >( transformation_name: N, @@ -2004,13 +2423,30 @@ pub(crate) mod tests { } } + /// Like `transform_yes`, but preserves `children_in_same_scope` from the original node, + /// so scope boundaries are not lost during scoped traversal. + fn transform_yes_scoped + Clone>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(Transformed::yes(TestTreeNode { + children_in_same_scope: node.children_in_same_scope.clone(), + children: node.children, + data: format!("{}({})", transformation_name, node.data).into(), + })) + } + } + macro_rules! rewrite_test { ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); - assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE); + let actual = tree.rewrite(&mut rewriter)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2022,7 +2458,10 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE); + let actual = tree.transform_down_up($F_DOWN, $F_UP)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2034,7 +2473,10 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE); + let actual = tree.transform_down($F)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2046,7 +2488,10 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE); + let actual = tree.transform_up($F)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2423,7 +2868,13 @@ pub(crate) mod tests { fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); for i in 0..3000 { - item = TestTreeNode::new(vec![item], format!("parent-{i}")); + // Avoid TestTreeNode::new() here which clones children into + // children_in_same_scope - that would be O(n^2) for a deep chain. + item = TestTreeNode { + children: vec![item], + children_in_same_scope: vec![], + data: format!("parent-{i}"), + }; } let mut visitor = @@ -2431,4 +2882,207 @@ pub(crate) mod tests { item.visit(&mut visitor).unwrap(); } + + // ===================================================================== + // ScopedTreeNode tests + // + // The scoped tree has 3 nested scopes. Scoped traversal of each scope + // should produce the same result as non-scoped traversal of a tree + // containing only that scope's nodes. + // + // A + // | + // B + // / \ + // C F (mixed) + // / \ / \ + // D E G H (new scope) + // | + // / \ + // I J (mixed) + // / \ + // K L (new scope) + // | + // M + // + // ScopedTreeNode traversal from A: A, B, C, D, E, F, G + // ScopedTreeNode traversal from H: H, I, J, K + // ScopedTreeNode traversal from L: L, M + // ===================================================================== + + /// Full tree with scope boundaries. + fn scoped_test_tree() -> TestTreeNode { + let d = TestTreeNode::new_leaf("d".to_string()); + let e = TestTreeNode::new_leaf("e".to_string()); + let g = TestTreeNode::new_leaf("g".to_string()); + let c = TestTreeNode::new(vec![d, e], "c".to_string()); + + let m = TestTreeNode::new_leaf("m".to_string()); + let l = TestTreeNode::new(vec![m], "l".to_string()); + + let i = TestTreeNode::new_leaf("i".to_string()); + let k = TestTreeNode::new_leaf("k".to_string()); + let j = + TestTreeNode::new_mixed(vec![k.clone(), l.clone()], vec![l], "j".to_string()); + let h = TestTreeNode::new(vec![i, j], "h".to_string()); + + let f = + TestTreeNode::new_mixed(vec![g.clone(), h.clone()], vec![h], "f".to_string()); + let b = TestTreeNode::new(vec![c, f], "b".to_string()); + TestTreeNode::new(vec![b], "a".to_string()) + } + + /// Build a non-scoped tree containing only the in-scope nodes, + /// by following `children_in_same_scope` recursively. + fn extract_scope_tree(node: &TestTreeNode) -> TestTreeNode { + let children: Vec<_> = node + .children_in_same_scope + .iter() + .map(extract_scope_tree) + .collect(); + TestTreeNode::new(children, node.data.clone()) + } + + fn collect_scoped_data(node: &TestTreeNode) -> Vec { + let mut result = vec![node.data.clone()]; + for child in &node.children_in_same_scope { + result.extend(collect_scoped_data(child)); + } + result + } + + fn collect_children_data(node: &TestTreeNode) -> Vec { + let mut result = vec![node.data.clone()]; + for child in &node.children { + result.extend(collect_children_data(child)); + } + result + } + + /// Collect references to every node in the tree (DFS through all children). + fn all_nodes(node: &TestTreeNode) -> Vec<&TestTreeNode> { + let mut result = vec![node]; + for child in &node.children { + result.extend(all_nodes(child)); + } + result + } + + /// For a given node, assert that all scoped traversal functions produce + /// the same result as their non-scoped counterparts on the in-scope subtree. + fn assert_all_scoped_traversals_match(scoped: &TestTreeNode) { + let equivalent = extract_scope_tree(scoped); + + // visit_in_scope == visit + let mut sv = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + scoped.visit_in_scope(&mut sv).unwrap(); + let mut ev = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + equivalent.visit(&mut ev).unwrap(); + assert_eq!(sv.visits, ev.visits, "visit mismatch for {}", scoped.data); + + // apply_in_scope == apply + let mut s_apply = vec![]; + scoped + .apply_in_scope(|n| { + s_apply.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + let mut e_apply = vec![]; + equivalent + .apply(|n| { + e_apply.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + assert_eq!(s_apply, e_apply, "apply mismatch for {}", scoped.data); + + // exists_in_scope == exists for each in-scope node + for name in &e_apply { + assert_eq!( + scoped.exists_in_scope(|n| Ok(&n.data == name)).unwrap(), + equivalent.exists(|n| Ok(&n.data == name)).unwrap(), + "exists mismatch for node {name} in scope {}", + scoped.data, + ); + } + + // transform_down_in_scope == transform_down + let s_td = scoped + .clone() + .transform_down_in_scope(transform_yes_scoped("tx")) + .unwrap(); + let e_td = equivalent + .clone() + .transform_down(transform_yes("tx")) + .unwrap(); + assert_eq!( + collect_scoped_data(&s_td.data), + collect_children_data(&e_td.data), + "transform_down mismatch for {}", + scoped.data, + ); + + // transform_up_in_scope == transform_up + let s_tu = scoped + .clone() + .transform_up_in_scope(transform_yes_scoped("tx")) + .unwrap(); + let e_tu = equivalent + .clone() + .transform_up(transform_yes("tx")) + .unwrap(); + assert_eq!( + collect_scoped_data(&s_tu.data), + collect_children_data(&e_tu.data), + "transform_up mismatch for {}", + scoped.data, + ); + + // transform_down_up_in_scope == transform_down_up + let s_tdu = scoped + .clone() + .transform_down_up_in_scope( + transform_yes_scoped("f_down"), + transform_yes_scoped("f_up"), + ) + .unwrap(); + let e_tdu = equivalent + .clone() + .transform_down_up(transform_yes("f_down"), transform_yes("f_up")) + .unwrap(); + assert_eq!( + collect_scoped_data(&s_tdu.data), + collect_children_data(&e_tdu.data), + "transform_down_up mismatch for {}", + scoped.data, + ); + + // rewrite_in_scope == rewrite + let mut sr = TestRewriter::new( + Box::new(transform_yes_scoped("f_down")), + Box::new(transform_yes_scoped("f_up")), + ); + let s_rw = scoped.clone().rewrite_in_scope(&mut sr).unwrap(); + let mut er = TestRewriter::new( + Box::new(transform_yes("f_down")), + Box::new(transform_yes("f_up")), + ); + let e_rw = equivalent.rewrite(&mut er).unwrap(); + assert_eq!( + collect_scoped_data(&s_rw.data), + collect_children_data(&e_rw.data), + "rewrite mismatch for {}", + scoped.data, + ); + } + + #[test] + fn test_scoped_traversal_matches_non_scoped() -> Result<()> { + let tree = scoped_test_tree(); + for node in all_nodes(&tree) { + assert_all_scoped_traversals_match(node); + } + Ok(()) + } } diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 3a255ae05f76f..aac450af766f0 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use datafusion_common::tree_node::ScopedTreeNode; use datafusion_common::{ DataFusionError, Result, ScalarValue, exec_err, metadata::FieldMetadata, @@ -69,7 +70,7 @@ where K: Borrow + Eq + Hash, V: Borrow, { - expr.transform_down(|expr| { + expr.transform_down_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() && let Some(replacement_value) = replacements.get(column.name()) { @@ -656,6 +657,7 @@ mod tests { use datafusion_common::{assert_contains, record_batch}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{Column, Literal, col, lit}; + use std::hash::Hash; fn create_test_schema() -> (Schema, Schema) { let physical_schema = Schema::new(vec![ @@ -1733,4 +1735,130 @@ mod tests { assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32); assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64); } + + /// A mock expression with an in-scope child and an out-of-scope child. + /// Used to verify that scoped traversal does not modify out-of-scope children. + #[derive(Debug, Clone)] + struct ScopedExprMock { + in_scope_child: Arc, + out_of_scope_child: Arc, + } + + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl std::fmt::Display for ScopedExprMock { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field(&self, input_schema: &Schema) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate( + &self, + _batch: &RecordBatch, + ) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + #[test] + fn test_replace_columns_with_literals_does_not_modify_out_of_scope_children() { + // The in-scope child references column "a" which should be replaced + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child also references a column "a" but should NOT be replaced + let out_of_scope_child: Arc = Arc::new(Column::new("a", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let mut replacements = HashMap::new(); + replacements.insert("a", ScalarValue::Int32(Some(42))); + + let result = replace_columns_with_literals(expr, &replacements).unwrap(); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a" should be replaced with literal 42 + let in_scope_lit = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be replaced with Literal"); + assert_eq!(in_scope_lit.value(), &ScalarValue::Int32(Some(42))); + + // The out-of-scope child "a@0" should be UNCHANGED (still a Column) + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "a"); + assert_eq!(out_of_scope_col.index(), 0); + } } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 7107b0a9004d3..0c3b98e16ef2e 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -165,12 +165,37 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Get a list of child PhysicalExpr that provide the input for this expr. fn children(&self) -> Vec<&Arc>; + /// Get a list of child PhysicalExpr that provide the input for this expr that are in the same scope as this expression. + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`] + /// + /// To know if specific child is considered in the same scope you can answer this simple question: + /// If that child is a `Column` would that column can be evaluated with the same input schema + /// Expressions like `plus`, `sum`, etc have all children in scope. + /// Lambda expressions like `array_filter(list, value -> value + 1)`, have the `list` in the same scope and the lambda function in different scope + fn children_in_scope(&self) -> Vec<&Arc> { + self.children() + } + /// Returns a new PhysicalExpr where all children were replaced by new exprs. fn with_new_children( self: Arc, children: Vec>, ) -> Result>; + /// Returns a new PhysicalExpr where all scoped children were replaced by new exprs. + /// + /// See [`Self::children_in_scope`] for definition of what child considered a scope + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`] + /// + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + self.with_new_children(children_in_scope) + } + /// Computes the output interval for the expression, given the input /// intervals. /// @@ -486,6 +511,30 @@ pub fn with_new_children_if_necessary( Ok(expr) } } +/// Returns a copy of this expr if we change any child according to the pointer comparison. +/// The size of `children_in_scope` must be equal to the size of [`PhysicalExpr::children_in_scope()`]. +pub fn with_new_children_in_scope_if_necessary( + expr: Arc, + children_in_scope: Vec>, +) -> Result> { + let old_children_in_scope = expr.children_in_scope(); + assert_eq_or_internal_err!( + children_in_scope.len(), + old_children_in_scope.len(), + "PhysicalExpr: Wrong number of children in scope" + ); + + if children_in_scope.is_empty() + || children_in_scope + .iter() + .zip(old_children_in_scope.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + { + Ok(expr.with_new_children_in_scope(children_in_scope)?) + } else { + Ok(expr) + } +} /// Returns [`Display`] able a list of [`PhysicalExpr`] /// diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs index 6c7d04a22535f..c61ce91e4024c 100644 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ b/datafusion/physical-expr-common/src/tree_node.rs @@ -20,10 +20,12 @@ use std::fmt::{self, Display, Formatter}; use std::sync::Arc; -use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary}; +use crate::physical_expr::{ + PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary, +}; use datafusion_common::Result; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynScopedTreeNode, DynTreeNode}; impl DynTreeNode for dyn PhysicalExpr { fn arc_children(&self) -> Vec<&Arc> { @@ -39,6 +41,20 @@ impl DynTreeNode for dyn PhysicalExpr { } } +impl DynScopedTreeNode for dyn PhysicalExpr { + fn arc_children_in_scope(&self) -> Vec<&Arc> { + self.children_in_scope() + } + + fn with_new_arc_children_in_scope( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_in_scope_if_necessary(arc_self, new_children) + } +} + /// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are /// two ways to access child plans—directly from the plan and through child nodes—it's /// recommended to perform mutable operations via [`Self::update_expr_from_children`]. diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 4ac40df2201e5..5becfed9d5b47 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -40,7 +40,7 @@ use std::{any::Any, sync::Arc}; use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; @@ -130,7 +130,7 @@ impl CaseBody { // Determine the set of columns that are used in all the expressions of the case body. let mut used_column_indices = IndexSet::::new(); let mut collect_column_indices = |expr: &Arc| { - expr.apply(|expr| { + expr.apply_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { used_column_indices.insert(column.index()); } @@ -161,7 +161,7 @@ impl CaseBody { // using the column index mapping. let project = |expr: &Arc| -> Result> { Arc::clone(expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { if let Some(column) = e.as_any().downcast_ref::() { let original = column.index(); let projected = *column_index_map.get(&original).unwrap(); @@ -1397,7 +1397,7 @@ fn replace_with_null( input_schema: &Schema, ) -> Result, DataFusionError> { let with_null = Arc::clone(expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { if e.as_ref().dyn_eq(expr_to_replace) { let data_type = e.data_type(input_schema)?; let null_literal = lit(ScalarValue::try_new_null(&data_type)?); @@ -1425,9 +1425,9 @@ mod tests { use crate::expressions; use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null}; - use arrow::buffer::Buffer; + use arrow::buffer::{BooleanBuffer, Buffer}; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::Field; + use arrow::datatypes::{ArrowNativeType, Field, Fields}; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -1731,6 +1731,271 @@ mod tests { Ok(()) } + #[test] + fn case_with_expression_that_have_different_scope() -> Result<()> { + /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function + /// + /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution + #[derive(Debug, Clone)] + pub struct AllListElementMatchMiniLambda { + child: Arc, + predicate_on_list_elements: Arc, + } + + impl Hash for AllListElementMatchMiniLambda { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.predicate_on_list_elements.hash(state); + } + } + impl PartialEq for AllListElementMatchMiniLambda { + fn eq(&self, other: &Self) -> bool { + self.child.as_ref() == other.child.as_ref() + && self.predicate_on_list_elements.as_ref() + == other.predicate_on_list_elements.as_ref() + } + } + + impl Eq for AllListElementMatchMiniLambda {} + + impl AllListElementMatchMiniLambda { + pub fn new( + child: Arc, + predicate_on_list_element: Arc, + ) -> Self { + Self { + child, + predicate_on_list_elements: predicate_on_list_element, + } + } + } + + impl std::fmt::Display for AllListElementMatchMiniLambda { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "all_match({:?}, {:?})", + self.child, self.predicate_on_list_elements + ) + } + } + + impl PhysicalExpr for AllListElementMatchMiniLambda { + fn as_any(&self) -> &dyn Any { + self + } + + fn return_field( + &self, + input_schema: &Schema, + ) -> Result { + let is_child_nullable = self.child.nullable(input_schema)?; + Ok(Arc::new(Field::new( + "match", + DataType::Boolean, + is_child_nullable, + ))) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let child = self.child.evaluate(batch)?; + let DataType::List(child_list_field) = + self.child.data_type(batch.schema_ref())? + else { + unreachable!() + }; + + let child = child.to_array_of_size(batch.num_rows())?; + let list = child.as_list::(); + + let lambda_schema = Arc::new(Schema::new(Fields::from(vec![ + Field::new("index", DataType::UInt32, false), + child_list_field.as_ref().clone(), + ]))); + + assert_eq!( + list.value_offsets()[0].as_usize(), + 0, + "this is mock implementation, it does not support sliced list" + ); + assert_eq!( + list.value_offsets().last().unwrap().as_usize(), + list.values().len(), + "this is mock implementation, it does not support sliced list" + ); + + let list_values = list.values(); + + let new_batch = RecordBatch::try_new( + Arc::clone(&lambda_schema), + vec![ + Arc::new( + list.offsets() + .lengths() + .flat_map(|list_len| 0..list_len as u32) + .collect::(), + ), + Arc::clone(list_values), + ], + )?; + + let any_match = self.predicate_on_list_elements.evaluate(&new_batch)?; + let any_match = any_match.to_array_of_size(list_values.len())?; + let any_match = any_match.as_boolean(); + + let all_match_per_list = list + .offsets() + .windows(2) + .map(|start_and_end| { + let length = start_and_end[1] - start_and_end[0]; + let list_matches = + any_match.slice(start_and_end[0] as usize, length as usize); + + list_matches.true_count() == list_matches.len() + }) + .collect::(); + + let result = Arc::new(BooleanArray::new( + all_match_per_list, + list.nulls().cloned(), + )); + + Ok(ColumnarValue::Array(result)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.predicate_on_list_elements] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + child: iter.next().unwrap(), + predicate_on_list_elements: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + let mut iter = children_in_scope.into_iter(); + Ok(Arc::new(Self { + child: iter.next().unwrap(), + // TODO - but what if child has changed to not be list or the data type has changed?? + predicate_on_list_elements: Arc::clone( + &self.predicate_on_list_elements, + ), + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + let input_schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("col_1", DataType::Utf8, true)), + Arc::new(Field::new("col_2", DataType::Utf8, true)), + Arc::new(Field::new("col_3", DataType::Utf8, true)), + Arc::new(Field::new( + "list", + DataType::new_list(DataType::UInt32, true), + true, + )), + ])); + + let input_list = ListArray::from_iter_primitive::(vec![ + // all even place numbers are even + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Not all even place are even but all odd place are odd + Some(vec![Some(0), Some(1), Some(1)]), + // Not odd and not even in corresponding places + Some(vec![Some(1), Some(2)]), + ]); + + let batch = RecordBatch::try_new( + input_schema, + vec![ + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + Arc::new(input_list), + ], + ) + .unwrap(); + let schema = batch.schema(); + + fn create_when_expr(is_even: bool) -> Arc { + let idx_col: Arc = Arc::new(Column::new("idx", 0)); + let item_col: Arc = Arc::new(Column::new("item", 1)); + Arc::new(AllListElementMatchMiniLambda::new( + Arc::new(Column::new("list", 3)), + create_both_odd_or_even(&idx_col, &item_col, is_even), + )) + } + + fn create_both_odd_or_even( + idx_column: &Arc, + list_item_column: &Arc, + is_even: bool, + ) -> Arc { + let equal_value = if is_even { 0 } else { 1 }; + let idx_equal = module_2_equal_value(idx_column, equal_value); + let item_equal = module_2_equal_value(list_item_column, equal_value); + + case( + None, + vec![(idx_equal, item_equal)], + // if idx not equal than true + Some(lit(true)), + ) + .unwrap() + } + + fn module_2_equal_value( + left: &Arc, + equal_value: u32, + ) -> Arc { + let modulo2 = BinaryExpr::new(Arc::clone(left), Operator::Modulo, lit(2u32)); + let equal_value = + BinaryExpr::new(Arc::new(modulo2), Operator::Eq, lit(equal_value)); + + Arc::new(equal_value) + } + + let expr = generate_case_when_with_type_coercion( + None, + vec![ + (create_when_expr(true), lit("both even")), + (create_when_expr(false), lit("both odd")), + ], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + + let expected = + &StringArray::from(vec![Some("both even"), None, Some("both odd"), None]); + + assert_eq!(expected, result.as_string::()); + + Ok(()) + } + #[test] fn case_with_expr_when_null() -> Result<()> { let batch = case_test_batch()?; @@ -2226,7 +2491,7 @@ mod tests { .unwrap(); let expr3 = Arc::clone(&expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { ScalarValue::Utf8(Some(str_value)) => { diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index a03b58e0b594d..555eca7c3bf81 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -22,7 +22,7 @@ use crate::{LexOrdering, PhysicalSortExpr, create_physical_expr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult}; use datafusion_common::{DFSchema, HashMap}; use datafusion_common::{Result, plan_err}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,7 +38,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { + expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::() { Some(col) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index e133e5a849cd8..3c5d4e547a61c 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -27,7 +27,7 @@ use crate::utils::collect_columns; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult}; use datafusion_common::{ Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, plan_err, @@ -920,7 +920,7 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr| { + .transform_up_in_scope(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -1043,7 +1043,7 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure @@ -1162,7 +1162,7 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = Arc::clone(expr).transform_up(|expr| { + let transformed = Arc::clone(expr).transform_up_in_scope(|expr| { let Some(col) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; @@ -1200,7 +1200,7 @@ pub(crate) mod tests { use super::*; use crate::equivalence::{EquivalenceProperties, convert_to_orderings}; use crate::expressions::{BinaryExpr, col}; - use crate::utils::tests::TestScalarUDF; + use crate::utils::tests::{ScopedExprMock, TestScalarUDF}; use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::compute::SortOptions; @@ -3038,4 +3038,161 @@ pub(crate) mod tests { Ok(()) } + + #[test] + fn test_update_expr_does_not_modify_out_of_scope_children() -> Result<()> { + // Outer schema: [a, b, c] + // Expression: ScopedExprMock(in_scope=a@0, out_of_scope=x@0) + // Projection: [c@2 as c_new, a@0 as a_new, b@1 as b_new] + // After unproject: in_scope should become c@2, out_of_scope should stay x@0 + let in_scope_child: Arc = Arc::new(Column::new("a_new", 1)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let projected_exprs = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_new".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a_new".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b_new".to_string(), + }, + ]; + + let result = + update_expr(&expr, &projected_exprs, true)?.expect("Should be valid"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a_new@1" should be unprojected to "a@0" + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + + Ok(()) + } + + #[test] + fn test_project_ordering_does_not_modify_out_of_scope_children() { + // Schema: [a: Int32, b: Int32] + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let scoped_expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + scoped_expr, + SortOptions::new(false, false), + )]) + .unwrap(); + + let result = project_ordering(&ordering, &schema).expect("Should project"); + + let projected_expr = &result.first().expr; + let mock = projected_expr + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope column "a" should be reindexed (stays at 0 in this case) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + } + + #[test] + fn test_projection_mapping_does_not_modify_out_of_scope_children() -> Result<()> { + // Input schema: [a: Int32, b: Int32] + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let scoped_expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + // Project: [ScopedExprMock as "result"] + let projection_exprs = vec![(scoped_expr, "result".to_string())]; + + let mapping = ProjectionMapping::try_new(projection_exprs, &input_schema)?; + + // The source expression in the mapping should have its in-scope column + // validated but the out-of-scope column left untouched + let (source_expr, _targets) = mapping.iter().next().unwrap(); + let mock = source_expr + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // In-scope child: "a@0" should still be "a@0" (name matches schema) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // Out-of-scope child: "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 6a8b49ac52523..f5a2508d68249 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -16,6 +16,7 @@ // under the License. mod guarantee; +use datafusion_common::tree_node::ScopedTreeNode; pub use guarantee::{Guarantee, LiteralGuarantee}; use std::borrow::Borrow; @@ -312,7 +313,7 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down(|expr| { + expr.transform_down_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { let index = schema.index_of(column.name())?; @@ -333,6 +334,7 @@ pub(crate) mod tests { use super::*; use crate::expressions::{Literal, binary, cast, col, in_list, lit}; + use std::hash::Hash; use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::{DataType, Field}; @@ -342,8 +344,94 @@ pub(crate) mod tests { ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; + use arrow::array::RecordBatch; use petgraph::visit::Bfs; + /// A mock expression that has two children but only one is "in scope". + /// This simulates a lambda-like expression where the `in_scope_child` + /// references columns in the outer schema and the `out_of_scope_child` + /// references columns in a different (lambda) schema. + #[derive(Debug, Clone)] + pub(crate) struct ScopedExprMock { + pub in_scope_child: Arc, + pub out_of_scope_child: Arc, + } + + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl Display for ScopedExprMock { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field(&self, input_schema: &Schema) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self, f) + } + } + #[derive(Debug, PartialEq, Eq, Hash)] pub struct TestScalarUDF { pub(crate) signature: Signature, @@ -647,4 +735,49 @@ pub(crate) mod tests { Ok(()) } + + #[test] + fn test_reassign_expr_columns_does_not_modify_out_of_scope_children() { + // Outer schema: [a: Int32, b: Int32] + // Lambda schema: [x: Int32] (different scope, should not be touched) + let outer_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + // The in-scope child references "b" at index 5 (wrong index for outer_schema) + let in_scope_child: Arc = Arc::new(Column::new("b", 5)); + // The out-of-scope child references "x" at index 0 in the lambda schema + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = reassign_expr_columns(expr, &outer_schema).unwrap(); + + // The in-scope "b" column should be reassigned to index 1 (its position in outer_schema) + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "b"); + assert_eq!(in_scope_col.index(), 1); // reassigned to correct index + + // The out-of-scope "x" column should be UNCHANGED (still index 0) + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); // not modified + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 7e82b9e8239e0..7e60c9a937d9d 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -40,7 +40,7 @@ use std::sync::Arc; use arrow_schema::SchemaRef; use datafusion_common::{ Result, - tree_node::{Transformed, TreeNode}, + tree_node::{ScopedTreeNode, Transformed}, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -364,7 +364,7 @@ impl FilterRemapper { filter: &Arc, ) -> Result>> { let mut all_valid = true; - let transformed = Arc::clone(filter).transform_down(|expr| { + let transformed = Arc::clone(filter).transform_down_in_scope(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { if self.allowed_indices.contains(&col.index()) && let Ok(new_index) = self.child_schema.index_of(col.name()) @@ -553,3 +553,195 @@ impl FilterDescription { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::RecordBatch; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::ColumnarValue; + use datafusion_physical_expr::expressions::Column; + use std::hash::Hash; + + /// A mock expression with an in-scope child and an out-of-scope child. + /// Used to verify that scoped traversal does not modify out-of-scope children. + #[derive(Debug, Clone)] + struct ScopedExprMock { + in_scope_child: Arc, + out_of_scope_child: Arc, + } + + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl std::fmt::Display for ScopedExprMock { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field(&self, input_schema: &Schema) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + #[test] + fn test_filter_remapper_does_not_modify_out_of_scope_children() { + // Child schema: [a: Int32, b: Int32] + let child_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let remapper = FilterRemapper::new(Arc::clone(&child_schema)); + + // The in-scope child references column "a@0" which exists in child schema + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child also references "a@0" but should NOT be remapped + let out_of_scope_child: Arc = Arc::new(Column::new("a", 0)); + + let filter: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = remapper + .try_remap(&filter) + .unwrap() + .expect("Should remap successfully"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a@0" should be remapped (stays "a@0" since schema matches) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "a@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "a"); + assert_eq!(out_of_scope_col.index(), 0); + } + + #[test] + fn test_filter_remapper_remaps_in_scope_but_not_out_of_scope() { + // Parent schema: [a: Int32, b: Int32, c: Int32] + // Child schema: [b: Int32, a: Int32] (different order, so "a" remaps from 0 -> 1) + let child_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let remapper = FilterRemapper::new(Arc::clone(&child_schema)); + + // The in-scope child references "a@0" - should be remapped to "a@1" in child schema + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child references "x@0" in the lambda schema - should NOT be touched + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let filter: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = remapper + .try_remap(&filter) + .unwrap() + .expect("Should remap successfully"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a@0" should be remapped to "a@1" (position in child schema) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 1); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + } +}