Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions flox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
"""Top-level module for flox ."""

from . import cache
from .aggregations import Aggregation, Scan # noqa
from .aggregations import Aggregation, Scan
from .core import (
groupby_reduce,
groupby_scan,
rechunk_for_blockwise,
rechunk_for_cohorts,
ReindexStrategy,
ReindexArrayType,
) # noqa
)
from .options import set_options


def _get_version():
Expand All @@ -24,3 +25,15 @@ def _get_version():


__version__ = _get_version()

__all__ = [
"Aggregation",
"Scan",
"groupby_reduce",
"groupby_scan",
"rechunk_for_blockwise",
"rechunk_for_cohorts",
"set_options",
"ReindexStrategy",
"ReindexArrayType",
]
67 changes: 57 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from .cache import memoize
from .lib import ArrayLayer, dask_array_type, sparse_array_type
from .options import OPTIONS
from .xrutils import (
_contains_cftime_datetimes,
_to_pytimedelta,
Expand Down Expand Up @@ -111,6 +112,7 @@
# _simple_combine.
DUMMY_AXIS = -2


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

logger = logging.getLogger("flox")


Expand Down Expand Up @@ -215,8 +217,11 @@ def identity(x: T) -> T:
return x


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())
def _issorted(arr: np.ndarray, ascending=True) -> bool:
if ascending:
return bool((arr[:-1] <= arr[1:]).all())
else:
return bool((arr[:-1] >= arr[1:]).all())


