Created
November 22, 2022 10:54
-
-
Save Mec-iS/c17042298fa5f6d9034dc7caf1d687bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "352982e0", | |
"metadata": {}, | |
"source": [ | |
"# Naive Bayes example\n", | |
"\n", | |
"## Gaussian" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1b9d5215", | |
"metadata": { | |
"vscode": { | |
"languageId": "rust" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"// cargo install the newest version\n", | |
"// this may take a while, needs to download and install the library\n", | |
":dep smartcore = {git = \"https://github.com/m7142yosuke/smartcore.git\", branch = \"feature/improve-display-for-naive-bayes\", features = [\"datasets\"]}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c62e1893", | |
"metadata": { | |
"vscode": { | |
"languageId": "rust" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"use smartcore::dataset::iris::load_dataset;\n", | |
"use smartcore::naive_bayes::gaussian::{GaussianNB, GaussianNBParameters};\n", | |
"use smartcore::linalg::basic::matrix::DenseMatrix;\n", | |
"// Model performance\n", | |
"use smartcore::metrics::accuracy;" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "0acb0e13", | |
"metadata": { | |
"vscode": { | |
"languageId": "rust" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"GaussianNB:\n", | |
" inner: BaseNaiveBayes: {\n", | |
" distribution: GaussianNBDistribution: {\n", | |
" class_labels: [0, 1, 2],\n", | |
" class_count: [50, 50, 50],\n", | |
" class_priors: [0.3333333333333333, 0.3333333333333333, 0.3333333333333333],\n", | |
" var: [[0.12176398698426993, 0.1422760001411465, 0.0295040034446723, 0.01126400058841702], [0.26110400788880384, 0.09650000076294418, 0.21639999752045114, 0.03832399889564586], [0.3962559386291673, 0.10192399745559833, 0.2984959836425922, 0.07392400431251556]],\n", | |
" theta: [[5.006000003814697, 3.41800000667572, 1.4639999961853027, 0.24400000482797624], [5.935999975204468, 2.770000009536743, 4.259999980926514, 1.3259999918937684], [6.588000001907349, 2.9739999914169313, 5.551999988555909, 2.0259999775886537]]\n", | |
"}}\n", | |
"\n", | |
"\n", | |
"\n", | |
"accuracy: 0.96\n" | |
] | |
} | |
], | |
"source": [ | |
"// Load Iris dataset\n", | |
"let iris_dataset = load_dataset();\n", | |
"let iris_data = load_dataset().data;\n", | |
"let iris_targets = load_dataset().target;\n", | |
"\n", | |
"// Input data\n", | |
"let x: DenseMatrix<f32> = DenseMatrix::new(\n", | |
" iris_dataset.num_samples, // num rows\n", | |
" iris_dataset.num_features, // num columns\n", | |
" iris_data, // data as Vec\n", | |
" false, // column_major\n", | |
");\n", | |
"// Ground truth\n", | |
"let y: Vec<u32> = iris_targets;\n", | |
"\n", | |
"// Run Gaussian Naive Bayes without priors\n", | |
"let parameters = GaussianNBParameters::default();\n", | |
"let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();\n", | |
"\n", | |
"// with priors would be:\n", | |
"// let priors = vec![...];\n", | |
"// let parameters = GaussianNBParameters::default().with_priors(priors.clone());\n", | |
"\n", | |
"println!(\"{}\", &gnb);\n", | |
"\n", | |
"let y_hat = gnb.predict(&x).unwrap(); // Predict class labels\n", | |
"// Calculate training error\n", | |
"println!(\"\\naccuracy: {}\", accuracy(&y, &y_hat)); // Prints 0.96" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9b9545b8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Rust", | |
"language": "rust", | |
"name": "rust" | |
}, | |
"language_info": { | |
"codemirror_mode": "rust", | |
"file_extension": ".rs", | |
"mimetype": "text/rust", | |
"name": "Rust", | |
"pygment_lexer": "rust", | |
"version": "" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment