Skip to content

Commit

Permalink
Preserve dtype better when specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 4, 2024
1 parent 4dbadae commit 24be2d1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
15 changes: 8 additions & 7 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def __repr__(self) -> str:
combine="sum",
fill_value=0,
final_fill_value=0,
dtypes=np.intp,
final_dtype=np.intp,
dtypes=np.integer,
final_dtype=np.integer,
)

# note that the fill values are the result of np.func([np.nan, np.nan])
Expand Down Expand Up @@ -521,20 +521,23 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


# if the input contains integers or floats smaller than float64,
# the output data-type is float64. Otherwise, the output data-type is the same as that
# of the input.
quantile = Aggregation(
name="quantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
nanquantile = Aggregation(
name="nanquantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(
Expand Down Expand Up @@ -780,10 +783,8 @@ def _initialize_aggregation(
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
)
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
Expand Down
9 changes: 7 additions & 2 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
def _normalize_dtype(
dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
) -> np.dtype:
if dtype is None:
dtype = array_dtype
if not preserves_dtype:
dtype = _maybe_promote_int(array_dtype)
else:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
Expand Down
4 changes: 2 additions & 2 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:


# TODO: stop excluding everything but U
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()

NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
Expand All @@ -38,7 +38,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
)
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
Expand Down
10 changes: 10 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,3 +1929,13 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
expected = flox.groupby_scan(array.compute(), by, func=func)
actual = flox.groupby_scan(array, by, func=func)
assert_equal(expected, actual)


def test_agg_dtypes():
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1])
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts, group, expected_groups=(np.array([1, 2]),), func="sum", dtype="uint8"
)
assert actual.dtype == np.uint8
24 changes: 23 additions & 1 deletion tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flox.xrutils import notnull

from . import assert_equal
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")
Expand Down Expand Up @@ -223,3 +223,25 @@ def test_first_last_useless(data, func):
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)


@given(
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
engine=st.sampled_from(["numpy", "flox"]),
array_dtype=st.none() | array_dtypes,
dtype=st.none() | array_dtypes,
)
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts,
group,
expected_groups=(np.array([1, 2]),),
func=func,
dtype=dtype,
engine=engine,
)
expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
assert actual.dtype == expected.dtype

0 comments on commit 24be2d1

Please sign in to comment.