Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing fixes #235

Merged
merged 104 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
da0f0be
Update core.py
Illviljan Apr 21, 2023
fd65311
Update xarray.py
Illviljan Apr 23, 2023
e9997bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2023
8a63dce
avoid renaming
Illviljan Apr 23, 2023
97fd352
Merge branch 'typing_fixes' of https://github.com/Illviljan/flox into…
Illviljan Apr 23, 2023
a935ca9
Update xarray.py
Illviljan Apr 24, 2023
1c82e0f
Update xarray.py
Illviljan Apr 24, 2023
4695dcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
46f429c
Update xarray.py
Illviljan Apr 24, 2023
c3c7828
Update xarray.py
Illviljan Apr 24, 2023
eff861d
Update xarray.py
Illviljan Apr 24, 2023
6762ae2
Merge remote-tracking branch 'upstream/main' into typing_fixes
Illviljan May 1, 2023
d06b7a9
Merge branch 'main' into typing_fixes
dcherian May 1, 2023
053646a
Merge remote-tracking branch 'upstream/main' into typing_fixes
Illviljan May 30, 2023
ad51845
Update xarray.py
Illviljan May 30, 2023
aee3e6c
Update xarray.py
Illviljan May 30, 2023
52027ea
split to optional
Illviljan May 30, 2023
c4a7347
Update xarray.py
Illviljan May 30, 2023
6941966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
63a413d
Update xarray.py
Illviljan May 30, 2023
836214f
convert to pd.Index instead of ndarray
Illviljan May 30, 2023
80169df
Handled different slicer types?
Illviljan May 30, 2023
b96de24
not supported instead?
Illviljan May 30, 2023
bc5a404
specify type for simple_combine
Illviljan May 30, 2023
8e20f16
Handle None in agg.min_count
Illviljan May 30, 2023
dd6cafe
Update core.py
Illviljan May 30, 2023
97c30a5
Update core.py
Illviljan May 30, 2023
f3e10ad
Update core.py
Illviljan May 30, 2023
2359be9
Update core.py
Illviljan May 30, 2023
d2378de
Update core.py
Illviljan May 30, 2023
0a7d7c1
Update core.py
Illviljan May 30, 2023
57a3ce7
Update core.py
Illviljan May 30, 2023
305cec9
Update core.py
Illviljan May 30, 2023
8e5413f
Update core.py
Illviljan May 30, 2023
9f5bf4a
Update core.py
Illviljan May 30, 2023
51af4cc
Update core.py
Illviljan May 30, 2023
f869980
Update core.py
Illviljan May 30, 2023
3234096
Update core.py
Illviljan May 31, 2023
9aee5a0
Update core.py
Illviljan Jun 1, 2023
9fab7cf
Update core.py
Illviljan Jun 1, 2023
ee7c042
add overloads and rename
Illviljan Jun 1, 2023
c2d5d15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
4a8a926
more overloads
Illviljan Jun 1, 2023
630e1bf
ignore
Illviljan Jun 1, 2023
4d4a697
Update core.py
Illviljan Jun 1, 2023
cba746a
Update xarray.py
Illviljan Jun 1, 2023
b346137
Update core.py
Illviljan Jun 1, 2023
d86795e
Update core.py
Illviljan Jun 1, 2023
edf5dea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
1fc3d7c
Update core.py
Illviljan Jun 1, 2023
b11631b
Update core.py
Illviljan Jun 1, 2023
27d0418
Update core.py
Illviljan Jun 1, 2023
1c6dd95
Update core.py
Illviljan Jun 1, 2023
87b2e9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
1f3e561
Update core.py
Illviljan Jun 1, 2023
dfacd9c
Merge branch 'typing_fixes' of https://github.com/Illviljan/flox into…
Illviljan Jun 1, 2023
17af0ca
Update core.py
Illviljan Jun 1, 2023
d02983e
Update core.py
Illviljan Jun 1, 2023
7808b75
Update core.py
Illviljan Jun 1, 2023
269ca37
Update core.py
Illviljan Jun 1, 2023
f7bdad0
Update flox/core.py
Illviljan Jun 2, 2023
83d2612
Update flox/core.py
Illviljan Jun 2, 2023
3a3b871
Update core.py
Illviljan Jun 4, 2023
ccb18d5
Merge branch 'typing_fixes' of https://github.com/Illviljan/flox into…
Illviljan Jun 4, 2023
ce9e071
Update core.py
Illviljan Jun 5, 2023
e5ca125
Update core.py
Illviljan Jun 5, 2023
9ad0df3
Update core.py
Illviljan Jun 5, 2023
0864db6
Update core.py
Illviljan Jun 5, 2023
7a41ed9
Update core.py
Illviljan Jun 5, 2023
67c6864
Update core.py
Illviljan Jun 5, 2023
bb42b43
Update core.py
Illviljan Jun 5, 2023
979c66a
Update core.py
Illviljan Jun 5, 2023
d8b555f
Update core.py
Illviljan Jun 5, 2023
179a51b
Update core.py
Illviljan Jun 6, 2023
34eb030
Update core.py
Illviljan Jun 6, 2023
f520b46
Update core.py
Illviljan Jun 6, 2023
e51a7ea
Update core.py
Illviljan Jun 6, 2023
5eaeb11
Update core.py
Illviljan Jun 6, 2023
e52a3c0
Merge branch 'main' into typing_fixes
dcherian Jun 14, 2023
78aebe7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2023
197b924
Update xarray.py
Illviljan Jun 17, 2023
fb5e6ec
Have to add another type here because of xarray not supporting Interv…
Illviljan Jun 17, 2023
fb03ef0
Update xarray.py
Illviljan Jun 17, 2023
8e55d3a
test ex instead of e
Illviljan Jun 23, 2023
8b441cb
Revert "test ex instead of e"
Illviljan Jun 23, 2023
bce7c73
check reveal_type
Illviljan Jun 24, 2023
36f872f
without e
Illviljan Jun 24, 2023
ad2037a
try no redefinition
Illviljan Jun 24, 2023
5122703
IF redefining ex, mypy always takes the first definition of ex. even…
Illviljan Jun 24, 2023
12c91e1
test min_count=0
Illviljan Jun 27, 2023
b0f4154
test min_count=0
Illviljan Jun 27, 2023
1e11f8f
test min_count=0
Illviljan Jun 27, 2023
173facd
test min_count=0
Illviljan Jun 27, 2023
f11542b
test min_count = 0
Illviljan Jun 27, 2023
dff70e6
test min_count=0
Illviljan Jun 27, 2023
20c0269
test min_count=0
Illviljan Jun 27, 2023
7c0720b
test min_count=0
Illviljan Jun 27, 2023
2989080
test min_count=0
Illviljan Jun 27, 2023
0573a9a
test min_count=0
Illviljan Jun 27, 2023
45a9bcf
test min_count=0
Illviljan Jun 27, 2023
90bb6dc
test min_count=0
Illviljan Jun 27, 2023
0380f97
test min_count=0
Illviljan Jun 27, 2023
8a6d04a
Update asv_bench/benchmarks/combine.py
dcherian Jul 3, 2023
fc96902
Merge branch 'main' into typing_fixes
dcherian Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import Any