def _is_arg_reduction(func: T_Agg) -> bool:
Expand Down Expand Up @@ -299,7 +304,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
def _get_optimal_chunks_for_groups(chunks, labels):
chunkidx = np.cumsum(chunks) - 1
# what are the groups at chunk boundaries
labels_at_chunk_bounds = _unique(labels[chunkidx])
labels_at_chunk_bounds = pd.unique(labels[chunkidx])
# what's the last index of all groups
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
# what's the last index of groups at the chunk boundaries.
Expand All @@ -317,6 +322,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
Δl = abs(c - l)
if c == 0 or newchunkidx[-1] > l:
continue
f = f.item() # noqa
l = l.item() # noqa
if Δf < Δl and f > newchunkidx[-1]:
newchunkidx.append(f)
else:
Expand Down Expand Up @@ -708,7 +715,9 @@ def rechunk_for_cohorts(
return array.rechunk({axis: newchunks})


def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray:
def rechunk_for_blockwise(
array: DaskArray, axis: T_Axis, labels: np.ndarray, *, force: bool = True
) -> tuple[T_MethodOpt, DaskArray]:
"""
Rechunks array so that group boundaries line up with chunk boundaries, allowing
embarrassingly parallel group reductions.
Expand All @@ -731,14 +740,47 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
DaskArray
Rechunked array
"""
# TODO: this should be unnecessary?
labels = factorize_((labels,), axes=())[0]

chunks = array.chunks[axis]
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
if len(chunks) == 1:
return "blockwise", array

# import dask
# from dask.utils import parse_bytes
# factor = parse_bytes(dask.config.get("array.chunk-size")) / (
# math.prod(array.chunksize) * array.dtype.itemsize
# )
# if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR:
# new_constant_chunks = math.ceil(factor) * max(chunks)
# q, r = divmod(array.shape[axis], new_constant_chunks)
# new_input_chunks = (new_constant_chunks,) * q + (r,)
# else:
new_input_chunks = chunks

# FIXME: this should be unnecessary?
labels = factorize_((labels,), axes=())[0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: get rid of this line

newchunks = _get_optimal_chunks_for_groups(new_input_chunks, labels)
if newchunks == chunks:
return array
return "blockwise", array

Δn = abs(len(newchunks) - len(new_input_chunks))
if pass_num_chunks_threshold := (
Δn / len(new_input_chunks) < OPTIONS["rechunk_blockwise_num_chunks_threshold"]
):
logger.debug("blockwise rechunk passes num chunks threshold")
if pass_chunk_size_threshold := (
# we just pick the max because number of chunks may have changed.
(abs(max(newchunks) - max(new_input_chunks)) / max(new_input_chunks))
< OPTIONS["rechunk_blockwise_chunk_size_threshold"]
):
logger.debug("blockwise rechunk passes chunk size change threshold")

if force or (pass_num_chunks_threshold and pass_chunk_size_threshold):
logger.debug("Rechunking to enable blockwise.")
return "blockwise", array.rechunk({axis: newchunks})
else:
return array.rechunk({axis: newchunks})
logger.debug("Didn't meet thresholds to do automatic rechunking for blockwise reductions.")
return None, array


def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
Expand Down Expand Up @@ -2704,6 +2746,11 @@ def groupby_reduce(
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)

if method is None and is_duck_dask_array(array) and not any_by_dask and by_.ndim == 1 and _issorted(by_):
# Let's try rechunking for sorted 1D by.
(single_axis,) = axis_
method, array = rechunk_for_blockwise(array, single_axis, by_, force=False)

is_first_last = _is_first_last_reduction(func)
if is_first_last:
if has_dask and nax != 1:
Expand Down Expand Up @@ -2891,7 +2938,7 @@ def groupby_reduce(

# if preferred method is already blockwise, no need to rechunk
if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
_, array = rechunk_for_blockwise(array, axis=-1, labels=by_)

result, groups = partial_agg(
array=array,
Expand Down
64 changes: 64 additions & 0 deletions flox/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Started from xarray options.py; vendored from cf-xarray
"""

import copy
from collections.abc import MutableMapping
from typing import Any

OPTIONS: MutableMapping[str, Any] = {
# Thresholds below which we will automatically rechunk to blockwise if it makes sense
# 1. Fractional change in number of chunks after rechunking
"rechunk_blockwise_num_chunks_threshold": 0.25,
# 2. Fractional change in max chunk size after rechunking
"rechunk_blockwise_chunk_size_threshold": 1.5,
# 3. If input arrays have chunk size smaller than `dask.array.chunk-size`,
# then adjust chunks to meet that size first.
# "rechunk.blockwise.chunk_size_factor": 1.5,
}


class set_options: # numpydoc ignore=PR01,PR02
"""
Set options for cf-xarray in a controlled context.

Parameters
----------
rechunk_blockwise_num_chunks_threshold : float
Rechunk if fractional change in number of chunks after rechunking
is less than this amount.
rechunk_blockwise_chunk_size_threshold: float
Rechunk if fractional change in max chunk size after rechunking
is less than this threshold.

Examples
--------

You can use ``set_options`` either as a context manager:

>>> import flox
>>> with flox.set_options(rechunk_blockwise_num_chunks_threshold=1):
... pass

Or to set global options:

>>> flox.set_options(rechunk_blockwise_num_chunks_threshold=1):
"""

def __init__(self, **kwargs):
self.old = {}
for k in kwargs:
if k not in OPTIONS:
raise ValueError(f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}")
self.old[k] = OPTIONS[k]
self._apply_update(kwargs)

def _apply_update(self, options_dict):
options_dict = copy.deepcopy(options_dict)
OPTIONS.update(options_dict)

def __enter__(self):
return

def __exit__(self, type, value, traceback):
self._apply_update(self.old)
3 changes: 2 additions & 1 deletion flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import toolz
import xarray as xr
from packaging.version import Version

Expand Down Expand Up @@ -589,7 +590,7 @@ def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_Data
DataArray or Dataset
Xarray object with rechunked arrays.
"""
return _rechunk(rechunk_array_for_blockwise, obj, dim, labels)
return _rechunk(toolz.compose(toolz.last, rechunk_array_for_blockwise), obj, dim, labels)


def _rechunk(func, obj, dim, labels, **kwargs):
Expand Down
44 changes: 29 additions & 15 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import set_options, xrutils
from flox import xrdtypes as dtypes
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
Expand All @@ -31,6 +31,7 @@
find_group_cohorts,
groupby_reduce,
groupby_scan,
rechunk_for_blockwise,
rechunk_for_cohorts,
reindex_,
subset_to_blocks,
Expand Down Expand Up @@ -979,26 +980,39 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
assert_equal(actual, expected)


@requires_dask
@pytest.mark.parametrize(
"inchunks, expected",
"inchunks, expected, expected_method",
[
[(1,) * 10, (3, 2, 2, 3)],
[(2,) * 5, (3, 2, 2, 3)],
[(3, 3, 3, 1), (3, 2, 5)],
[(3, 1, 1, 2, 1, 1, 1), (3, 2, 2, 3)],
[(3, 2, 2, 3), (3, 2, 2, 3)],
[(4, 4, 2), (3, 4, 3)],
[(5, 5), (5, 5)],
[(6, 4), (5, 5)],
[(7, 3), (7, 3)],
[(8, 2), (7, 3)],
[(9, 1), (10,)],
[(10,), (10,)],
[(1,) * 10, (3, 2, 2, 3), None],
[(2,) * 5, (3, 2, 2, 3), None],
[(3, 3, 3, 1), (3, 2, 5), None],
[(3, 1, 1, 2, 1, 1, 1), (3, 2, 2, 3), None],
[(3, 2, 2, 3), (3, 2, 2, 3), "blockwise"],
[(4, 4, 2), (3, 4, 3), None],
[(5, 5), (5, 5), "blockwise"],
[(6, 4), (5, 5), None],
[(7, 3), (7, 3), "blockwise"],
[(8, 2), (7, 3), None],
[(9, 1), (10,), None],
[(10,), (10,), "blockwise"],
],
)
def test_rechunk_for_blockwise(inchunks, expected):
def test_rechunk_for_blockwise(inchunks, expected, expected_method):
labels = np.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5])
assert _get_optimal_chunks_for_groups(inchunks, labels) == expected
# reversed
assert _get_optimal_chunks_for_groups(inchunks, labels[::-1]) == expected

with set_options(rechunk_blockwise_chunk_size_threshold=-1):
array = dask.array.ones(labels.size, chunks=(inchunks,))
method, array = rechunk_for_blockwise(array, -1, labels, force=False)
assert method == expected_method
assert array.chunks == (inchunks,)

method, array = rechunk_for_blockwise(array, -1, labels[::-1], force=False)
assert method == expected_method
assert array.chunks == (inchunks,)


@requires_dask
Expand Down
Loading