Skip to content

Commit

Permalink
Removed unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Renz committed Apr 2, 2020
1 parent fe542b1 commit 9d37529
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 47 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**/__pycache__
**/.vscode
95 changes: 48 additions & 47 deletions fcd/FCD.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python3
''' Defines the functions necessary for calculating the Frechet ChemNet
''' Defines the functions necessary for calculating the Frechet ChemNet
Distance (FCD) to evalulate generative models for molecules.
The FCD metric calculates the distance between two distributions of molecules.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by the generative
of these distributions, while the 2nd distribution is given by the generative
model.
The FCD is calculated by assuming that X_1 and X_2 are the activations of
Expand All @@ -13,29 +13,22 @@
'''

from __future__ import absolute_import, division, print_function
from keras.models import load_model
import keras.backend as K
from scipy import linalg
import numpy as np
from multiprocessing import Pool
from rdkit import Chem
import warnings
warnings.filterwarnings('ignore')
import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
from scipy import linalg
import pathlib
import urllib
import keras
import keras.backend as K
from keras.models import load_model


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
Expand Down Expand Up @@ -83,7 +76,8 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
tr_covmean = np.trace(covmean)

return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def build_masked_loss(loss_function, mask_value):
"""Builds a loss function that masks based on targets
Expand All @@ -101,25 +95,26 @@ def masked_loss_function(y_true, y_pred):
return loss_function(y_true * mask, y_pred * mask)

return masked_loss_function
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def masked_accuracy(y_true, y_pred):
mask_value = 0.5
a = K.sum(K.cast(K.equal(y_true,K.round(y_pred)),K.floatx()))
c = K.sum(K.cast(K.not_equal(y_true,0.5),K.floatx()))
acc = (a) / c
return acc
#-------------------------------------------------------------------------------
a = K.sum(K.cast(K.equal(y_true, K.round(y_pred)), K.floatx()))
c = K.sum(K.cast(K.not_equal(y_true, 0.5), K.floatx()))
acc = (a) / c
return acc
# -------------------------------------------------------------------------------


def get_one_hot(smiles, pad_len=-1):
one_hot = asym = ['C','N','O', 'H', 'F', 'Cl', 'P', 'B', 'Br', 'S', 'I', 'Si',
one_hot = ['C', 'N', 'O', 'H', 'F', 'Cl', 'P', 'B', 'Br', 'S', 'I', 'Si',
'#', '(', ')', '+', '-', '1', '2', '3', '4', '5', '6', '7', '8', '=', '[', ']', '@',
'c', 'n', 'o', 's', 'X', '.']
smiles = smiles + '.'
if pad_len < 0:
vec = np.zeros((len(smiles), len(one_hot) ))
vec = np.zeros((len(smiles), len(one_hot)))
else:
vec = np.zeros((pad_len, len(one_hot) ))
vec = np.zeros((pad_len, len(one_hot)))
cont = True
j = 0
i = 0
Expand All @@ -133,59 +128,65 @@ def get_one_hot(smiles, pad_len=-1):
if sym in one_hot:
vec[j, one_hot.index(sym)] = 1
else:
vec[j,one_hot.index('X')] = 1
j+=1
vec[j, one_hot.index('X')] = 1
j += 1
if smiles[i] == '.' or j >= (pad_len-1) and pad_len > 0:
vec[j,one_hot.index('.')] = 1
vec[j, one_hot.index('.')] = 1
cont = False
return (vec)
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def myGenerator_predict(smilesList, batch_size=128, pad_len=350):
while 1:
while 1:
N = len(smilesList)
nn = pad_len
nn = pad_len
idxSamples = np.arange(N)

for j in range(int(np.ceil(N / batch_size))):
idx = idxSamples[j*batch_size : min((j+1)*batch_size,N)]
idx = idxSamples[j*batch_size: min((j+1)*batch_size, N)]

x = []
for i in range(0,len(idx)):
for i in range(0, len(idx)):
currentSmiles = smilesList[idx[i]]
smiEnc = get_one_hot(currentSmiles, pad_len=nn)
x.append(smiEnc)

x = np.asarray(x)/35
yield x
#-------------------------------------------------------------------------------
def load_ref_model(model_file = None):
if model_file==None:
# -------------------------------------------------------------------------------


def load_ref_model(model_file=None):
if model_file is None:
model_file = 'ChemNet_v0.13_pretrained.h5'
masked_loss_function = build_masked_loss(K.binary_crossentropy,0.5)
model = load_model(model_file,
custom_objects={'masked_loss_function':masked_loss_function,'masked_accuracy':masked_accuracy})
masked_loss_function = build_masked_loss(K.binary_crossentropy, 0.5)
model = load_model(model_file,
custom_objects={'masked_loss_function': masked_loss_function, 'masked_accuracy': masked_accuracy})
model.pop()
model.pop()
return(model)
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def get_predictions(model, gen_mol):
gen_mol_act = model.predict_generator(myGenerator_predict(gen_mol, batch_size=128),
steps= np.ceil(len(gen_mol)/128))
steps=np.ceil(len(gen_mol)/128))
return gen_mol_act
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def canonical(smi):
try:
smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi))
except:
pass
return smi
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------


def canoncial_smiles(smiles):
pool = Pool(32)
smiles = pool.map(canonical, smiles)
pool.close()
return(smiles)



0 comments on commit 9d37529

Please sign in to comment.