import numpy as np

Expand Down Expand Up @@ -43,8 +44,8 @@ class Combine1d(Combine):
this is for reducting along a single dimension
"""

def setup(self, *args, **kwargs):
def construct_member(groups):
def setup(self, *args, **kwargs) -> None:
def construct_member(groups) -> dict[str, Any]:
return {
"groups": groups,
"intermediates": [
Expand Down
11 changes: 4 additions & 7 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
# how to aggregate results after first round of reduction
self.combine: FuncTuple = _atleast_1d(combine)
# simpler reductions used with the "simple combine" algorithm
self.simple_combine = None
self.simple_combine: tuple[Callable, ...] = ()
# final aggregation
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
# finalize results (see mean)
Expand All @@ -207,7 +207,7 @@ def __init__(

# The following are set by _initialize_aggregation
self.finalize_kwargs: dict[Any, Any] = {}
self.min_count: int | None = None
self.min_count: int = 0

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -504,7 +504,7 @@ def _initialize_aggregation(
dtype,
array_dtype,
fill_value,
min_count: int | None,
min_count: int,
finalize_kwargs: dict[Any, Any] | None,
) -> Aggregation:
if not isinstance(func, Aggregation):
Expand Down Expand Up @@ -559,9 +559,6 @@ def _initialize_aggregation(
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs

if min_count is None:
min_count = 0

# This is needed for the dask pathway.
# Because we use intermediate fill_value since a group could be
# absent in one block, but present in another block
Expand All @@ -579,7 +576,7 @@ def _initialize_aggregation(
else:
agg.min_count = 0

simple_combine = []
simple_combine: list[Callable] = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
Expand Down
102 changes: 69 additions & 33 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
T_DuckArray = Union[np.ndarray, DaskArray] # Any ?
T_By = T_DuckArray
T_Bys = tuple[T_By, ...]
T_ExpectIndex = Union[pd.Index, None]
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
T_ExpectIndex = Union[pd.Index]
T_ExpectIndexTuple = tuple[T_ExpectIndex, ...]
T_ExpectIndexOpt = Union[T_ExpectIndex, None]
T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...]
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
T_ExpectTuple = tuple[T_Expect, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt]
T_ExpectOptTuple = tuple[T_ExpectOpt, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple]
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
T_Func = Union[str, Callable]
T_Funcs = Union[T_Func, Sequence[T_Func]]
Expand Down Expand Up @@ -98,7 +102,7 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> pd.Index:
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
if is_duck_dask_array(by):
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
flatby = by.reshape(-1)
Expand Down Expand Up @@ -219,8 +223,13 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
raveled = labels.reshape(-1)
# these are chunks where a label is present
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()

# These invert the label_chunks mapping so we know which labels occur together.
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())
def invert(x) -> tuple[np.ndarray, ...]:
arr = label_chunks.get(x)
return tuple(arr) # type: ignore [arg-type] # pandas issue?

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

if merge:
# First sort by number of chunks occupied by cohort
Expand Down Expand Up @@ -459,7 +468,7 @@ def factorize_(
axes: T_Axes,
*,
fastpath: Literal[True],
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]:
Expand All @@ -471,7 +480,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: Literal[False] = False,
Expand All @@ -484,7 +493,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: bool = False,
Expand All @@ -496,7 +505,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: bool = False,
Expand Down Expand Up @@ -546,7 +555,7 @@ def factorize_(
else:
idx = np.zeros_like(flat, dtype=np.intp) - 1

found_groups.append(expect)
found_groups.append(np.array(expect))
else:
if expect is not None and reindex:
sorter = np.argsort(expect)
Expand All @@ -560,7 +569,7 @@ def factorize_(
idx = sorter[(idx,)]
idx[mask] = -1
else:
idx, groups = pd.factorize(flat, sort=sort)
idx, groups = pd.factorize(flat, sort=sort) # type: ignore # pandas issue?

found_groups.append(np.array(groups))
factorized.append(idx.reshape(groupvar.shape))
Expand Down Expand Up @@ -853,7 +862,8 @@ def _finalize_results(
"""
squeezed = _squeeze_results(results, axis)

