Created
January 10, 2020 10:21
-
-
Save Everfighting/673b56521531321b7715e6ecc8657b9a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import load_iris | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.model_selection import train_test_split | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# 加载数据集 | |
iris = load_iris() | |
# 数据特征:150行, 4列 | |
features = iris['data'] | |
# 对应的鸢尾花种类: 150个,三种鸢尾花分别用 0,1,2 表示 | |
target = iris['target'] | |
# 自定义4个特征的名称 | |
feature_names = iris.feature_names | |
feature_names = ['花萼长度', '花萼宽度', '花瓣长度', '花瓣宽度'] | |
# 自定义三种鸢尾花的名称 | |
class_names = iris.target_names | |
class_names = ['山鸢尾花', '变色鸢尾花', '维吉尼亚鸢尾花'] | |
# 把样本分成训练集和测试集两部分, 两者比例为: 7:3 | |
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.3, random_state=42) | |
# 训练 | |
lr = LogisticRegression() | |
lr.fit(X_train, y_train) | |
# 预测 | |
output = lr.predict(X_test) | |
# 计算准确率 | |
acc = np.mean(output == y_test)*100 | |
print("The accuracy of the logistic regression classifier is: \t", acc, "%") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment