@@ -74,45 +74,55 @@ build_case_when_expr <- function(query, value) {
7474 )
7575}
7676
77+ # ' Build a match expression for x against a value (scalar, NA, or vector).
78+ # ' @param x Arrow Expression for the column to match against.
79+ # ' @param match_value Value to match - R scalar, vector, or NA. Expressions
80+ # ' are compared with equality.
81+ # ' @return Arrow Expression that is TRUE when x matches match_value.
82+ # ' @keywords internal
83+ # ' @noRd
84+ build_match_expr <- function (x , match_value ) {
85+ # Expressions: use equality directly
86+ if (inherits(match_value , " Expression" )) {
87+ return (x == match_value )
88+ }
89+
90+ # R scalar NA requires is.na() since x == NA returns NA in Arrow
91+ if (length(match_value ) == 1 && is.na(match_value )) {
92+ return (call_binding(" is.na" , x ))
93+ }
94+
95+ # R scalar: simple equality
96+ if (length(match_value ) == 1 ) {
97+ return (x == match_value )
98+ }
99+
100+ # R vector: use %in%, handling NA separately if present
101+ has_na <- any(is.na(match_value ))
102+ non_na_values <- match_value [! is.na(match_value )]
103+
104+ if (length(non_na_values ) == 0 ) {
105+ call_binding(" is.na" , x )
106+ } else if (has_na ) {
107+ call_binding(" %in%" , x , non_na_values ) | call_binding(" is.na" , x )
108+ } else {
109+ call_binding(" %in%" , x , match_value )
110+ }
111+ }
112+
77113# ' Build query/value lists from parallel from/to vectors.
78114# ' NA values in `from` use is.na() for matching.
79115# ' @param x Arrow Expression for the column to match against.
80116# ' @param from Vector of values to match.
81117# ' @param to Vector of replacement values (recycled to length of `from`).
82118# ' @return list(query, value) for use with build_case_when_expr().
83- # ' @examples
84- # ' x_expr <- Expression$field_ref("x")
85- # ' parse_from_to_mapping(x_expr, from = c("a", "b"), to = c("A", "B"))
86119# ' @keywords internal
87120# ' @noRd
88121parse_from_to_mapping <- function (x , from , to ) {
89122 n <- length(from )
90123 to <- vctrs :: vec_recycle(to , n )
91- query <- vector(" list" , n )
92- value <- vector(" list" , n )
93- for (i in seq_len(n )) {
94- from_i <- from [[i ]]
95- # Handle NA specially: use is.na() since x == NA returns NA in Arrow
96- if (length(from_i ) == 1 && is.na(from_i )) {
97- query [[i ]] <- call_binding(" is.na" , x )
98- } else if (length(from_i ) > 1 ) {
99- # Multiple values: use %in% to match any
100- # If NA is in the vector, also match NA using is.na()
101- if (any(is.na(from_i ))) {
102- non_na_values <- from_i [! is.na(from_i )]
103- if (length(non_na_values ) > 0 ) {
104- query [[i ]] <- call_binding(" %in%" , x , non_na_values ) | call_binding(" is.na" , x )
105- } else {
106- query [[i ]] <- call_binding(" is.na" , x )
107- }
108- } else {
109- query [[i ]] <- call_binding(" %in%" , x , from_i )
110- }
111- } else {
112- query [[i ]] <- x == from_i
113- }
114- value [[i ]] <- Expression $ scalar(to [[i ]])
115- }
124+ query <- map(from , ~ build_match_expr(x , .x ))
125+ value <- map(to , Expression $ scalar )
116126 list (query = query , value = value )
117127}
118128
@@ -123,10 +133,6 @@ parse_from_to_mapping <- function(x, from, to) {
123133# ' @param mask Data mask for evaluating formula expressions.
124134# ' @param fn Calling function name (for error messages).
125135# ' @return list(query, value) for use with build_case_when_expr().
126- # ' @examples
127- # ' x_expr <- Expression$field_ref("x")
128- # ' mask <- rlang::new_data_mask(rlang::current_env())
129- # ' parse_formula_mapping(x_expr, list("a" ~ "A", "b" ~ "B"), mask, "replace_values")
130136# ' @keywords internal
131137# ' @noRd
132138parse_formula_mapping <- function (x , formulas , mask , fn ) {
@@ -142,31 +148,8 @@ parse_formula_mapping <- function(x, formulas, mask, fn) {
142148 }
143149 # f[[2]] is LHS (value to match), f[[3]] is RHS (replacement)
144150 lhs <- arrow_eval(f [[2 ]], mask )
145- rhs <- arrow_eval(f [[3 ]], mask )
146- # Handle NA specially: use is.na() since x == NA returns NA in Arrow
147- if (inherits(lhs , " Expression" ) && lhs $ type_id() == Type [[" NA" ]]) {
148- # NA evaluated to an Arrow NA Expression
149- query [[i ]] <- call_binding(" is.na" , x )
150- } else if (! inherits(lhs , " Expression" ) && length(lhs ) == 1 && is.na(lhs )) {
151- # NA is a bare R value
152- query [[i ]] <- call_binding(" is.na" , x )
153- } else if (! inherits(lhs , " Expression" ) && length(lhs ) > 1 ) {
154- # Vector LHS: c("a", "b") ~ "X" matches any value in the vector
155- # If NA is in the vector, also match NA using is.na()
156- if (any(is.na(lhs ))) {
157- non_na_values <- lhs [! is.na(lhs )]
158- if (length(non_na_values ) > 0 ) {
159- query [[i ]] <- call_binding(" %in%" , x , non_na_values ) | call_binding(" is.na" , x )
160- } else {
161- query [[i ]] <- call_binding(" is.na" , x )
162- }
163- } else {
164- query [[i ]] <- call_binding(" %in%" , x , lhs )
165- }
166- } else {
167- query [[i ]] <- x == lhs
168- }
169- value [[i ]] <- rhs
151+ query [[i ]] <- build_match_expr(x , lhs )
152+ value [[i ]] <- arrow_eval(f [[3 ]], mask )
170153 }
171154 list (query = query , value = value )
172155}
0 commit comments