Skip to content

Commit

Permalink
Add GPCCA::plot_tsi
Browse files Browse the repository at this point in the history
Add class method to plot terminal state identification.
  • Loading branch information
WeilerP committed Feb 22, 2024
1 parent 88d244a commit c829bfb
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import pathlib
import types
from pathlib import PosixPath
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -10,6 +11,7 @@
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import ListedColormap, Normalize
Expand Down Expand Up @@ -586,6 +588,96 @@ def get_tsi_score(self) -> float:

return self.tsi["Identified terminal states"].sum() / optimal_score

def plot_tsi(
self,
tsi_df: pd.DataFrame,
fname: Optional[PosixPath] = None,
x_offset: Optional[Tuple[float, float]] = None,
y_offset: Optional[Tuple[float, float]] = None,
**kwargs: Any,
):
"""Plot terminal state identificiation (TSI).
Parameters
----------
tsi_df
Pre-computed TSI DataFrame.
fname
File name under which the plot is saved. The plot is not saved if the argument is not specified.
x_offset
Offset of x-axis. Defaults to `[0.2, 0.2]` if not specified.
y_offset
Offset of y-axis. Defaults to `[0.1, 0.1]` if not specified.
kwargs
Keyword arguments for :meth:`~seaborn.lineplot`.
Returns
-------
Returns TSI as a Pandas DataFrame and adds the class attribute :attr:`tsi`.
"""
if x_offset is None:
x_offset = [0.2, 0.2]

if y_offset is None:
y_offset = [0.1, 0.1]

optimal_identification = tsi_df[["Number of macrostates", "Optimal identification"]]
optimal_identification = optimal_identification.rename(
columns={"Optimal identification": "Identified terminal states"}
)
optimal_identification["Method"] = "Optimal identification"
optimal_identification["line_style"] = "--"

df = tsi_df[["Number of macrostates", "Identified terminal states"]]
df["Method"] = self.kernel.__class__.__name__
df["line_style"] = "-"

df = pd.concat([df, optimal_identification])

fig, ax = plt.subplots(figsize=(6, 4))
sns.lineplot(
data=df,
x="Number of macrostates",
y="Identified terminal states",
hue="Method",
style="line_style",
drawstyle="steps-post",
ax=ax,
**kwargs,
)

ax.set_xticks(df["Number of macrostates"].unique().astype(int))
for label_id, label in enumerate(ax.xaxis.get_ticklabels()):
if ((label_id + 1) % 5 != 0) and label_id != 0:
label.set_visible(False)
ax.set_yticks(df["Identified terminal states"].unique())

x_min = df["Number of macrostates"].min() - x_offset[0]
x_max = df["Number of macrostates"].max() + x_offset[1]
y_min = df["Identified terminal states"].min() - y_offset[0]
y_max = df["Identified terminal states"].max() + y_offset[1]
ax.set(xlim=[x_min, x_max], ylim=[y_min, y_max])

ax.get_legend().remove()

n_methods = len(df["Method"].unique())
handles, labels = ax.get_legend_handles_labels()
handles[n_methods].set_linestyle("--")
handles = handles[: (n_methods + 1)]
labels = labels[: (n_methods + 1)]
fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1))
plt.tight_layout()
plt.show()

if fname is not None:
format = fname.suffix[1:]
fig.savefig(
fname=fname,
format=format,
transparent=True,
bbox_inches="tight",
)

@d.dedent
def fit(
self,
Expand Down

0 comments on commit c829bfb

Please sign in to comment.