Skip to content

Commit 3fad33f

Browse files
committed
POC: consistent NaN treatment for pyarrow dtypes
1 parent 35b0d1d commit 3fad33f

File tree

7 files changed

+81
-19
lines changed

7 files changed

+81
-19
lines changed

pandas/_libs/parsers.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def _maybe_upcast(
14561456
if isinstance(arr, IntegerArray) and arr.isna().all():
14571457
# use null instead of int64 in pyarrow
14581458
arr = arr.to_numpy(na_value=None)
1459-
arr = ArrowExtensionArray(pa.array(arr, from_pandas=True))
1459+
arr = ArrowExtensionArray(pa.array(arr))
14601460

14611461
return arr
14621462

pandas/core/arrays/arrow/array.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717

1818
from pandas._libs import lib
19+
from pandas._libs.missing import NA
1920
from pandas._libs.tslibs import (
2021
Timedelta,
2122
Timestamp,
@@ -360,7 +361,7 @@ def _from_sequence_of_strings(
360361
# duration to string casting behavior
361362
mask = isna(scalars)
362363
if not isinstance(strings, (pa.Array, pa.ChunkedArray)):
363-
strings = pa.array(strings, type=pa.string(), from_pandas=True)
364+
strings = pa.array(strings, type=pa.string())
364365
strings = pc.if_else(mask, None, strings)
365366
try:
366367
scalars = strings.cast(pa.int64())
@@ -381,7 +382,7 @@ def _from_sequence_of_strings(
381382
if isinstance(strings, (pa.Array, pa.ChunkedArray)):
382383
scalars = strings
383384
else:
384-
scalars = pa.array(strings, type=pa.string(), from_pandas=True)
385+
scalars = pa.array(strings, type=pa.string())
385386
scalars = pc.if_else(pc.equal(scalars, "1.0"), "1", scalars)
386387
scalars = pc.if_else(pc.equal(scalars, "0.0"), "0", scalars)
387388
scalars = scalars.cast(pa.bool_())
@@ -393,6 +394,13 @@ def _from_sequence_of_strings(
393394
from pandas.core.tools.numeric import to_numeric
394395

395396
scalars = to_numeric(strings, errors="raise")
397+
if not pa.types.is_decimal(pa_type):
398+
# TODO: figure out why doing this cast breaks with decimal dtype
399+
# in test_from_sequence_of_strings_pa_array
400+
mask = strings.is_null()
401+
scalars = pa.array(scalars, mask=np.array(mask), type=pa_type)
402+
# TODO: could we just do strings.cast(pa_type)?
403+
396404
else:
397405
raise NotImplementedError(
398406
f"Converting strings to {pa_type} is not implemented."
@@ -435,7 +443,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
435443
"""
436444
if isinstance(value, pa.Scalar):
437445
pa_scalar = value
438-
elif isna(value):
446+
elif isna(value) and not lib.is_float(value):
439447
pa_scalar = pa.scalar(None, type=pa_type)
440448
else:
441449
# Workaround https://github.com/apache/arrow/issues/37291
@@ -452,7 +460,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
452460
value = value.as_unit(pa_type.unit)
453461
value = value._value
454462

455-
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
463+
pa_scalar = pa.scalar(value, type=pa_type)
456464

457465
if pa_type is not None and pa_scalar.type != pa_type:
458466
pa_scalar = pa_scalar.cast(pa_type)
@@ -484,6 +492,13 @@ def _box_pa_array(
484492
if copy:
485493
value = value.copy()
486494
pa_array = value.__arrow_array__()
495+
496+
elif hasattr(value, "__arrow_array__"):
497+
# e.g. StringArray
498+
if copy:
499+
value = value.copy()
500+
pa_array = value.__arrow_array__()
501+
487502
else:
488503
if (
489504
isinstance(value, np.ndarray)
@@ -510,19 +525,32 @@ def _box_pa_array(
510525
value = to_timedelta(value, unit=pa_type.unit).as_unit(pa_type.unit)
511526
value = value.to_numpy()
512527

528+
mask = None
529+
if getattr(value, "dtype", None) is None or value.dtype.kind not in "mfM":
530+
# similar to isna(value) but exclude NaN
531+
# TODO: cythonize!
532+
mask = np.array([x is NA or x is None for x in value], dtype=bool)
533+
534+
from_pandas = False
535+
if pa.types.is_integer(pa_type):
536+
# If user specifically asks to cast a numpy float array with NaNs
537+
# to pyarrow integer, we'll treat those NaNs as NA
538+
from_pandas = True
513539
try:
514-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
540+
pa_array = pa.array(
541+
value, type=pa_type, mask=mask, from_pandas=from_pandas
542+
)
515543
except (pa.ArrowInvalid, pa.ArrowTypeError):
516544
# GH50430: let pyarrow infer type, then cast
517-
pa_array = pa.array(value, from_pandas=True)
545+
pa_array = pa.array(value, mask=mask, from_pandas=from_pandas)
518546

519547
if pa_type is None and pa.types.is_duration(pa_array.type):
520548
# Workaround https://github.com/apache/arrow/issues/37291
521549
from pandas.core.tools.timedeltas import to_timedelta
522550

523551
value = to_timedelta(value)
524552
value = value.to_numpy()
525-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
553+
pa_array = pa.array(value, type=pa_type)
526554

527555
if pa.types.is_duration(pa_array.type) and pa_array.null_count > 0:
528556
# GH52843: upstream bug for duration types when originally
@@ -1169,7 +1197,7 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
11691197
if not len(values):
11701198
return np.zeros(len(self), dtype=bool)
11711199

1172-
result = pc.is_in(self._pa_array, value_set=pa.array(values, from_pandas=True))
1200+
result = pc.is_in(self._pa_array, value_set=pa.array(values))
11731201
# pyarrow 2.0.0 returned nulls, so we explicitly specify dtype to convert nulls
11741202
# to False
11751203
return np.array(result, dtype=np.bool_)
@@ -2002,7 +2030,7 @@ def __setitem__(self, key, value) -> None:
20022030
raise ValueError("Length of indexer and values mismatch")
20032031
chunks = [
20042032
*self._pa_array[:key].chunks,
2005-
pa.array([value], type=self._pa_array.type, from_pandas=True),
2033+
pa.array([value], type=self._pa_array.type),
20062034
*self._pa_array[key + 1 :].chunks,
20072035
]
20082036
data = pa.chunked_array(chunks).combine_chunks()
@@ -2056,7 +2084,7 @@ def _rank_calc(
20562084
pa_type = pa.float64()
20572085
else:
20582086
pa_type = pa.uint64()
2059-
result = pa.array(ranked, type=pa_type, from_pandas=True)
2087+
result = pa.array(ranked, type=pa_type)
20602088
return result
20612089

20622090
data = self._pa_array.combine_chunks()
@@ -2308,7 +2336,7 @@ def _to_numpy_and_type(value) -> tuple[np.ndarray, pa.DataType | None]:
23082336
right, right_type = _to_numpy_and_type(right)
23092337
pa_type = left_type or right_type
23102338
result = np.where(cond, left, right)
2311-
return pa.array(result, type=pa_type, from_pandas=True)
2339+
return pa.array(result, type=pa_type)
23122340

23132341
@classmethod
23142342
def _replace_with_mask(
@@ -2351,7 +2379,7 @@ def _replace_with_mask(
23512379
replacements = replacements.as_py()
23522380
result = np.array(values, dtype=object)
23532381
result[mask] = replacements
2354-
return pa.array(result, type=values.type, from_pandas=True)
2382+
return pa.array(result, type=values.type)
23552383

23562384
# ------------------------------------------------------------------
23572385
# GroupBy Methods
@@ -2430,7 +2458,7 @@ def _groupby_op(
24302458
return type(self)(pa_result)
24312459
else:
24322460
# DatetimeArray, TimedeltaArray
2433-
pa_result = pa.array(result, from_pandas=True)
2461+
pa_result = pa.array(result)
24342462
return type(self)(pa_result)
24352463

24362464
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:

pandas/core/arrays/string_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,12 @@ def _str_map_str_or_object(
474474
if self.dtype.storage == "pyarrow":
475475
import pyarrow as pa
476476

477+
# TODO: shouldn't this already be caught my passed mask?
478+
# it isn't in test_extract_expand_capture_groups_index
479+
# mask = mask | np.array(
480+
# [x is libmissing.NA for x in result], dtype=bool
481+
# )
482+
477483
result = pa.array(
478484
result, mask=mask, type=pa.large_string(), from_pandas=True
479485
)
@@ -726,7 +732,7 @@ def __arrow_array__(self, type=None):
726732

727733
values = self._ndarray.copy()
728734
values[self.isna()] = None
729-
return pa.array(values, type=type, from_pandas=True)
735+
return pa.array(values, type=type)
730736

731737
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
732738
arr = self._ndarray

pandas/core/generic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9874,7 +9874,7 @@ def where(
98749874
def where(
98759875
self,
98769876
cond,
9877-
other=np.nan,
9877+
other=lib.no_default,
98789878
*,
98799879
inplace: bool = False,
98809880
axis: Axis | None = None,
@@ -10032,6 +10032,23 @@ def where(
1003210032
stacklevel=2,
1003310033
)
1003410034

10035+
if other is lib.no_default:
10036+
if self.ndim == 1:
10037+
if isinstance(self.dtype, ExtensionDtype):
10038+
other = self.dtype.na_value
10039+
else:
10040+
other = np.nan
10041+
else:
10042+
if self._mgr.nblocks == 1 and isinstance(
10043+
self._mgr.blocks[0].values.dtype, ExtensionDtype
10044+
):
10045+
# FIXME: checking this is kludgy!
10046+
other = self._mgr.blocks[0].values.dtype.na_value
10047+
else:
10048+
# FIXME: the same problem we had with Series will now
10049+
# show up column-by-column!
10050+
other = np.nan
10051+
1003510052
other = common.apply_if_callable(other, self)
1003610053
return self._where(cond, other, inplace=inplace, axis=axis, level=level)
1003710054

pandas/tests/extension/test_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def test_EA_types(self, engine, data, dtype_backend, request):
731731
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
732732
)
733733
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
734-
csv_output = df.to_csv(index=False, na_rep=np.nan)
734+
csv_output = df.to_csv(index=False, na_rep=np.nan) # should be NA?
735735
if pa.types.is_binary(pa_dtype):
736736
csv_output = BytesIO(csv_output)
737737
else:

pandas/tests/groupby/test_reductions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
381381
df = DataFrame(
382382
{
383383
"a": [2, 1, 1, 2, 3, 3],
384-
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
385-
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
384+
# TODO: test that has mixed na_value and NaN either working for
385+
# float or raising for int?
386+
"b": [na_value, 3.0, na_value, 4.0, na_value, na_value],
387+
"c": [na_value, 3.0, na_value, 4.0, na_value, na_value],
386388
},
387389
dtype=any_real_nullable_dtype,
388390
)

pandas/tests/series/methods/test_rank.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
276276

277277
ser = ser if dtype is None else ser.astype(dtype)
278278
result = ser.rank(method=method)
279+
if dtype == "float64[pyarrow]":
280+
# the NaNs are not treated as NA
281+
exp = exp.copy()
282+
if method == "average":
283+
exp[np.isnan(ser)] = 9.5
284+
elif method == "dense":
285+
exp[np.isnan(ser)] = 6
279286
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
280287

281288
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
@@ -321,6 +328,8 @@ def test_rank_tie_methods_on_infs_nans(
321328
order = [ranks[1], ranks[0], ranks[2]]
322329
elif na_option == "bottom":
323330
order = [ranks[0], ranks[2], ranks[1]]
331+
elif dtype == "float64[pyarrow]":
332+
order = [ranks[0], [NA] * chunk, ranks[1]]
324333
else:
325334
order = [ranks[0], [np.nan] * chunk, ranks[1]]
326335
expected = order if ascending else order[::-1]

0 commit comments

Comments
 (0)