Created
May 14, 2019 07:50
-
-
Save andersx/76369c2bca8d84140302b8ff77bc1cab to your computer and use it in GitHub Desktop.
Test-case for asymmetric FCHL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import os | |
import ast | |
import csv | |
from scipy.linalg import lstsq | |
import qml | |
from qml.fchl import generate_representation | |
from qml.fchl import get_atomic_local_kernels | |
np.random.seed(667) | |
CSV_FILE = "force_test.csv" | |
TRAINING = 100 | |
TEST = 101 | |
CUT_DISTANCE = 1e6 | |
SIGMAS = [2.5] | |
KERNEL_ARGS = { | |
"verbose": False, | |
"cut_distance": CUT_DISTANCE, | |
"kernel": "gaussian", | |
"kernel_args": { | |
"sigma": SIGMAS, | |
}, | |
} | |
def mae(a, b): | |
return np.mean(np.abs(a.flatten() - b.flatten())) | |
def csv_to_reps(csv_filename): | |
x = [] | |
e = [] | |
with open(csv_filename, 'r') as csvfile: | |
df = csv.reader(csvfile, delimiter=";", quotechar='#') | |
for row in df: | |
coordinates = np.array(ast.literal_eval(row[3])) | |
nuclear_charges = ast.literal_eval(row[8]) | |
energy = float(row[1]) | |
rep = generate_representation(coordinates, nuclear_charges, cut_distance=CUT_DISTANCE) | |
x.append(rep) | |
e.append(energy) | |
return np.array(x), np.array(e) | |
def test_asymmetric(): | |
Xall, Eall, = csv_to_reps(CSV_FILE) | |
X = Xall[:TRAINING] | |
Xs = Xall[-TEST:] | |
E = Eall[:TRAINING] | |
Es = Eall[-TEST:] | |
K = get_atomic_local_kernels(X, X, **KERNEL_ARGS) | |
Ks = get_atomic_local_kernels(X, Xs, **KERNEL_ARGS) | |
for i, sigma in enumerate(SIGMAS): | |
alphas, residuals, singular_values, rank = lstsq(K[i].T, E, cond=1e-9, lapack_driver="gelsd") | |
# Test energy prediction | |
Ess = np.dot(Ks[i].T, alphas) | |
# Training energy predictions | |
Et = np.dot(K[i].T, alphas) | |
print(mae(Ess, Es)) | |
print(mae(Et, E)) | |
if __name__ == "__main__": | |
test_asymmetric() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment