Skip to content

Commit

Permalink
Merge pull request #28 from aertslab/dev
Browse files Browse the repository at this point in the history
updated pattern clustering plotting
  • Loading branch information
nkempynck authored Oct 1, 2024
2 parents 5359d02 + b89a96d commit a30a959
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 235 deletions.
379 changes: 181 additions & 198 deletions docs/tutorials/enhancer_code_analysis.ipynb

Large diffs are not rendered by default.

123 changes: 89 additions & 34 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ def modisco_results(
def create_clustermap(
pattern_matrix: np.ndarray,
classes: list[str],
figsize: tuple[int, int] = (15, 13),
subset: list[str] | None = None, # Subset option
figsize: tuple[int, int] = (25,8),
grid: bool = False,
color_palette: str | list[str] = "hsv",
cmap: str = "coolwarm",
center: float = 0,
method: str = "average",
dy: float = 0.002,
fig_path: str | None = None,
pat_seqs: list[tuple[str, np.ndarray]] | None = None
) -> sns.matrix.ClusterGrid:
Expand All @@ -205,26 +207,45 @@ def create_clustermap(
Parameters
----------
pattern_matrix
pattern_matrix : np.ndarray
2D NumPy array containing pattern data.
classes
List of class labels.
figsize
classes : list[str]
List of class labels, matching the rows of the pattern matrix.
subset : list[str], optional
List of class labels to subset the matrix.
figsize : tuple[int, int], optional
Size of the figure.
grid
grid : bool, optional
Whether to add a grid to the heatmap.
color_palette
color_palette : str or list[str], optional
Color palette for the row colors.
- cmap (str): Colormap for the clustermap.
- center (float): Value at which to center the colormap.
- method (str): Clustering method to use (e.g., 'average', 'single', 'complete').
- fig_path (str, optional): Path to save the figure.
- pat_seqs (list, optional): List of sequences to use as xticklabels.
Returns
-------
The clustermap object.
cmap : str, optional
Colormap for the clustermap.
center : float, optional
Value at which to center the colormap.
method : str, optional
Clustering method to use.
dy: float, optional
Scaling parameter for vertical distance between nucleotides (if pat_seqs is not None) in xticklabels.
fig_path : str, optional
Path to save the figure.
pat_seqs : list[tuple[str, np.ndarray]], optional
List of sequences to use as xticklabels.
"""
# Subset the pattern_matrix and classes if subset is provided
if subset is not None:
subset_indices = [i for i, class_label in enumerate(classes) if class_label in subset]
pattern_matrix = pattern_matrix[subset_indices, :]
classes = [classes[i] for i in subset_indices]

# Remove columns that contain only zero values
non_zero_columns = np.any(pattern_matrix != 0, axis=0)
pattern_matrix = pattern_matrix[:, non_zero_columns]

# Reindex columns based on the original positions of non-zero columns
column_indices = np.where(non_zero_columns)[0]
data = pd.DataFrame(pattern_matrix, columns=column_indices)

data = pd.DataFrame(pattern_matrix)

if isinstance(color_palette, str):
Expand All @@ -236,17 +257,14 @@ def create_clustermap(
row_colors = pd.Series(classes).map(class_lut)

if pat_seqs is not None:
plt.rc("text", usetex = True)
plt.rc("text", usetex=False) # Turn off LaTeX to speed up rendering
scaling_factor = 10

# Plot the scaled x-tick labels based on the importance scores
xtick_labels = [
letters[0:2] + " " + r"".join(
[
r"{\fontsize{" + f"{s * scaling_factor}" + r"pt}{3em}\selectfont " + l + r"}"
for l, s in zip(letters[2:], scores[2:])
]
)
for letters, scores in pat_seqs
(letters, scores) for letters, scores in pat_seqs
]

else:
xtick_labels = True

Expand All @@ -257,29 +275,66 @@ def create_clustermap(
row_colors=None,
yticklabels=classes,
center=center,
xticklabels=xtick_labels,
xticklabels=True if pat_seqs is None else False, # Disable default xticklabels if pat_seqs provided. #xticklabels=xtick_labels,
method=method,
dendrogram_ratio=(0.1, 0.1),
cbar_pos=(1.05, 0.4, 0.01, 0.3)
)
col_order = g.dendrogram_col.reordered_ind
cbar = g.ax_heatmap.collections[0].colorbar
cbar.set_label('Motif importance', rotation=270, labelpad=20) # Rotate label and add padding
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)

#for label in class_lut:
# g.ax_col_dendrogram.bar(0, 0, color=class_lut[label], label=label, linewidth=0)
# Get the reordered column indices from the clustermap
col_order = g.dendrogram_col.reordered_ind

# Reorder the pat_seqs to follow the column order
if pat_seqs is not None:
reordered_pat_seqs = [pat_seqs[column_indices[i]] for i in col_order]
ax = g.ax_heatmap
x_positions = np.arange(len(reordered_pat_seqs)) + 0.5 # Shift labels by half a tick to the right

constant = (1/figsize[1])*64
for i, (letters, scores) in enumerate(reordered_pat_seqs):
previous_spacing = 0
for j, (letter, score) in enumerate(zip(reversed(letters), reversed(scores))):
fontsize = score*10
vertical_spacing = max((constant * score * dy), constant * 0.1 * dy) # Spacing proportional to figsize[1]

ax.text(
x_positions[i], -(constant*0.002) - previous_spacing, # Adjust y-position based on spacing
letter,
fontsize=fontsize, # Constant font size
ha='center', # Horizontal alignment
va='top', # Vertical alignment
rotation=90, # Rotate the labels vertically
transform=ax.get_xaxis_transform() # Ensure the text is placed relative to x-axis
)
previous_spacing += vertical_spacing

# Ensure x-ticks are visible
ax.set_xticks(x_positions)

if grid:
ax = g.ax_heatmap
ax.grid(
True,
which="both",
color="grey",
linewidth=0.25,
)
# Define the grid positions (between cells, hence the +0.5 offset)
x_positions = np.arange(pattern_matrix.shape[1] + 1)
y_positions = np.arange(len(pattern_matrix) + 1)

# Add horizontal grid lines
for y in y_positions:
ax.hlines(y, *ax.get_xlim(), color="grey", linewidth=0.25)

# Add vertical grid lines
for x in x_positions:
ax.vlines(x, *ax.get_ylim(), color="grey", linewidth=0.25)

g.fig.canvas.draw()

if fig_path is not None:
plt.savefig(fig_path)

plt.show()
return g


def plot_patterns(pattern_dict: dict, idcs: list[int]) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def enhancer_design_motif_implementation(
masked_locations=inserted_motif_locations,
)

mutagenesis_predictions = self.model.predict(mutagenesis)
mutagenesis_predictions = self.model.predict(mutagenesis, verbose=False)

# determine the best insertion site
best_mutation = enhancer_optimizer.get_best(
Expand Down Expand Up @@ -1551,7 +1551,8 @@ def enhancer_design_in_silico_evolution(
mutagenesis_predictions = self.model.predict(
mutagenesis.reshape(
(n_sequences * TOTAL_NUMBER_OF_MUTATIONS_PER_SEQ, seq_len, 4)
)
),
verbose=False
)

mutagenesis_predictions = mutagenesis_predictions.reshape(
Expand Down
2 changes: 1 addition & 1 deletion src/crested/tl/_modisco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def match_score_patterns(a: dict, b: dict) -> float:
#vizsequence.plot_weights(ic_b)
score = tangermeme_tomtom.tomtom(Qs = [ic_a.T], Ts = [ic_b.T])

log_score = -np.log10(max(score[0,0][0], 0))
log_score = -np.log10(max(score[0,0][0], 1e-12))

return log_score

Expand Down

0 comments on commit a30a959

Please sign in to comment.