Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for perturbation magnitude/direction #17

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 52 additions & 50 deletions gears/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd
import numpy as np
import scanpy as sc
from random import shuffle
sc.settings.verbosity = 0
from tqdm import tqdm
import requests
Expand All @@ -10,7 +9,10 @@
import warnings
warnings.filterwarnings("ignore")

from .utils import parse_single_pert, parse_combo_pert, parse_any_pert, print_sys
from .utils import (
parse_single_pert, parse_combo_pert, parse_any_pert, print_sys,
get_pert_genes, rm_magnitude
)

def rank_genes_groups_by_cov(
adata,
Expand Down Expand Up @@ -53,24 +55,24 @@ def rank_genes_groups_by_cov(
if return_dict:
return gene_dict


def get_DE_genes(adata, skip_calc_de):
adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '1+1' if len(x.split('+')) == 2 else '1')
adata.obs.loc[:, 'control'] = adata.obs.condition.apply(lambda x: 0 if len(x.split('+')) == 2 else 1)
adata.obs.loc[:, 'condition_name'] = adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1)
adata.obs.loc[:, 'condition_name'] = adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1)

adata.obs = adata.obs.astype('category')
if not skip_calc_de:
rank_genes_groups_by_cov(adata,
groupby='condition_name',
covariate='cell_type',
control_group='ctrl_1',
rank_genes_groups_by_cov(adata,
groupby='condition_name',
covariate='cell_type',
control_group='ctrl_1',
n_genes=len(adata.var),
key_added = 'rank_genes_groups_cov_all')
return adata

def get_dropout_non_zero_genes(adata):

# calculate mean expression for each condition
unique_conditions = adata.obs.condition.unique()
conditions2index = {}
Expand All @@ -83,7 +85,7 @@ def get_dropout_non_zero_genes(adata):
pert_list = np.array(list(condition2mean_expression.keys()))
mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1])
ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]]

## in silico modeling and upperbounding
pert2pert_full_id = dict(adata.obs[['condition', 'condition_name']].values)
pert_full_id2pert = dict(adata.obs[['condition_name', 'condition']].values)
Expand Down Expand Up @@ -118,17 +120,17 @@ def get_dropout_non_zero_genes(adata):
non_dropout_gene_idx[pert] = np.sort(non_dropouts)
top_non_dropout_de_20[pert] = np.array(non_dropout_20_gene_id)
top_non_zero_de_20[pert] = np.array(non_zero_20_gene_id)

non_zero = np.where(np.array(X)[0] != 0)[0]
zero = np.where(np.array(X)[0] == 0)[0]
true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0])
non_dropouts = np.concatenate((non_zero, true_zeros))

adata.uns['top_non_dropout_de_20'] = top_non_dropout_de_20
adata.uns['non_dropout_gene_idx'] = non_dropout_gene_idx
adata.uns['non_zeros_gene_idx'] = non_zeros_gene_idx
adata.uns['top_non_zero_de_20'] = top_non_zero_de_20

return adata


Expand All @@ -152,11 +154,11 @@ def split_data(self, test_size=0.1, test_pert_genes=None,
np.random.seed(seed=seed)
unique_perts = [p for p in self.adata.obs['condition'].unique() if
p != 'ctrl']

if self.split_type == 'simulation':
train, test, test_subgroup = self.get_simulation_split(unique_perts,
train_gene_set_size,
combo_seen2_train_frac,
combo_seen2_train_frac,
seed, test_perts, only_test_set_perts)
train, val, val_subgroup = self.get_simulation_split(train,
0.9,
Expand All @@ -174,17 +176,17 @@ def split_data(self, test_size=0.1, test_pert_genes=None,
elif self.split_type == 'no_test':
print('test_pert_genes',str(test_pert_genes))
print('test_perts',str(test_perts))

train, val = self.get_split_list(unique_perts,
test_pert_genes=test_pert_genes,
test_perts=test_perts,
test_size=test_size)
test_size=test_size)
else:
train, test = self.get_split_list(unique_perts,
test_pert_genes=test_pert_genes,
test_perts=test_perts,
test_size=test_size)

train, val = self.get_split_list(train, test_size=val_size)

map_dict = {x: 'train' for x in train}
Expand All @@ -196,19 +198,19 @@ def split_data(self, test_size=0.1, test_pert_genes=None,
self.adata.obs[split_name] = self.adata.obs['condition'].map(map_dict)

if self.split_type == 'simulation':
return self.adata, {'test_subgroup': test_subgroup,
return self.adata, {'test_subgroup': test_subgroup,
'val_subgroup': val_subgroup
}
else:
return self.adata

def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False):
unique_pert_genes = self.get_genes_from_perts(pert_list)

pert_train = []
pert_test = []
np.random.seed(seed=seed)

if only_test_set_perts and (test_set_perts is not None):
ood_genes = np.array(test_set_perts)
train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes)
Expand All @@ -223,24 +225,24 @@ def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, see
ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts))
train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False)
train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition))

