Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
706 changes: 680 additions & 26 deletions datafusion/common/src/tree_node.rs

Large diffs are not rendered by default.

130 changes: 129 additions & 1 deletion datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
use datafusion_common::tree_node::ScopedTreeNode;
use datafusion_common::{
DataFusionError, Result, ScalarValue, exec_err,
metadata::FieldMetadata,
Expand Down Expand Up @@ -69,7 +70,7 @@ where
K: Borrow<str> + Eq + Hash,
V: Borrow<ScalarValue>,
{
expr.transform_down(|expr| {
expr.transform_down_in_scope(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>()
&& let Some(replacement_value) = replacements.get(column.name())
{
Expand Down Expand Up @@ -656,6 +657,7 @@ mod tests {
use datafusion_common::{assert_contains, record_batch};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
use std::hash::Hash;

fn create_test_schema() -> (Schema, Schema) {
let physical_schema = Schema::new(vec![
Expand Down Expand Up @@ -1733,4 +1735,130 @@ mod tests {
assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32);
assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
}

/// A mock expression with an in-scope child and an out-of-scope child.
/// Used to verify that scoped traversal does not modify out-of-scope children.
#[derive(Debug, Clone)]
struct ScopedExprMock {
in_scope_child: Arc<dyn PhysicalExpr>,
out_of_scope_child: Arc<dyn PhysicalExpr>,
}

impl Hash for ScopedExprMock {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.in_scope_child.hash(state);
self.out_of_scope_child.hash(state);
}
}

impl PartialEq for ScopedExprMock {
fn eq(&self, other: &Self) -> bool {
self.in_scope_child.as_ref() == other.in_scope_child.as_ref()
&& self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref()
}
}

impl Eq for ScopedExprMock {}

impl std::fmt::Display for ScopedExprMock {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"scoped_mock({}, {})",
self.in_scope_child, self.out_of_scope_child
)
}
}

impl PhysicalExpr for ScopedExprMock {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn return_field(&self, input_schema: &Schema) -> Result<Arc<Field>> {
self.in_scope_child.return_field(input_schema)
}

fn evaluate(
&self,
_batch: &RecordBatch,
) -> Result<datafusion_expr::ColumnarValue> {
unimplemented!("ScopedExprMock does not support evaluation")
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.in_scope_child, &self.out_of_scope_child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
assert_eq!(children.len(), 2);
let mut iter = children.into_iter();
Ok(Arc::new(Self {
in_scope_child: iter.next().unwrap(),
out_of_scope_child: iter.next().unwrap(),
}))
}

fn children_in_scope(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.in_scope_child]
}

fn with_new_children_in_scope(
self: Arc<Self>,
children_in_scope: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
assert_eq!(children_in_scope.len(), 1);
Ok(Arc::new(Self {
in_scope_child: children_in_scope.into_iter().next().unwrap(),
out_of_scope_child: Arc::clone(&self.out_of_scope_child),
}))
}

fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self, f)
}
}

#[test]
fn test_replace_columns_with_literals_does_not_modify_out_of_scope_children() {
// The in-scope child references column "a" which should be replaced
let in_scope_child: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));
// The out-of-scope child also references a column "a" but should NOT be replaced
let out_of_scope_child: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));

let expr: Arc<dyn PhysicalExpr> = Arc::new(ScopedExprMock {
in_scope_child,
out_of_scope_child,
});

let mut replacements = HashMap::new();
replacements.insert("a", ScalarValue::Int32(Some(42)));

let result = replace_columns_with_literals(expr, &replacements).unwrap();

let mock = result
.as_any()
.downcast_ref::<ScopedExprMock>()
.expect("Should still be ScopedExprMock");

// The in-scope child "a" should be replaced with literal 42
let in_scope_lit = mock
.in_scope_child
.as_any()
.downcast_ref::<Literal>()
.expect("in_scope_child should be replaced with Literal");
assert_eq!(in_scope_lit.value(), &ScalarValue::Int32(Some(42)));

// The out-of-scope child "a@0" should be UNCHANGED (still a Column)
let out_of_scope_col = mock
.out_of_scope_child
.as_any()
.downcast_ref::<Column>()
.expect("out_of_scope_child should still be Column");
assert_eq!(out_of_scope_col.name(), "a");
assert_eq!(out_of_scope_col.index(), 0);
}
}
49 changes: 49 additions & 0 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,37 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
/// Get a list of child PhysicalExpr that provide the input for this expr.
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;

/// Get a list of child PhysicalExpr that provide the input for this expr that are in the same scope as this expression.
Copy link
Copy Markdown
Contributor

@alamb alamb Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this API is confusing because "scope" isn't something that is defined

I think a beter description of what this API is doing is "children in the outermost scope"

All built in expressions have the (implicit) property that their children are evaluated in the same input schema. I don't see how CASE is any different in theory (though I understand it is what is causing you current problems)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think if we add thsi API it is easy to misue and you'll be chasing down corner case bugs with other expression types that have the same problem

///
/// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`]
///
/// To know if specific child is considered in the same scope you can answer this simple question:
/// If that child is a `Column` would that column can be evaluated with the same input schema
/// Expressions like `plus`, `sum`, etc have all children in scope.
/// Lambda expressions like `array_filter(list, value -> value + 1)`, have the `list` in the same scope and the lambda function in different scope
fn children_in_scope(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.children()
}

/// Returns a new PhysicalExpr where all children were replaced by new exprs.
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;

/// Returns a new PhysicalExpr where all scoped children were replaced by new exprs.
///
/// See [`Self::children_in_scope`] for definition of what child considered a scope
///
/// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`]
///
fn with_new_children_in_scope(
self: Arc<Self>,
children_in_scope: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
self.with_new_children(children_in_scope)
}

/// Computes the output interval for the expression, given the input
/// intervals.
///
Expand Down Expand Up @@ -486,6 +511,30 @@ pub fn with_new_children_if_necessary(
Ok(expr)
}
}
/// Returns a copy of this expr if we change any child according to the pointer comparison.
/// The size of `children_in_scope` must be equal to the size of [`PhysicalExpr::children_in_scope()`].
pub fn with_new_children_in_scope_if_necessary(
expr: Arc<dyn PhysicalExpr>,
children_in_scope: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let old_children_in_scope = expr.children_in_scope();
assert_eq_or_internal_err!(
children_in_scope.len(),
old_children_in_scope.len(),
"PhysicalExpr: Wrong number of children in scope"
);

if children_in_scope.is_empty()
|| children_in_scope
.iter()
.zip(old_children_in_scope.iter())
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
{
Ok(expr.with_new_children_in_scope(children_in_scope)?)
} else {
Ok(expr)
}
}

/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
Expand Down
20 changes: 18 additions & 2 deletions datafusion/physical-expr-common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;

use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary};
use crate::physical_expr::{
PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary,
};

use datafusion_common::Result;
use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode};
use datafusion_common::tree_node::{ConcreteTreeNode, DynScopedTreeNode, DynTreeNode};

impl DynTreeNode for dyn PhysicalExpr {
fn arc_children(&self) -> Vec<&Arc<Self>> {
Expand All @@ -39,6 +41,20 @@ impl DynTreeNode for dyn PhysicalExpr {
}
}

impl DynScopedTreeNode for dyn PhysicalExpr {
fn arc_children_in_scope(&self) -> Vec<&Arc<Self>> {
self.children_in_scope()
}

fn with_new_arc_children_in_scope(
&self,
arc_self: Arc<Self>,
new_children: Vec<Arc<Self>>,
) -> Result<Arc<Self>> {
with_new_children_in_scope_if_necessary(arc_self, new_children)
}
}

/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are
/// two ways to access child plans—directly from the plan and through child nodes—it's
/// recommended to perform mutable operations via [`Self::update_expr_from_children`].
Expand Down
Loading
Loading