Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom loss function enhancer design #15

Merged
merged 12 commits into from
Sep 27, 2024

Conversation

SeppeDeWinter
Copy link
Collaborator

@SeppeDeWinter SeppeDeWinter commented Sep 5, 2024

  • Fix typing syntax error
  • Allow for custom optimization function for in silicio evolution
  • Change typing of target param in EnhancerOptimizer
  • Allow target to be passed using params.
  • Allow for custom optimization function for motif embedding.

For example:

# use Heart, muscle and myoblast cells as contrast
classes_of_interest = [
    i for i, ct in enumerate(adata.obs_names)
    if "Heart" in ct or "muscle" in ct or "myoblast" in ct
]

# design enhancers that are high in heart but low in myoblast
target = np.array(
      [
        0 if "Cardiac muscle" not in x else 1 for x in adata.obs_names
        if "Heart" in x or "muscle" in x or "myoblast" in x
    ]
)

assert all(["Cardiac muscle" in x for x in adata.obs_names[np.array(classes_of_interest)[np.where(target)[0]]]])

from sklearn.metrics import pairwise
from crested.tl._utils import EnhancerOptimizer

def L2_distance(
    mutated_predictions: np.ndarray,
    original_prediction: np.ndarray,
    target: np.ndarray,
    classes_of_interest: list[int]):
    def scale(X):
        return ((X.T - X.min(1)) / (X.max(1) - X.min(1))).T
    L2_sat_mut = pairwise.euclidean_distances(scale(mutated_predictions)[:,classes_of_interest], target.reshape(1, -1))
    L2_baseline = pairwise.euclidean_distances(scale(original_prediction)[:, classes_of_interest], target.reshape(1, -1))
    return np.argmax((L2_baseline - L2_sat_mut).squeeze())

L2_optimizer = EnhancerOptimizer(
    optimize_func = L2_distance
)

intermediate_info_list, designed_sequences = evaluator.enhancer_design_in_silico_evolution(
  target_class=None, n_sequences=1, n_mutations=30,
  enhancer_optimizer = L2_optimizer,
  target = target,
  return_intermediate = True,
  no_mutation_flanks = (807, 807),
  classes_of_interest = classes_of_interest
)

@SeppeDeWinter
Copy link
Collaborator Author

Note, code for motif embedding has not been tested yet.

@LukasMahieu
Copy link
Collaborator

Okay, looks good and makes sense to me. In the near future we should really make a separate tutorial for enhancer design (including this information here), since as of now it's a one-liner in the introductory tutorial.
@erceksi could you take a look too since you implemented the original function?

@SeppeDeWinter
Copy link
Collaborator Author

Added some extra changes.

Now multiple sequences should be processed in parallel.
Before a call was made to model.predict for each sequence and each iteration.

Now a single call is made to model.predict for each iteration only.

@SeppeDeWinter
Copy link
Collaborator Author

From a quick and dirty benchmark, this code should be around 2x faster.

@SeppeDeWinter SeppeDeWinter merged commit 81ba818 into main Sep 27, 2024
4 checks passed
@nkempynck nkempynck deleted the custom_loss_function_in_silico_evolution branch September 30, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants