diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 986beb9..8090e04 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -10,13 +10,11 @@ from anndata import AnnData from loguru import logger from tqdm import tqdm -from typing import Callable, Any from crested._logging import log_and_raise from crested.tl import TaskConfig from crested.tl._utils import ( _weighted_difference, - EnhancerOptimizer, generate_motif_insertions, generate_mutagenesis, hot_encoding_to_sequence, @@ -1138,58 +1136,47 @@ def enhancer_design_motif_implementation( def enhancer_design_in_silico_evolution( self, + target_class: str, n_mutations: int, n_sequences: int, - target_class: str | None = None, return_intermediate: bool = False, + class_penalty_weights: np.ndarray | None = None, no_mutation_flanks: tuple | None = None, target_len: int | None = None, - enhancer_optimizer: EnhancerOptimizer | None = None, - **kwargs: dict[str, Any] ) -> tuple[list[dict], list] | list: """ Create synthetic enhancers for a specified class using in silico evolution (ISE). Parameters ---------- + target_class + Class name for which the enhancers will be designed for. n_mutations Number of mutations per sequence n_sequences Number of enhancers to design - target_class - Class name for which the enhancers will be designed for. If this value is set to None a custom target can be - defined using kwargs. return_intermediate If True, returns a dictionary with predictions and changes made in intermediate steps for selected sequences + class_penalty_weights + Array with a value per class, determining the penalty weight for that class to be used in scoring + function for sequence selection. no_mutation_flanks A tuple of integers which determine the regions in each flank to not do implementations. target_len Length of the area in the center of the sequence to make implementations, ignored if no_mutation_flanks is supplied. - enhancer_optimizer - An instance of EnhancerOptimizer, defining how sequences should be optimized. - If None, a default EnhancerOptimizer will be initialized using `_weighted_difference` - as optimization function. - kwargs - Keyword arguments that will be passed to the `get_best` function of the EnhancerOptimizer Returns ------- A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate mutations and predictions """ - if target_class is not None: - self._check_contribution_scores_params([target_class]) - - all_class_names = list(self.anndatamodule.adata.obs_names) + self._check_contribution_scores_params([target_class]) - target = all_class_names.index(target_class) + all_class_names = list(self.anndatamodule.adata.obs_names) - if enhancer_optimizer is None: - enhancer_optimizer = EnhancerOptimizer( - optimize_func = _weighted_difference - ) + target = all_class_names.index(target_class) # get input sequence length of the model seq_len = ( @@ -1250,12 +1237,11 @@ def enhancer_design_in_silico_evolution( mutagenesis_predictions = self.model.predict(mutagenesis) # determine the best mutation - - best_mutation = enhancer_optimizer.get_best( - mutated_predictions = mutagenesis_predictions, - original_prediction = current_prediction, - target = target, - **kwargs + best_mutation = _weighted_difference( + mutagenesis_predictions, + current_prediction, + target, + class_penalty_weights, ) sequence_onehot = mutagenesis[best_mutation : best_mutation + 1] diff --git a/src/crested/tl/_utils.py b/src/crested/tl/_utils.py index 79ffc57..e6fbd06 100644 --- a/src/crested/tl/_utils.py +++ b/src/crested/tl/_utils.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, Callable - import numpy as np @@ -94,32 +92,9 @@ def generate_motif_insertions(x, motif, flanks=(0, 0), masked_locations=None): return np.concatenate(x_mut, axis=0), insertion_locations -class EnhancerOptimizer: - def __init__( - self, - optimize_func: Callable[..., np.intp] - ) -> None: - self.optimize_func = optimize_func - - def get_best( - self, - mutated_predictions: np.ndarray, - original_prediction: np.ndarray, - target: int | list[int], - **kwargs: dict[str, Any] - ) -> np.intp: - return self.optimize_func( - mutated_predictions, - original_prediction, - target, - **kwargs - ) def _weighted_difference( - mutated_predictions: np.ndarray, - original_prediction: np.ndarray, - target: int, - class_penalty_weights: np.ndarray | None = None + mutated_predictions, original_prediction, target, class_penalty_weights=None ): n_classes = original_prediction.shape[1] penalty_factor = 1 / n_classes