Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Necessary bug fixes in pattern matching code #25

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
821 changes: 440 additions & 381 deletions docs/tutorials/enhancer_code_analysis.ipynb

Large diffs are not rendered by default.

19 changes: 13 additions & 6 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def modisco_results(
for metacluster_name in [f"{contribution[:3]}_patterns"]:
all_pattern_names = list(hdf5_results[metacluster_name])

for _pattern_idx, pattern_name in enumerate(all_pattern_names):
for i in range(len(all_pattern_names)):
pattern_name = 'pattern_'+str(i)
if len(classes) > 1:
ax = axes[motif_counter - 1, idx]
elif max_num_patterns > 1:
Expand Down Expand Up @@ -308,7 +309,7 @@ def plot_patterns(pattern_dict: dict, idcs: list[int]) -> None:
plt.tight_layout()
plt.show()

def plot_pattern_instances(pattern_dict: dict, idx: int) -> None:
def plot_pattern_instances(pattern_dict: dict, idx: int, class_representative: bool = False) -> None:
"""
Plots all the pattern instances clustered together in the pattern dictionary for a given pattern index.

Expand All @@ -318,22 +319,28 @@ def plot_pattern_instances(pattern_dict: dict, idx: int) -> None:
A dictionary containing pattern data.
idcs
Index specifying from which pattern the instances to plot.
class_representative
Boolean to plot the best pattern per class, or all instances of a pattern in the same class if there would be multiple instances in one class. Default False.
"""
n_instances = len(pattern_dict[str(idx)]['classes'])
if class_representative:
key = 'classes'
else:
key='instances'
n_instances = len(pattern_dict[str(idx)][key])
figure, axes = plt.subplots(nrows=n_instances, ncols=1, figsize=(8, 2 * n_instances))
if n_instances == 1:
axes = [axes]

instance_classes = list(pattern_dict[str(idx)]['classes'].keys())
instance_classes = list(pattern_dict[str(idx)][key].keys())

for i, cl in enumerate(instance_classes):
ax = _plot_attribution_map(
ax=axes[i],
saliency_df=np.array(pattern_dict[str(idx)]["classes"][cl]["contrib_scores"]),
saliency_df=np.array(pattern_dict[str(idx)][key][cl]["contrib_scores"]),
return_ax=True,
figsize=None,
)
ax.set_title(pattern_dict[str(idx)]['classes'][cl]["id"])
ax.set_title(pattern_dict[str(idx)][key][cl]["id"])

plt.tight_layout()
plt.show()
Expand Down
2 changes: 1 addition & 1 deletion src/crested/tl/_modisco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def match_score_patterns(a: dict, b: dict) -> float:
#vizsequence.plot_weights(ic_b)
score = tangermeme_tomtom.tomtom(Qs = [ic_a.T], Ts = [ic_b.T])

log_score = -np.log10(max(score[0,0], 0))
log_score = -np.log10(max(score[0,0][0], 0))

return log_score

Expand Down
57 changes: 44 additions & 13 deletions src/crested/tl/_tfmodisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def add_pattern_to_dict(

all_patterns[str(idx)]['ppm']=ppm
all_patterns[str(idx)]["ic"] = np.mean(ic_pos)#np.mean(_get_ic(p["contrib_scores"], pos_pattern))
all_patterns[str(idx)]["instances"] = {}
all_patterns[str(idx)]["instances"][p['id']] = p
all_patterns[str(idx)]["classes"] = {}
all_patterns[str(idx)]["classes"][cell_type] = p
return all_patterns
Expand Down Expand Up @@ -260,8 +262,12 @@ def match_to_patterns(

ppm = _pattern_to_ppm(p)
ic, ic_pos, ic_mat = compute_ic(ppm)
p_ic = np.mean(ic_pos)
p['ic'] = p_ic
p['ppm'] = ppm

p['class']=cell_type

for pat_idx, pattern in enumerate(all_patterns.keys()):
sim = match_score_patterns(p, all_patterns[pattern]["pattern"])
if sim > sim_threshold:
Expand All @@ -276,10 +282,18 @@ def match_to_patterns(

if verbose:
print(
f'Match between {pattern_id} and {all_patterns[str(match_idx)]["pattern"]["id"]}'
f'Match between {pattern_id} and {all_patterns[str(match_idx)]["pattern"]["id"]} with similarity score {str(max_sim)}'
)
all_patterns[str(match_idx)]["classes"][cell_type] = p
p_ic = np.mean(ic_pos)#np.mean(_get_ic(p["contrib_scores"], pos_pattern))

all_patterns[str(match_idx)]["instances"][pattern_id] = p

if(cell_type in all_patterns[str(match_idx)]["classes"].keys()):
ic_class_representative = all_patterns[str(match_idx)]["classes"][cell_type]['ic']
if p_ic > ic_class_representative:
all_patterns[str(match_idx)]["classes"][cell_type] = p
else:
all_patterns[str(match_idx)]["classes"][cell_type] = p

if p_ic > all_patterns[str(match_idx)]["ic"]:
all_patterns[str(match_idx)]["ic"] = p_ic
all_patterns[str(match_idx)]["pattern"] = p
Expand Down Expand Up @@ -421,7 +435,22 @@ def merge_patterns(pattern1: dict, pattern2: dict) -> dict:
-------
Merged pattern with updated metadata.
"""
merged_classes = {}
for cell_type in pattern1["classes"].keys():
if cell_type in pattern2["classes"].keys():
ic_a = pattern1["classes"][cell_type]['ic']
ic_b = pattern2["classes"][cell_type]['ic']
merged_classes[cell_type] = pattern1["classes"][cell_type] if ic_a > ic_b else pattern2["classes"][cell_type]
else:
merged_classes[cell_type] = pattern1["classes"][cell_type]

for cell_type in pattern2["classes"].keys():
if cell_type not in merged_classes.keys():
merged_classes[cell_type] = pattern2["classes"][cell_type]

merged_classes = {**pattern1["classes"], **pattern2["classes"]}
merged_instances = {**pattern1["instances"], **pattern2["instances"]}


if pattern2["ic"] > pattern1["ic"]:
representative_pattern = pattern2["pattern"]
Expand All @@ -434,6 +463,7 @@ def merge_patterns(pattern1: dict, pattern2: dict) -> dict:
"pattern": representative_pattern,
"ic": highest_ic,
"classes": merged_classes,
"instances": merged_instances
}


Expand Down Expand Up @@ -464,7 +494,7 @@ def pattern_similarity(
all_patterns[str(idx2)]["pattern"], all_patterns[str(idx1)]["pattern"]
),
)
return sim[0]
return sim


def normalize_rows(arr: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -553,9 +583,9 @@ def match_h5_files_to_classes(

def process_patterns(
matched_files: dict[str, str | list[str] | None],
sim_threshold: float = 0.5,
sim_threshold: float = 3,
trim_ic_threshold: float = 0.1,
discard_ic_threshold: float = 0.15,
discard_ic_threshold: float = 0.1,
verbose: bool = False,
) -> dict[str, dict[str, str | list[float]]]:
"""
Expand All @@ -566,7 +596,7 @@ def process_patterns(
matched_files
dictionary with class names as keys and paths to HDF5 files as values.
sim_threshold
Similarity threshold for matching patterns.
Similarity threshold for matching patterns (-log10(pval), pval obtained through TOMTOM matching from tangermeme)
trim_ic_threshold
Information content threshold for trimming patterns.
discard_ic_threshold
Expand All @@ -584,7 +614,6 @@ def process_patterns(
trimmed_patterns = []
pattern_ids = []
is_pattern_pos = []
pattern_idx = 0

if matched_files[cell_type] is None:
continue
Expand All @@ -599,7 +628,9 @@ def process_patterns(
try:
with h5py.File(h5_file) as hdf5_results:
for metacluster_name in list(hdf5_results.keys()):
for p in hdf5_results[metacluster_name]:
pattern_idx = 0
for i in range(len(list(hdf5_results[metacluster_name]))):
p = 'pattern_'+str(i)
pattern_ids.append(
f"{cell_type.replace(' ', '_')}_{metacluster_name}_{pattern_idx}"
)
Expand Down Expand Up @@ -806,7 +837,7 @@ def generate_html_paths(

for i, _ in enumerate(all_patterns):
pattern_id = all_patterns[str(i)]["pattern"]["id"]
pattern_class_parts = pattern_id.split("_")[:-4]
pattern_class_parts = pattern_id.split("_")[:-3]
pattern_class = (
"_".join(pattern_class_parts)
if len(pattern_class_parts) > 1
Expand Down Expand Up @@ -845,11 +876,11 @@ def find_pattern_matches(
pattern_id = all_patterns[p_idx]["pattern"]["id"]
pattern_id_parts = pattern_id.split("_")
pattern_id = (
pattern_id_parts[-4]
pattern_id_parts[-3]
+ "_"
+ pattern_id_parts[-3]
+ "."
+ pattern_id_parts[-2]
+ "."
+ 'pattern'
+ "_"
+ pattern_id_parts[-1]
)
Expand Down
Loading