Download model file from here https://github.com/streamlit/mol-demo/blob/main/chembl_32_multitask.onnx Download streamlit-ketcher wheel file as well https://github.com/streamlit/mol-demo/blob/main/streamlit_ketcher-0.0.1-py2.py3-none-any.whl
Created
October 30, 2024 16:03
-
-
Save churnikov/2e940935a3f91e45486b0a412c134bfa to your computer and use it in GitHub Desktop.
Streamlit app for serve
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
from streamlit_ketcher import st_ketcher | |
import utils | |
import target_predictions | |
DEFAULT_COMPOUND = "CHEMBL141739" | |
if "molfile" not in st.session_state: | |
st.session_state.molfile = None | |
if "chembl_id" not in st.session_state: | |
st.session_state.chembl_id = DEFAULT_COMPOUND | |
st.set_page_config(layout="wide") | |
st.subheader("🧪 Molecule editor") | |
chembl_id = st.text_input("ChEMBL ID:", st.session_state.chembl_id) | |
st.session_state.molfile = utils.id_to_molecule(chembl_id) | |
famous_molecules = [ | |
('☕', 'Caffeine'), ('🥱', 'Melatonin'), ('🚬', 'Nicotine'), ('🌨️', 'Cocaine'), ('💊', 'Aspirin'), | |
('🍄', 'Psilocybine'), ('💎', 'Lysergide') | |
] | |
for molecule, column in zip(famous_molecules, st.columns(len(famous_molecules))): | |
with column: | |
emoji, name = molecule | |
if st.button(f'{emoji} {name}'): | |
st.session_state.molfile, st.session_state.chembl_id = utils.name_to_molecule(name) | |
editor_column, results_column = st.columns(2) | |
similar_smiles = [] | |
with editor_column: | |
smiles = st_ketcher(st.session_state.molfile) | |
similarity_threshold = st.slider("Similarity threshold:", min_value=60, max_value=100) | |
with st.expander("Raw data"): | |
st.markdown(f"```{smiles}```") | |
with results_column: | |
similar_molecules = utils.find_similar_molecules(smiles, similarity_threshold) | |
if not similar_molecules: | |
st.warning("No results found") | |
else: | |
table = utils.render_similarity_table(similar_molecules) | |
similar_smiles = utils.get_similar_smiles(similar_molecules) | |
st.markdown(f'<div id="" style="overflow:scroll; height:600px; padding-left: 80px;">{table}</div>', | |
unsafe_allow_html=True) | |
if similar_smiles: | |
st.subheader("Target prediction based on [ChEMBL multitask model](https://github.com/chembl/chembl_multitask_model)") | |
if st.button("🔮 Predict targets"): | |
preds = target_predictions.predict_all(similar_smiles) | |
table = utils.render_target_predictions_table(preds) | |
st.markdown(table, unsafe_allow_html=True) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM python:3.9-slim | |
ENV USER=username | |
ENV HOME=/home/$USER | |
# Add user to system | |
RUN useradd -m -u 1000 $USER | |
WORKDIR $HOME/app | |
RUN apt-get update && apt-get install --no-install-recommends -y \ | |
build-essential \ | |
software-properties-common | |
COPY streamlit_ketcher-0.0.1-py2.py3-none-any.whl $HOME/app | |
COPY requirements.txt $HOME/app | |
RUN pip install -r requirements.txt | |
COPY chembl_32_multitask.onnx . | |
COPY app.py . | |
COPY target_predictions.py . | |
COPY utils.py . | |
USER $USER |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rdkit | |
onnxruntime==1.17.1 | |
chembl-webresource-client==0.10.8 | |
matplotlib==3.7.1 | |
streamlit==1.20.0 | |
pandas==1.5.3 | |
numpy==1.24.0 | |
./streamlit_ketcher-0.0.1-py2.py3-none-any.whl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnxruntime | |
import numpy as np | |
from rdkit import Chem | |
from rdkit.Chem import rdMolDescriptors | |
FP_SIZE = 1024 | |
RADIUS = 2 | |
# load the model | |
ort_session = onnxruntime.InferenceSession("chembl_32_multitask.onnx") | |
def calc_morgan_fp(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
fp = rdMolDescriptors.GetMorganFingerprintAsBitVect( | |
mol, RADIUS, nBits=FP_SIZE) | |
a = np.zeros((0,), dtype=np.float32) | |
Chem.DataStructs.ConvertToNumpyArray(fp, a) | |
return a | |
def format_preds(preds, targets): | |
preds = np.concatenate(preds).ravel() | |
np_preds = [(tar, pre) for tar, pre in zip(targets, preds)] | |
dt = [('chembl_id', '|U20'), ('pred', '<f4')] | |
np_preds = np.array(np_preds, dtype=dt) | |
np_preds[::-1].sort(order='pred') | |
return np_preds | |
def predict(smiles): | |
# calculate the FPs | |
descs = calc_morgan_fp(smiles) | |
# run the prediction | |
ort_inputs = {ort_session.get_inputs()[0].name: descs} | |
preds = ort_session.run(None, ort_inputs) | |
# example of how the output of the model can be formatted | |
return format_preds(preds, [o.name for o in ort_session.get_outputs()]) | |
def predict_all(smiles): | |
preds = [] | |
for smile in smiles: | |
preds.append(predict(smile)) | |
return np.concatenate(preds) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from typing import Optional, Tuple | |
from chembl_webresource_client.new_client import new_client as ch | |
EBI_URL = "https://www.ebi.ac.uk/chembl/" | |
def name_to_molecule(name: str) -> Tuple[str, str]: | |
columns = ['molecule_chembl_id', 'molecule_structures'] | |
ret = ch.molecule.filter(molecule_synonyms__molecule_synonym__iexact=name).only(columns) | |
best_match = ret[0] | |
return best_match["molecule_structures"]["molfile"], best_match["molecule_chembl_id"] | |
def id_to_molecule(chembl_id: str) -> Tuple[str, str]: | |
return ch.molecule.filter(chembl_id=chembl_id).only('molecule_structures')[0]["molecule_structures"]["molfile"] | |
def style_table(df: pd.DataFrame) -> pd.io.formats.style.Styler: | |
return df.style.hide_index().format( | |
subset=['Similarity'], | |
decimal=',', precision=2 | |
).bar( | |
subset=['Similarity'], | |
align="mid", | |
cmap="coolwarm" | |
).applymap(lambda x: 'background-color: #aaaaaa', subset=['Image']) | |
def style_predictions(df: pd.DataFrame) -> pd.io.formats.style.Styler: | |
return df.style.hide_index().format( | |
subset=['Prediction'], | |
decimal=',', precision=2 | |
).bar( | |
subset=['Prediction'], | |
align="mid", | |
cmap="plasma_r", | |
vmax=1.0, | |
vmin=0.8 | |
) | |
def render_chembl_url(chembl_id: str) -> str: | |
return f'<a href="{EBI_URL}compound_report_card/{chembl_id}/">{chembl_id}</a>' | |
def render_chembl_img(chembl_id: str) -> str: | |
return f'<img src="{EBI_URL}api/data/image/{chembl_id}.svg" height="100px" width="100px">' | |
def render_row(row): | |
return { | |
"Similarity": float(row["similarity"]), | |
"Preferred name": row["pref_name"], | |
"ChEMBL ID": render_chembl_url(row["molecule_chembl_id"]), | |
"Image": render_chembl_img(row["molecule_chembl_id"]) | |
} | |
def render_target(target): | |
return { | |
"Prediction": float(target["pred"]), | |
"ChEMBL ID": render_chembl_url(target["chembl_id"]) | |
} | |
def find_similar_molecules(smiles: str, threshold: int): | |
columns = ['molecule_chembl_id', 'similarity', 'pref_name', 'molecule_structures'] | |
try: | |
return ch.similarity.filter(smiles=smiles, similarity=threshold).only(columns) | |
except Exception as _: | |
return None | |
def render_similarity_table(similar_molecules) -> Optional[str]: | |
records = [render_row(row) for row in similar_molecules if row["molecule_structures"]] | |
df = pd.DataFrame.from_records(records) | |
styled = style_table(df) | |
return styled.to_html(render_links=True) | |
def render_target_predictions_table(predictions) -> Optional[str]: | |
df = pd.DataFrame(predictions) | |
records = [render_target(target) for target in | |
df.sort_values(by=['pred'], ascending=False).head(20).to_dict('records')] | |
df = pd.DataFrame.from_records(records) | |
styled = style_predictions(df) | |
return styled.to_html(render_links=True) | |
def get_similar_smiles(similar_molecules): | |
return [mol["molecule_structures"]["canonical_smiles"] for mol in similar_molecules if mol["molecule_structures"]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment