From c829bfb2693aac32ed1770f89d3fbf5671324b41 Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Thu, 22 Feb 2024 10:55:11 +0000 Subject: [PATCH] Add `GPCCA::plot_tsi` Add class method to plot terminal state identification. --- .../estimators/terminal_states/_gpcca.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index a119fde1d..f9a207f09 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -2,6 +2,7 @@ import enum import pathlib import types +from pathlib import PosixPath from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -10,6 +11,7 @@ from pandas.api.types import infer_dtype import matplotlib.pyplot as plt +import seaborn as sns from matplotlib.axes import Axes from matplotlib.colorbar import ColorbarBase from matplotlib.colors import ListedColormap, Normalize @@ -586,6 +588,96 @@ def get_tsi_score(self) -> float: return self.tsi["Identified terminal states"].sum() / optimal_score + def plot_tsi( + self, + tsi_df: pd.DataFrame, + fname: Optional[PosixPath] = None, + x_offset: Optional[Tuple[float, float]] = None, + y_offset: Optional[Tuple[float, float]] = None, + **kwargs: Any, + ): + """Plot terminal state identificiation (TSI). + + Parameters + ---------- + tsi_df + Pre-computed TSI DataFrame. + fname + File name under which the plot is saved. The plot is not saved if the argument is not specified. + x_offset + Offset of x-axis. Defaults to `[0.2, 0.2]` if not specified. + y_offset + Offset of y-axis. Defaults to `[0.1, 0.1]` if not specified. + kwargs + Keyword arguments for :meth:`~seaborn.lineplot`. + + Returns + ------- + Returns TSI as a Pandas DataFrame and adds the class attribute :attr:`tsi`. + """ + if x_offset is None: + x_offset = [0.2, 0.2] + + if y_offset is None: + y_offset = [0.1, 0.1] + + optimal_identification = tsi_df[["Number of macrostates", "Optimal identification"]] + optimal_identification = optimal_identification.rename( + columns={"Optimal identification": "Identified terminal states"} + ) + optimal_identification["Method"] = "Optimal identification" + optimal_identification["line_style"] = "--" + + df = tsi_df[["Number of macrostates", "Identified terminal states"]] + df["Method"] = self.kernel.__class__.__name__ + df["line_style"] = "-" + + df = pd.concat([df, optimal_identification]) + + fig, ax = plt.subplots(figsize=(6, 4)) + sns.lineplot( + data=df, + x="Number of macrostates", + y="Identified terminal states", + hue="Method", + style="line_style", + drawstyle="steps-post", + ax=ax, + **kwargs, + ) + + ax.set_xticks(df["Number of macrostates"].unique().astype(int)) + for label_id, label in enumerate(ax.xaxis.get_ticklabels()): + if ((label_id + 1) % 5 != 0) and label_id != 0: + label.set_visible(False) + ax.set_yticks(df["Identified terminal states"].unique()) + + x_min = df["Number of macrostates"].min() - x_offset[0] + x_max = df["Number of macrostates"].max() + x_offset[1] + y_min = df["Identified terminal states"].min() - y_offset[0] + y_max = df["Identified terminal states"].max() + y_offset[1] + ax.set(xlim=[x_min, x_max], ylim=[y_min, y_max]) + + ax.get_legend().remove() + + n_methods = len(df["Method"].unique()) + handles, labels = ax.get_legend_handles_labels() + handles[n_methods].set_linestyle("--") + handles = handles[: (n_methods + 1)] + labels = labels[: (n_methods + 1)] + fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1)) + plt.tight_layout() + plt.show() + + if fname is not None: + format = fname.suffix[1:] + fig.savefig( + fname=fname, + format=format, + transparent=True, + bbox_inches="tight", + ) + @d.dedent def fit( self,