Created
December 23, 2018 12:05
-
-
Save marta-sd/bd359e5047e7bc1abb8ba5bb65799e35 to your computer and use it in GitHub Desktop.
Re-train Pafnucy (https://gitlab.com/cheminfIBB/pafnucy)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# before running this script clone Pafnucy's repository and create the environment: | |
# $ git clone https://gitlab.com/cheminfIBB/pafnucy | |
# $ cd pafnucy | |
# $ conda env create -f environment_gpu.yml | |
import numpy as np | |
import h5py | |
import tensorflow as tf | |
from tfbio.data import make_grid | |
# load Pafnucy | |
graph = tf.Graph() | |
with graph.as_default(): | |
saver = tf.train.import_meta_graph('results/batch5-2017-06-05T07:58:47-best.meta') | |
# get placeholders for input, prediction and target | |
x = graph.get_tensor_by_name('input/structure:0') | |
y = graph.get_tensor_by_name('output/prediction:0') | |
t = graph.get_tensor_by_name('input/affinity:0') | |
keep_prob = graph.get_tensor_by_name('fully_connected/keep_prob:0') | |
train = graph.get_tensor_by_name('training/train:0') | |
# load some data | |
x_ = [] | |
y_ = [] | |
with h5py.File('tests/data/dataset/test_set.hdf', 'r') as f: | |
for name in ['1e66', '5c28']: | |
coords = (f[name][:, :3]) | |
features = (f[name][:, 3:]) | |
grid = make_grid(coords, features) | |
x_.append(grid) | |
y_.append(f[name].attrs['affinity']) | |
x_ = np.vstack(x_) | |
y_ = np.reshape(y_, (-1, 1)) | |
print('target values:', y_) | |
# re-train Pafnucy | |
with tf.Session(graph=graph) as session: | |
saver.restore(session, 'results/batch5-2017-06-05T07:58:47-best') | |
print('predictions before training:', | |
session.run(y, feed_dict={x: x_, keep_prob: 1.0})) | |
for _ in range(10): | |
session.run(train, feed_dict={x: x_, t: y_, keep_prob: 1.0}) | |
print('predictions after training:', | |
session.run(y, feed_dict={x: x_, keep_prob: 1.0})) | |
saver.save(session, 'pafnucy_retrained') | |
# load and use the new model | |
new_graph = tf.Graph() | |
with new_graph.as_default(): | |
saver = tf.train.import_meta_graph('pafnucy_retrained.meta') | |
x = new_graph.get_tensor_by_name('input/structure:0') | |
y = new_graph.get_tensor_by_name('output/prediction:0') | |
keep_prob = new_graph.get_tensor_by_name('fully_connected/keep_prob:0') | |
with tf.Session(graph=new_graph) as session: | |
saver.restore(session, 'pafnucy_retrained') | |
print('predictions with loaded model:', | |
session.run(y, feed_dict={x: x_, keep_prob: 1.0})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment