Skip to content

Commit 3e3b413

Browse files
committed
refactor for simplicity again
1 parent ade94d9 commit 3e3b413

File tree

1 file changed

+40
-57
lines changed

1 file changed

+40
-57
lines changed

r/R/dplyr-funcs-conditional.R

Lines changed: 40 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -74,45 +74,55 @@ build_case_when_expr <- function(query, value) {
7474
)
7575
}
7676

77+
#' Build a match expression for x against a value (scalar, NA, or vector).
78+
#' @param x Arrow Expression for the column to match against.
79+
#' @param match_value Value to match - R scalar, vector, or NA. Expressions
80+
#' are compared with equality.
81+
#' @return Arrow Expression that is TRUE when x matches match_value.
82+
#' @keywords internal
83+
#' @noRd
84+
build_match_expr <- function(x, match_value) {
85+
# Expressions: use equality directly
86+
if (inherits(match_value, "Expression")) {
87+
return(x == match_value)
88+
}
89+
90+
# R scalar NA requires is.na() since x == NA returns NA in Arrow
91+
if (length(match_value) == 1 && is.na(match_value)) {
92+
return(call_binding("is.na", x))
93+
}
94+
95+
# R scalar: simple equality
96+
if (length(match_value) == 1) {
97+
return(x == match_value)
98+
}
99+
100+
# R vector: use %in%, handling NA separately if present
101+
has_na <- any(is.na(match_value))
102+
non_na_values <- match_value[!is.na(match_value)]
103+
104+
if (length(non_na_values) == 0) {
105+
call_binding("is.na", x)
106+
} else if (has_na) {
107+
call_binding("%in%", x, non_na_values) | call_binding("is.na", x)
108+
} else {
109+
call_binding("%in%", x, match_value)
110+
}
111+
}
112+
77113
#' Build query/value lists from parallel from/to vectors.
78114
#' NA values in `from` use is.na() for matching.
79115
#' @param x Arrow Expression for the column to match against.
80116
#' @param from Vector of values to match.
81117
#' @param to Vector of replacement values (recycled to length of `from`).
82118
#' @return list(query, value) for use with build_case_when_expr().
83-
#' @examples
84-
#' x_expr <- Expression$field_ref("x")
85-
#' parse_from_to_mapping(x_expr, from = c("a", "b"), to = c("A", "B"))
86119
#' @keywords internal
87120
#' @noRd
88121
parse_from_to_mapping <- function(x, from, to) {
89122
n <- length(from)
90123
to <- vctrs::vec_recycle(to, n)
91-
query <- vector("list", n)
92-
value <- vector("list", n)
93-
for (i in seq_len(n)) {
94-
from_i <- from[[i]]
95-
# Handle NA specially: use is.na() since x == NA returns NA in Arrow
96-
if (length(from_i) == 1 && is.na(from_i)) {
97-
query[[i]] <- call_binding("is.na", x)
98-
} else if (length(from_i) > 1) {
99-
# Multiple values: use %in% to match any
100-
# If NA is in the vector, also match NA using is.na()
101-
if (any(is.na(from_i))) {
102-
non_na_values <- from_i[!is.na(from_i)]
103-
if (length(non_na_values) > 0) {
104-
query[[i]] <- call_binding("%in%", x, non_na_values) | call_binding("is.na", x)
105-
} else {
106-
query[[i]] <- call_binding("is.na", x)
107-
}
108-
} else {
109-
query[[i]] <- call_binding("%in%", x, from_i)
110-
}
111-
} else {
112-
query[[i]] <- x == from_i
113-
}
114-
value[[i]] <- Expression$scalar(to[[i]])
115-
}
124+
query <- map(from, ~ build_match_expr(x, .x))
125+
value <- map(to, Expression$scalar)
116126
list(query = query, value = value)
117127
}
118128

@@ -123,10 +133,6 @@ parse_from_to_mapping <- function(x, from, to) {
123133
#' @param mask Data mask for evaluating formula expressions.
124134
#' @param fn Calling function name (for error messages).
125135
#' @return list(query, value) for use with build_case_when_expr().
126-
#' @examples
127-
#' x_expr <- Expression$field_ref("x")
128-
#' mask <- rlang::new_data_mask(rlang::current_env())
129-
#' parse_formula_mapping(x_expr, list("a" ~ "A", "b" ~ "B"), mask, "replace_values")
130136
#' @keywords internal
131137
#' @noRd
132138
parse_formula_mapping <- function(x, formulas, mask, fn) {
@@ -142,31 +148,8 @@ parse_formula_mapping <- function(x, formulas, mask, fn) {
142148
}
143149
# f[[2]] is LHS (value to match), f[[3]] is RHS (replacement)
144150
lhs <- arrow_eval(f[[2]], mask)
145-
rhs <- arrow_eval(f[[3]], mask)
146-
# Handle NA specially: use is.na() since x == NA returns NA in Arrow
147-
if (inherits(lhs, "Expression") && lhs$type_id() == Type[["NA"]]) {
148-
# NA evaluated to an Arrow NA Expression
149-
query[[i]] <- call_binding("is.na", x)
150-
} else if (!inherits(lhs, "Expression") && length(lhs) == 1 && is.na(lhs)) {
151-
# NA is a bare R value
152-
query[[i]] <- call_binding("is.na", x)
153-
} else if (!inherits(lhs, "Expression") && length(lhs) > 1) {
154-
# Vector LHS: c("a", "b") ~ "X" matches any value in the vector
155-
# If NA is in the vector, also match NA using is.na()
156-
if (any(is.na(lhs))) {
157-
non_na_values <- lhs[!is.na(lhs)]
158-
if (length(non_na_values) > 0) {
159-
query[[i]] <- call_binding("%in%", x, non_na_values) | call_binding("is.na", x)
160-
} else {
161-
query[[i]] <- call_binding("is.na", x)
162-
}
163-
} else {
164-
query[[i]] <- call_binding("%in%", x, lhs)
165-
}
166-
} else {
167-
query[[i]] <- x == lhs
168-
}
169-
value[[i]] <- rhs
151+
query[[i]] <- build_match_expr(x, lhs)
152+
value[[i]] <- arrow_eval(f[[3]], mask)
170153
}
171154
list(query = query, value = value)
172155
}

0 commit comments

Comments
 (0)