diff --git a/flox/core.py b/flox/core.py index d1ac8815..633bd695 100644 --- a/flox/core.py +++ b/flox/core.py @@ -179,7 +179,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool: def _is_first_last_reduction(func: T_Agg) -> bool: - return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"] + if isinstance(func, Aggregation): + func = func.name + return func in ["nanfirst", "nanlast", "first", "last"] def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex: @@ -680,6 +682,7 @@ def rechunk_for_blockwise( abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD ) ): + logger.debug("Rechunking to enable blockwise.") # Less than 25% change in number of chunks, let's do it return array.rechunk({axis: newchunks}) else: @@ -1668,7 +1671,12 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None - do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown + do_grouped_combine = ( + _is_arg_reduction(agg) + or labels_are_unknown + or (_is_first_last_reduction(agg) and array.dtype.kind != "f") + ) + do_simple_combine = not do_grouped_combine if method == "blockwise": # use the "non dask" code path, but applied blockwise @@ -2012,8 +2020,13 @@ def _validate_reindex( expected_groups, any_by_dask: bool, is_dask_array: bool, + array_dtype: Any, ) -> bool | None: # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa + def first_or_last(): + return func in ["first", "last"] or ( + _is_first_last_reduction(func) and array_dtype.kind != "f" + ) all_numpy = not is_dask_array and not any_by_dask if reindex is True and not all_numpy: @@ -2023,7 +2036,7 @@ def _validate_reindex( raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) - if func in ["first", "last"]: + if first_or_last(): raise ValueError("reindex must be None or False when func is 'first' or 'last.") if reindex is None: @@ -2034,9 +2047,10 @@ def _validate_reindex( if all_numpy: return True - if func in ["first", "last"]: + if first_or_last(): # have to do the grouped_combine since there's no good fill_value - reindex = False + # Also needed for nanfirst, nanlast with no-NaN dtypes + return False if method == "blockwise": # for grouping by dask arrays, we set reindex=True @@ -2439,7 +2453,13 @@ def groupby_reduce( raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") reindex = _validate_reindex( - reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array) + reindex, + func, + method, + expected_groups, + any_by_dask, + is_duck_dask_array(array), + array.dtype, ) if not is_duck_array(array): @@ -2638,7 +2658,7 @@ def groupby_reduce( # TODO: clean this up reindex = _validate_reindex( - reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array) + reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype ) if TYPE_CHECKING: