文章目录
💌💌💌ID3算法实例
💨💨1.使用sklearn的决策树算法对葡萄酒数据集进行分类,要求:
(1)划分训练集和测试集(测试集占20%)
(2)对测试集的预测类别标签和真实标签进行对比
(3)输出分类的准确率
(4)调整参数比较不同算法(ID3, CART)的分类效果。
🕳🕳2. 利用给定ID3算法,画出下列训练集的决策树。
🍇🍇🍇1.葡萄酒分类
🚲🚲🚲(1)划分训练集和测试集(测试集占20%)
test_size等于几就是测试集占比
x_train, x_test, y_train, y_test = train_test_split(
X, Y, test_size=0.2, random_state=0)
🚓🚓(2)对测试集的预测类别标签和真实标签进行对比
预测类别标签
y_predict = clf.predict(x_test)
对比
pd.concat([pd.DataFrame(x_test), pd.DataFrame(y_test), pd.DataFrame(y_predict)], axis=1)
🛹🛹(3)输出分类的准确率
clf.fit(x_train, y_train)
score = clf.score(x_test, y_test)
🏍🛵🏍(4)调整参数比较不同算法(ID3, CART)的分类效果。
采用ID3算法进行计算
clf = tree.DecisionTreeClassifier(criterion=“entropy”)
采用CART算法进行计算
clf = tree.DecisionTreeClassifier(criterion=“gini”)
🚀🚀🚀完整代码:
# 导入相关库
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.datasets import load_wine
import pandas as pd
# 导入数据集
wine = load_wine()
X = wine.data
Y = wine.target
features_name = wine.feature_names
print(features_name)
print(pd.concat([pd.DataFrame(X), pd.DataFrame(Y)], axis=1))
# 打印数据
# 划分数据集,数据集划分为测试集占20%;
x_train, x_test, y_train, y_test = train_test_split(
X, Y, test_size=0.2, random_state=0)
# 采用ID3算法进行计算
clf = tree.DecisionTreeClassifier(criterion="entropy")
# 采用CART算法进行计算
# clf = tree.DecisionTreeClassifier(criterion="gini")
# 获取模型
clf.fit(x_train, y_train)
score = clf.score(x_test, y_test)
y_predict = clf.predict(x_test)
print('准确率为:', score)
# 对测试集的预测类别标签和真实标签进行对比
print(pd.concat([pd.DataFrame(x_test), pd.DataFrame(y_test), pd.DataFrame(y_predict)], axis=1))
🚀🚀🚀结果:
🛰🛰🌌2.只需修改数据集,标签集即可
🚩🚩部分代码:
if __name__ == '__main__':
def createDataSet():
dataSet = [[1, 1, 1, 1, 'No'],
[1, 1, 1, 2, 'No'],
[2, 1, 1, 1, 'Yes'],
[3, 2, 1, 1, 'Yes'],
[3, 3, 2, 1, 'Yes'],
[3, 3, 2, 2, 'No'],
[2, 3, 2, 2, 'Yes'],
[1, 2, 1, 1, 'No'],
[1, 3, 2, 1, 'Yes'],
[3, 2, 2, 1, 'Yes'],
[1, 2, 2, 2, 'Yes'],
[2, 2, 1, 2, 'Yes'],
[2, 1, 2, 1, 'Yes'],
[3, 2, 1, 2, 'No'], ]
features = ['outlook', 'temp', 'humidity', 'windy']
return dataSet, features
id3 = ID3Tree() # 创建一个ID3决策树
ds, labels = createDataSet()
id3.getDataSet(ds, labels)
id3.train() # 训练ID3决策树
print(id3.tree) # 输出ID3决策树
print(id3.predict(id3.tree, {'outlook': 2, 'temp': 2, 'humidity': 1, 'windy': 1}))
treePlotter.createPlot(id3.tree)
🍽🍽🌼🌼结果:
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/147450.html