Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TSI score code #1166

Merged
merged 20 commits into from
Mar 4, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 138 additions & 1 deletion src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import collections
import datetime
import enum
import pathlib
import types
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import ListedColormap, Normalize
Expand Down Expand Up @@ -532,6 +535,140 @@ def set_initial_states(
)
return self

def get_tsi(self, n_macrostates: int, terminal_states: List[str], cluster_key: str, **kwargs: Any) -> pd.DataFrame:
"""Compute terminal state identificiation (TSI).

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
n_macrostates
Maximum number of macrostates to consider.
terminal_states
List of terminal states.
cluster_key
Key in :attr:`~anndata.AnnData.obs` defining cluster labels including terminal states.
kwargs
Keyword arguments passed to the class' `compute_macrostates` function.
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Returns TSI as a Pandas DataFrame and adds the class attribute :attr:`tsi`. The DataFrame contains the columns
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
- "Number of macrostates": Number of macrostates computed
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
- "Identified terminal states": Number of terminal states identified
- "Optimal identification": Number of terminal states identified when using an optimal identification scheme
"""
macrostates = {}
for n_states in range(n_macrostates, 0, -1):
self.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs)
macrostates[n_states] = self.macrostates.cat.categories

max_terminal_states = len(terminal_states)

tsi_df = collections.defaultdict(list)
for n_states, states in macrostates.items():
n_terminal_states = (
states.str.replace(r"(_).*", "", regex=True).drop_duplicates().isin(terminal_states).sum()
)
tsi_df["Number of macrostates"].append(n_states)
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
tsi_df["Identified terminal states"].append(n_terminal_states)

tsi_df["Optimal identification"].append(min(n_states, max_terminal_states))

tsi_df = pd.DataFrame(tsi_df)
self.tsi = tsi_df
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

return tsi_df

def get_tsi_score(self) -> float:
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
"""Compute TSI score."""
if not hasattr(self, "tsi"):
raise RuntimeError("Compute TSI with `get_tsi` first.")

optimal_score = self.tsi["Optimal identification"].sum()

return self.tsi["Identified terminal states"].sum() / optimal_score

@d.dedent
def plot_tsi(
self,
tsi_df: pd.DataFrame,
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
x_offset: Tuple[float, float] = (0.2, 0.2),
y_offset: Tuple[float, float] = (0.1, 0.1),
figsize: Tuple[float, float] = (6, 4),
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Tuple[plt.Figure, Axes]:
"""Plot terminal state identificiation (TSI).

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
tsi_df
Pre-computed TSI DataFrame with :meth:`get_tsi_score`.
x_offset
Offset of x-axis.
y_offset
Offset of y-axis.
%(plotting)s
kwargs
Keyword arguments for :meth:`~seaborn.lineplot`.
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Plot TSI of the kernel and an optimal identification strategy.
"""
optimal_identification = tsi_df[["Number of macrostates", "Optimal identification"]]
optimal_identification = optimal_identification.rename(
columns={"Optimal identification": "Identified terminal states"}
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
)
optimal_identification["Method"] = "Optimal identification"
optimal_identification["line_style"] = "--"
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

df = tsi_df[["Number of macrostates", "Identified terminal states"]]
df["Method"] = self.kernel.__class__.__name__
df["line_style"] = "-"

df = pd.concat([df, optimal_identification])

fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
sns.lineplot(
data=df,
x="Number of macrostates",
y="Identified terminal states",
hue="Method",
style="line_style",
drawstyle="steps-post",
ax=ax,
**kwargs,
)

ax.set_xticks(df["Number of macrostates"].unique().astype(int))
# Plot is generated from large to small values on the x-axis
for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]):
if ((label_id + 1) % 5 != 0) and label_id != 0:
label.set_visible(False)
ax.set_yticks(df["Identified terminal states"].unique())

x_min = df["Number of macrostates"].min() - x_offset[0]
x_max = df["Number of macrostates"].max() + x_offset[1]
y_min = df["Identified terminal states"].min() - y_offset[0]
y_max = df["Identified terminal states"].max() + y_offset[1]
ax.set(xlim=[x_min, x_max], ylim=[y_min, y_max])

ax.get_legend().remove()

n_methods = len(df["Method"].unique())
handles, labels = ax.get_legend_handles_labels()
handles[n_methods].set_linestyle("--")
handles = handles[: (n_methods + 1)]
labels = labels[: (n_methods + 1)]
fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1))

if save is not None:
save_fig(fig=fig, path=save)

return fig, ax
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

@d.dedent
def fit(
self,
Expand Down
Loading