Skip to content

Commit

Permalink
typing and pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed May 13, 2024
1 parent a2a27a8 commit fe39cd4
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
28 changes: 16 additions & 12 deletions src/dask_nested/accessor.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
from dask.dataframe.extensions import make_array_nonempty, make_scalar, register_series_accessor
import nested_pandas as npd
from nested_pandas import NestedDtype, NestSeriesAccessor
import dask
from dask.dataframe.extensions import register_series_accessor
from nested_pandas import NestedDtype


@register_series_accessor("nest")
class DaskNestSeriesAccessor(npd.NestSeriesAccessor):

"""The nested-dask version of the nested-pandas NestSeriesAccessor.
Note that this has a very limited implementation relative to nested-pandas.
Parameters
----------
series: dd.series
A series to tie to the accessor
"""

def __init__(self, series):
self._check_series(series)

self._series = series

@staticmethod
def _check_series(series):
"""chcek the validity of the tied series dtype"""
dtype = series.dtype
if not isinstance(dtype, NestedDtype):
raise AttributeError(f"Can only use .nest accessor with a Series of NestedDtype, got {dtype}")

@property
def fields(self) -> list[str]:
"""Names of the nested columns"""
return self._series.head(0).nest.fields
#hacky
#return self._series.partitions[0:1].map_partitions(lambda x: x.nest.fields)
#return self._series.array.field_names

@dask.delayed
def test_fields(self):
return self._series.head(0).nest.fields

return self._series.head(0).nest.fields # hacky
30 changes: 19 additions & 11 deletions src/dask_nested/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dask.dataframe as dd
import dask_expr as dx
import nested_pandas as npd
Expand All @@ -8,8 +10,11 @@
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default

# need this for the base _Frame class
# mypy: disable-error-code="misc"


class _Frame(dx.FrameBase):
class _Frame(dx.FrameBase): # type: ignore
"""Base class for extensions of Dask Dataframes that track additional
Ensemble-related metadata.
"""
Expand All @@ -34,7 +39,7 @@ def __dask_postpersist__(self):

return self._rebuild, (func, args)

def _rebuild(self, graph, func, args):
def _rebuild(self, graph, func, args): # type: ignore
collection = func(graph, *args)
return collection

Expand Down Expand Up @@ -63,7 +68,9 @@ def __getitem__(self, key):
return result

@classmethod
def from_nestedpandas(cls, data, npartitions=None, chunksize=None, sort=True, label=None, ensemble=None):
def from_nestedpandas(
cls, data, npartitions=None, chunksize=None, sort=True, label=None, ensemble=None
) -> NestedFrame:
"""Returns an EnsembleFrame constructed from a TapeFrame.
Parameters
Expand All @@ -89,10 +96,10 @@ def from_nestedpandas(cls, data, npartitions=None, chunksize=None, sort=True, la
The constructed EnsembleFrame object.
"""
result = dd.from_pandas(data, npartitions=npartitions, chunksize=chunksize, sort=sort)
return result
return NestedFrame.from_dask_dataframe(result)

@classmethod
def from_dask_dataframe(cl, df):
def from_dask_dataframe(cls, df) -> NestedFrame:
"""Converts a Dask Dataframe to a Dask-Nested NestedFrame
Parameters
Expand Down Expand Up @@ -129,7 +136,7 @@ def nested_columns(self) -> list:
nest_cols.append(column)
return nest_cols

def add_nested(self, nested, name):
def add_nested(self, nested, name) -> NestedFrame: # type: ignore[name-defined] # noqa: F821
"""Packs a dataframe into a nested column
Parameters
Expand All @@ -146,7 +153,7 @@ def add_nested(self, nested, name):
nested = nested.map_partitions(lambda x: pack_flat(x)).rename(name)
return self.join(nested, how="outer")

def query(self, expr):
def query(self, expr) -> Self: # type: ignore # noqa: F821:
"""
Query the columns of a NestedFrame with a boolean expression. Specified
queries can target nested columns in addition to the typical column set
Expand Down Expand Up @@ -203,7 +210,7 @@ def dropna(
subset: IndexLabel | None = None,
inplace: bool = False,
ignore_index: bool = False,
):
) -> Self: # type: ignore[name-defined] # noqa: F821:
"""
Remove missing values for one layer of the NestedFrame.
Expand Down Expand Up @@ -260,7 +267,6 @@ def dropna(
time.
"""
# grab meta from head, assumes row-based operation
meta = self.head(0)
return self.map_partitions(
lambda x: x.dropna(
axis=axis,
Expand All @@ -271,10 +277,10 @@ def dropna(
inplace=inplace,
ignore_index=ignore_index,
),
meta=meta,
meta=self._meta,
)

def reduce(self, func, *args, meta=None, **kwargs):
def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
"""
Takes a function and applies it to each top-level row of the NestedFrame.
Expand All @@ -292,6 +298,8 @@ def reduce(self, func, *args, meta=None, **kwargs):
args : positional arguments
Positional arguments to pass to the function, the first *args should be the names of the
columns to apply the function to.
meta : dataframe or series-like, optional
The dask meta of the output.
kwargs : keyword arguments, optional
Keyword arguments to pass to the function.
Expand Down

0 comments on commit fe39cd4

Please sign in to comment.