Skip to content

Commit

Permalink
Support multiple SMARTS patterns in argument to --cut_smarts. See rdk…
Browse files Browse the repository at this point in the history
…it#15 for details.
  • Loading branch information
baoilleach committed Sep 5, 2024
1 parent befd7a0 commit f43568a
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions mmpdblib/fragment_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,22 @@ def parse_cut_smarts(smarts):
from rdkit import Chem
from . import smarts_aliases

if smarts in smarts_aliases.cut_smarts_aliases_by_name:
smarts = smarts_aliases.cut_smarts_aliases_by_name[smarts].smarts
pattern = Chem.MolFromSmarts(smarts)
if pattern is None:
raise ValueError("unable to parse SMARTS")
if pattern.GetNumAtoms() != 2:
raise ValueError("cut SMARTS must match exactly two atoms")
if pattern.GetNumBonds() != 1:
raise ValueError("cut SMARTS must connect both atoms")
return pattern
smarts_terms = smarts.split("%%") # add support for multiple SMARTS
if not smarts_terms:
raise ValueError("cut SMARTS must not be empty")
patterns = []
for smarts_term in smarts_terms:
if smarts_term in smarts_aliases.cut_smarts_aliases_by_name:
smarts_term = smarts_aliases.cut_smarts_aliases_by_name[smarts_term].smarts
pattern = Chem.MolFromSmarts(smarts_term)
if pattern is None:
raise ValueError("unable to parse SMARTS")
if pattern.GetNumAtoms() != 2:
raise ValueError("cut SMARTS must match exactly two atoms")
if pattern.GetNumBonds() != 1:
raise ValueError("cut SMARTS must connect both atoms")
patterns.append(pattern)
return patterns


def parse_salt_remover(salt_remover_filename):
Expand Down Expand Up @@ -234,7 +240,7 @@ def __init__(
max_rotatable_bonds,
rotatable_pattern,
salt_remover,
cut_pattern,
cut_patterns,
num_cuts,
method,
options,
Expand All @@ -249,7 +255,7 @@ def __init__(
self.max_rotatable_bonds = max_rotatable_bonds
self.rotatable_pattern = rotatable_pattern
self.salt_remover = salt_remover
self.cut_pattern = cut_pattern
self.cut_patterns = cut_patterns
self.num_cuts = num_cuts
self.method = method
self.options = options
Expand Down Expand Up @@ -306,7 +312,15 @@ def apply_filters(self, mol):
return None

def get_cut_atom_pairs(self, mol):
return mol.GetSubstructMatches(self.cut_pattern, uniquify=True)
seen = set()
for pat in self.cut_patterns:
for (atom1_idx, atom2_idx) in mol.GetSubstructMatches(pat):
# put into canonical order so cuts are consistent across all patterns
if atom1_idx < atom2_idx:
seen.add((atom1_idx, atom2_idx))
else:
seen.add((atom2_idx, atom1_idx))
return list(seen)

def get_cut_lists(self, mol):
atom_pairs = self.get_cut_atom_pairs(mol)
Expand Down Expand Up @@ -339,7 +353,7 @@ def call(parse, name):
max_heavies = options.max_heavies
max_rotatable_bonds = options.max_rotatable_bonds
rotatable_pattern = call(parse_rotatable_smarts, "rotatable_smarts")
cut_pattern = call(parse_cut_smarts, "cut_smarts")
cut_patterns = call(parse_cut_smarts, "cut_smarts")

num_cuts = call(parse_num_cuts, "num_cuts")
method = call(parse_method, "method")
Expand All @@ -357,7 +371,7 @@ def call(parse, name):
max_rotatable_bonds=max_rotatable_bonds,
rotatable_pattern=rotatable_pattern,
salt_remover=salt_remover,
cut_pattern=cut_pattern,
cut_patterns=cut_patterns,
num_cuts=num_cuts,
method=method,
options=fragment_options,
Expand Down

0 comments on commit f43568a

Please sign in to comment.