Skip to content

Commit

Permalink
Refactor TSI computation (#1174)
Browse files Browse the repository at this point in the history
* Copy the estimator to prevent overwriting
attributes
  • Loading branch information
michalk8 authored Mar 4, 2024
1 parent 696a23b commit 3b8e314
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def tsi(
cluster_key: Optional[str] = None,
**kwargs: Any,
) -> float:
"""Compute terminal state identificiation (TSI) score.
"""Compute terminal state identification (TSI) score.
Parameters
----------
Expand Down Expand Up @@ -573,10 +573,13 @@ def tsi(
if cluster_key is None:
raise RuntimeError("`cluster_key` needs to be specified to compute TSI.")

# create a new GPCCA object to avoid unsetting attributes
# that depend on the macrostates, e.g. the terminal states
g = self.copy(deep=True)
macrostates = {}
for n_states in range(n_macrostates, 0, -1):
self.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs)
macrostates[n_states] = self.macrostates.cat.categories
g = g.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs)
macrostates[n_states] = g.macrostates.cat.categories

max_terminal_states = len(terminal_states)

Expand Down
1 change: 0 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,7 +2259,6 @@ def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str):

@compare(kind="gpcca")
def test_plot_tsi(self, mc: GPCCA, fpath: str):
mc = mc.copy(deep=True)
terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"]
cluster_key = "clusters"
_ = mc.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10)
Expand Down

0 comments on commit 3b8e314

Please sign in to comment.