diff --git a/tests/test_kernels.py b/tests/test_kernels.py index ee0d510e..c08f3736 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -431,12 +431,11 @@ def test_write_adata_key(self, adata: AnnData): np.testing.assert_array_equal(adata.obsp["foo"].toarray(), vk.transition_matrix.toarray()) @pytest.mark.parametrize("model", ["deterministic", "stochastic"]) - def test_vk_row_normalized(self, adata: AnnData, model: str): + def test_vk_row_normalized(self, adata: AnnData, model: Literal["deterministic", "stochastic", "monte_carlo"]): if model == "stochastic": pytest.importorskip("jax") - pytest.importorskip("jaxlib") vk = VelocityKernel(adata) - vk.compute_transition_matrix(model="stochastic", softmax_scale=4) + vk.compute_transition_matrix(model=model, softmax_scale=4) np.testing.assert_allclose(vk.transition_matrix.sum(1), 1, rtol=_rtol)