## ood genes
ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)
ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)

pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single')
unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single')
assert len(unseen_single) + len(pert_single_train) == len(pert_list)

return pert_single_train, unseen_single, {'unseen_single': unseen_single}

def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen2_train_frac = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False):

unique_pert_genes = self.get_genes_from_perts(pert_list)

pert_train = []
pert_test = []
np.random.seed(seed=seed)

if only_test_set_perts and (test_set_perts is not None):
ood_genes = np.array(test_set_perts)
train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes)
Expand All @@ -255,35 +257,35 @@ def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen
ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts))
train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False)
train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition))

## ood genes
ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)
ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)

pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single')
pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list,'combo')
pert_train.extend(pert_single_train)

## the combo set with one of them in OOD
combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if
combo_seen1 = [x for x in pert_combo if len([t for t in get_pert_genes(x) if
t in train_gene_candidates]) == 1]
pert_test.extend(combo_seen1)

pert_combo = np.setdiff1d(pert_combo, combo_seen1)
## randomly sample the combo seen 2 as a test set, the rest in training set
np.random.seed(seed=seed)
pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False)

combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train).tolist()
pert_test.extend(combo_seen2)
pert_train.extend(pert_combo_train)

## unseen single
unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single')
combo_ood = self.get_perts_from_genes(ood_genes, pert_list, 'combo')
pert_test.extend(unseen_single)

## here only keeps the seen 0, since seen 1 is tackled above
combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if
combo_seen0 = [x for x in combo_ood if len([t for t in get_pert_genes(x) if
t in train_gene_candidates]) == 0]
pert_test.extend(combo_seen0)
assert len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2) == len(pert_list)
Expand All @@ -292,7 +294,7 @@ def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen
'combo_seen1': combo_seen1,
'combo_seen2': combo_seen2,
'unseen_single': unseen_single}

def get_split_list(self, pert_list, test_size=0.1,
test_pert_genes=None, test_perts=None,
hold_outs=True):
Expand Down Expand Up @@ -336,7 +338,7 @@ def get_split_list(self, pert_list, test_size=0.1,
if hold_outs:
# This just checks that none of the combos have 2 seen genes
hold_out = [t for t in combo_perts if
len([t for t in t.split('+') if
len([t for t in get_pert_genes(t) if
t not in test_pert_genes]) > 0]
combo_perts = [c for c in combo_perts if c not in hold_out]
test_perts = single_perts + combo_perts
Expand All @@ -353,22 +355,22 @@ def get_split_list(self, pert_list, test_size=0.1,
if hold_outs:
# This just checks that none of the combos have 2 seen genes
hold_out = [t for t in combo_perts if
len([t for t in t.split('+') if
len([t for t in get_pert_genes(t) if
t not in test_pert_genes]) > 1]
combo_perts = [c for c in combo_perts if c not in hold_out]
test_perts = single_perts + combo_perts

elif self.seen == 2:
if test_perts is None:
test_perts = np.random.choice(combo_perts,
int(len(combo_perts) * test_size))
int(len(combo_perts) * test_size))
else:
test_perts = np.array(test_perts)
else:
if test_perts is None:
test_perts = np.random.choice(combo_perts,
int(len(combo_perts) * test_size))

train_perts = [p for p in pert_list if (p not in test_perts)
and (p not in hold_out)]
return train_perts, test_perts
Expand All @@ -380,16 +382,16 @@ def get_perts_from_genes(self, genes, pert_list, type_='both'):

single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')]
combo_perts = [p for p in pert_list if 'ctrl' not in p]

perts = []

if type_ == 'single':
pert_candidate_list = single_perts
elif type_ == 'combo':
pert_candidate_list = combo_perts
elif type_ == 'both':
pert_candidate_list = pert_list

for p in pert_candidate_list:
for g in genes:
if g in parse_any_pert(p):
Expand All @@ -404,7 +406,7 @@ def get_genes_from_perts(self, perts):

if type(perts) is str:
perts = [perts]
gene_list = [p.split('+') for p in np.unique(perts)]
gene_list = [get_pert_genes(p) for p in np.unique(perts)]
gene_list = [item for sublist in gene_list for item in sublist]
gene_list = [g for g in gene_list if g != 'ctrl']
return np.unique(gene_list)
Loading