Skip to content

Commit

Permalink
Merge branch 'fix-blockwise' into fix-dtype-again
Browse files Browse the repository at this point in the history
* fix-blockwise:
  Add test
  Avoid rechunking when preferred_method="blockwise"
  • Loading branch information
dcherian committed Sep 16, 2024
2 parents 38a556f + 40efff2 commit 44e3cc3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
DaskArray
Rechunked array
"""
# TODO: this should be unnecessary?
labels = factorize_((labels,), axes=())[0]
chunks = array.chunks[axis]
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
Expand Down Expand Up @@ -2623,7 +2624,8 @@ def groupby_reduce(

partial_agg = partial(dask_groupby_agg, **kwargs)

if method == "blockwise" and by_.ndim == 1:
# if preferred method is already blockwise, no need to rechunk
if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

result, groups = partial_agg(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,3 +1997,11 @@ def test_agg_dtypes(func, engine):
)
expected = _get_array_func(func)(counts, dtype="uint8")
assert actual.dtype == np.uint8 == expected.dtype


def test_blockwise_avoid_rechunk():
array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64)
by = np.array(["1", "1", "0", "", "0", ""], dtype="<U1")
actual, groups = groupby_reduce(array, by, func="first")
assert_equal(groups, ["", "0", "1"])
assert_equal(actual, [0, 0, 0])

0 comments on commit 44e3cc3

Please sign in to comment.