Last active
May 24, 2021 06:10
-
-
Save tetlabo/54cbac8f67ec4265a893eecbfca2f4df to your computer and use it in GitHub Desktop.
scikit-learn-intelexライブラリの検証で使用したコード (抜粋)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
# scikit-learnのIntel拡張 (scikit-learn-intelex) を試してみる | |
### ライブラリの読み込み | |
import numpy as np | |
from sklearn.datasets import make_classification | |
from sklearn.svm import SVC | |
from sklearn.model_selection import GridSearchCV, train_test_split | |
from sklearn.metrics import classification_report | |
### データセットの生成 | |
X, y = make_classification(n_samples=2000, n_features=100, n_informative=40, n_classes=2, random_state=334) | |
### 学習用、テスト用データの分割 | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=334) | |
### グリッドサーチパラメータの設定 | |
params = { | |
'C': [1, 1.5, 2, 2.5, 3, 3.5, 4], | |
'gamma': [0.003, 0.002, 0.001, 0.0009, 0.0008], | |
'kernel': ['rbf'] | |
} | |
### モデリングの実行 | |
# F1スコアを指標とした5-Fold CVを行います。筆者のPCが8スレッド (Intel Core i5-8250U CPU @ 1.60GHz) なので、6スレッドまで使用するよう指定しています。 | |
clf = GridSearchCV(SVC(), params, cv=5, scoring='f1_macro', n_jobs=6, verbose=10) | |
clf.fit(X_train, y_train) | |
# ### 精度評価 | |
# 今回は速度の計測が主であって、本題ではないですが。 | |
y_true, y_pred = y_test, clf.predict(X_test) | |
print(clf.best_params_) | |
print(classification_report(y_true, y_pred)) | |
## Intel拡張を使用する場合 | |
### ダイナミック・パッチ | |
# 利用時に、機能のオン・オフを切り替えられるそうです。 | |
from sklearnex import patch_sklearn, unpatch_sklearn | |
patch_sklearn() | |
### モデリングの実行 | |
# F1スコアを指標とした5-Fold CVを行います。 | |
clf = GridSearchCV(SVC(), params, cv=5, scoring='f1_macro', n_jobs=6, verbose=10) | |
clf.fit(X_train, y_train) | |
# ### 精度評価 | |
# 今回は速度の計測が主であって、本題ではないですが。 | |
y_true, y_pred = y_test, clf.predict(X_test) | |
print(clf.best_params_) | |
print(classification_report(y_true, y_pred)) | |
# ### ダイナミック・パッチをオフにする | |
unpatch_sklearn() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment