From 24be2d191b80abb97d6adfd855df280c7a1da483 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 4 Sep 2024 10:15:39 -0600 Subject: [PATCH] Preserve dtype better when specified. --- flox/aggregations.py | 15 ++++++++------- flox/xrdtypes.py | 9 +++++++-- tests/strategies.py | 4 ++-- tests/test_core.py | 10 ++++++++++ tests/test_properties.py | 24 +++++++++++++++++++++++- 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 5515a5b6..5772af4a 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -292,8 +292,8 @@ def __repr__(self) -> str: combine="sum", fill_value=0, final_fill_value=0, - dtypes=np.intp, - final_dtype=np.intp, + dtypes=np.integer, + final_dtype=np.integer, ) # note that the fill values are the result of np.func([np.nan, np.nan]) @@ -521,12 +521,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +# if the input contains integers or floats smaller than float64, +# the output data-type is float64. Otherwise, the output data-type is the same as that +# of the input. quantile = Aggregation( name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) nanquantile = Aggregation( @@ -534,7 +537,7 @@ def quantile_new_dims_func(q) -> tuple[Dim]: fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) mode = Aggregation( @@ -780,10 +783,8 @@ def _initialize_aggregation( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) final_dtype = dtypes._normalize_dtype( - dtype_ or agg.dtype_init["final"], array_dtype, fill_value + dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value ) - if not agg.preserves_dtype: - final_dtype = dtypes._maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 3fd0f4fe..34d0d2a5 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -150,9 +150,14 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: +def _normalize_dtype( + dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None +) -> np.dtype: if dtype is None: - dtype = array_dtype + if not preserves_dtype: + dtype = _maybe_promote_int(array_dtype) + else: + dtype = array_dtype if dtype is np.floating: # mean, std, var always result in floating # but we preserve the array's dtype if it is floating diff --git a/tests/strategies.py b/tests/strategies.py index a2f95da1..52dd0e49 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -26,7 +26,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO: stop excluding everything but U -array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU") +array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU") by_dtype_st = supported_dtypes() NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list( @@ -38,7 +38,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: [f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS] ) numeric_arrays = npst.arrays( - elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st + elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes ) all_arrays = npst.arrays( elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes() diff --git a/tests/test_core.py b/tests/test_core.py index 5d4e7ec3..98f02690 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1929,3 +1929,13 @@ def test_ffill_bfill(chunks, size, add_nan_by, func): expected = flox.groupby_scan(array.compute(), by, func=func) actual = flox.groupby_scan(array, by, func=func) assert_equal(expected, actual) + + +def test_agg_dtypes(): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1]) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, group, expected_groups=(np.array([1, 2]),), func="sum", dtype="uint8" + ) + assert actual.dtype == np.uint8 diff --git a/tests/test_properties.py b/tests/test_properties.py index c032f074..16faa139 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -19,7 +19,7 @@ from flox.xrutils import notnull from . import assert_equal -from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays +from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays from .strategies import chunks as chunks_strategy dask.config.set(scheduler="sync") @@ -223,3 +223,25 @@ def test_first_last_useless(data, func): actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy") expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype) assert_equal(actual, expected) + + +@given( + func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), + engine=st.sampled_from(["numpy", "flox"]), + array_dtype=st.none() | array_dtypes, + dtype=st.none() | array_dtypes, +) +def test_agg_dtype_specified(func, array_dtype, dtype, engine): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, + group, + expected_groups=(np.array([1, 2]),), + func=func, + dtype=dtype, + engine=engine, + ) + expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) + assert actual.dtype == expected.dtype