if agg.min_count > 0:
min_count = agg.min_count
if min_count > 0:
counts = squeezed["intermediates"][-1]
squeezed["intermediates"] = squeezed["intermediates"][:-1]

Expand All @@ -864,8 +874,8 @@ def _finalize_results(
else:
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)

if agg.min_count > 0:
count_mask = counts < agg.min_count
if min_count > 0:
count_mask = counts < min_count
if count_mask.any():
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
# necessary
Expand Down Expand Up @@ -1283,7 +1293,7 @@ def dask_groupby_agg(
array: DaskArray,
by: T_By,
agg: Aggregation,
expected_groups: pd.Index | None,
expected_groups: T_ExpectIndexOpt,
axis: T_Axes = (),
fill_value: Any = None,
method: T_Method = "map-reduce",
Expand Down Expand Up @@ -1423,9 +1433,11 @@ def dask_groupby_agg(
group_chunks = ((np.nan,),)
else:
if expected_groups is None:
expected_groups = _get_expected_groups(by_input, sort=sort)
groups = (expected_groups.to_numpy(),)
group_chunks = ((len(expected_groups),),)
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
group_chunks = ((len(expected_groups_),),)

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
Expand Down Expand Up @@ -1569,7 +1581,7 @@ def _validate_reindex(
return reindex


def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None:
assert all(b.ndim == by[0].ndim for b in by[1:])
for idx, b in enumerate(by):
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
Expand All @@ -1584,18 +1596,33 @@ def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
)


@overload
def _convert_expected_groups_to_index(
expected_groups: tuple[None, ...], isbin: Sequence[bool], sort: bool
) -> tuple[None, ...]:
...


@overload
def _convert_expected_groups_to_index(
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
) -> T_ExpectIndexTuple:
out: list[pd.Index | None] = []
...


def _convert_expected_groups_to_index(
expected_groups: T_ExpectOptTuple, isbin: Sequence[bool], sort: bool
) -> T_ExpectIndexOptTuple:
out: list[T_ExpectIndexOpt] = []
for ex, isbin_ in zip(expected_groups, isbin):
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin_):
if sort:
ex = ex.sort_values()
out.append(ex)
out.append(ex.sort_values())
else:
out.append(ex)
elif ex is not None:
if isbin_:
out.append(pd.IntervalIndex.from_breaks(ex))
out.append(pd.IntervalIndex.from_breaks(ex)) # type: ignore [arg-type] # TODO: what do we want here?
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
if sort:
ex = np.sort(ex)
Expand All @@ -1613,7 +1640,7 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:

def _factorize_multiple(
by: T_Bys,
expected_groups: T_ExpectIndexTuple,
expected_groups: T_ExpectIndexOptTuple,
any_by_dask: bool,
reindex: bool,
sort: bool = True,
Expand Down Expand Up @@ -1668,7 +1695,17 @@ def _factorize_multiple(
return (group_idx,), found_groups, grp_shape


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:
@overload
def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]:
...


@overload
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple:
...


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectOptTuple:
if expected_groups is None:
return (None,) * nby

Expand Down Expand Up @@ -1935,21 +1972,20 @@ def groupby_reduce(
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
if min_count is None:
if nax < by_.ndim or fill_value is not None:
min_count = 1
min_count_: int = 1
else:
min_count_ = 0
else:
min_count_ = min_count

# TODO: set in xarray?
if (
min_count is not None
and min_count > 0
and func in ["nansum", "nanprod"]
and fill_value is None
):
if min_count_ > 0 and func in ["nansum", "nanprod"] and fill_value is None:
# nansum, nanprod have fill_value=0, 1
# overwrite than when min_count is set
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)

groups: tuple[np.ndarray | DaskArray, ...]
if not has_dask:
Expand Down
Loading
Loading