Skip to content

Commit

Permalink
Fix first, last again
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 2, 2024
1 parent ebcd06c commit 8f7d093
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:


def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
if isinstance(func, Aggregation):
func = func.name
return func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
Expand Down Expand Up @@ -680,6 +682,7 @@ def rechunk_for_blockwise(
abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
)
):
logger.debug("Rechunking to enable blockwise.")
# Less than 25% change in number of chunks, let's do it
return array.rechunk({axis: newchunks})
else:
Expand Down Expand Up @@ -1668,7 +1671,12 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
do_grouped_combine = (
_is_arg_reduction(agg)
or labels_are_unknown
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
)
do_simple_combine = not do_grouped_combine

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
Expand Down Expand Up @@ -2012,8 +2020,13 @@ def _validate_reindex(
expected_groups,
any_by_dask: bool,
is_dask_array: bool,
array_dtype: Any,
) -> bool | None:
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
def first_or_last():
return func in ["first", "last"] or (
_is_first_last_reduction(func) and array_dtype.kind != "f"
)

all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
Expand All @@ -2023,7 +2036,7 @@ def _validate_reindex(
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
if func in ["first", "last"]:
if first_or_last():
raise ValueError("reindex must be None or False when func is 'first' or 'last.")

if reindex is None:
Expand All @@ -2034,9 +2047,10 @@ def _validate_reindex(
if all_numpy:
return True

if func in ["first", "last"]:
if first_or_last():
# have to do the grouped_combine since there's no good fill_value
reindex = False
# Also needed for nanfirst, nanlast with no-NaN dtypes
return False

if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
Expand Down Expand Up @@ -2439,7 +2453,13 @@ def groupby_reduce(
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
reindex,
func,
method,
expected_groups,
any_by_dask,
is_duck_dask_array(array),
array.dtype,
)

if not is_duck_array(array):
Expand Down Expand Up @@ -2638,7 +2658,7 @@ def groupby_reduce(

# TODO: clean this up
reindex = _validate_reindex(
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
)

if TYPE_CHECKING:
Expand Down

0 comments on commit 8f7d093

Please sign in to comment.