Created
November 28, 2024 10:47
-
-
Save tiandiao123/b4af498778d69af761b85a5c32ba09af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lightgbm as lgb | |
import numpy as np | |
import pandas as pd | |
import shap | |
from sklearn.model_selection import train_test_split | |
def feature_selection_lgb(X, y, feature_names, threshold=0.01, use_shap=False): | |
""" | |
使用 LightGBM 进行特征选择 | |
Args: | |
X: 特征矩阵 | |
y: 目标变量 | |
feature_names: 特征名列表 | |
threshold: 重要性阈值,低于此值的特征将被过滤 | |
use_shap: 是否使用 SHAP 值进行特征重要性分析 | |
Returns: | |
selected_features: 筛选后的特征列表 | |
importance_df: 特征重要性数据框 | |
""" | |
# 划分训练集和验证集 | |
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) | |
# 创建数据集 | |
train_data = lgb.Dataset(X_train, label=y_train, feature_name=feature_names) | |
val_data = lgb.Dataset(X_val, label=y_val, feature_name=feature_names) | |
# 设置参数 | |
params = { | |
'objective': 'multiclass', | |
'num_class': 3, # 假设是三分类问题 | |
'boosting_type': 'gbdt', | |
'metric': 'multi_logloss', | |
'num_leaves': 31, | |
'learning_rate': 0.05, | |
'feature_fraction': 0.9, | |
'bagging_fraction': 0.8, | |
'bagging_freq': 5, | |
'verbose': -1 | |
} | |
# 训练模型 | |
model = lgb.train( | |
params, | |
train_data, | |
num_boost_round=100, | |
valid_sets=[val_data], | |
early_stopping_rounds=20, | |
verbose_eval=False | |
) | |
if use_shap: | |
# 使用 SHAP 值计算特征重要性 | |
explainer = shap.TreeExplainer(model) | |
shap_values = explainer.shap_values(X_train) | |
# 对于多分类问题,取所有类别的平均绝对 SHAP 值 | |
importance_values = np.mean([np.abs(shap_values[i]).mean(axis=0) for i in range(len(shap_values))]) | |
importance_type = 'SHAP value' | |
else: | |
# 使用普通的特征重要性 | |
importance_values = model.feature_importance(importance_type='split') | |
importance_type = 'split' | |
# 创建特征重要性数据框 | |
importance_df = pd.DataFrame({ | |
'feature': feature_names, | |
'importance': importance_values | |
}) | |
importance_df = importance_df.sort_values('importance', ascending=False) | |
# 归一化特征重要性 | |
importance_df['importance_normalized'] = importance_df['importance'] / importance_df['importance'].sum() | |
# 根据阈值筛选特征 | |
selected_features = importance_df[importance_df['importance_normalized'] > threshold]['feature'].tolist() | |
print(f"\n特征选择结果 (基于 {importance_type}):") | |
print(f"原始特征数量: {len(feature_names)}") | |
print(f"筛选后特征数量: {len(selected_features)}") | |
print("\n重要性排名前10的特征:") | |
print(importance_df.head(10)) | |
return selected_features, importance_df | |
def plot_feature_importance(importance_df, top_n=20): | |
""" | |
绘制特征重要性图 | |
""" | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
plt.figure(figsize=(12, 6)) | |
sns.barplot( | |
data=importance_df.head(top_n), | |
x='importance_normalized', | |
y='feature' | |
) | |
plt.title(f'Top {top_n} Feature Importance') | |
plt.xlabel('Normalized Importance') | |
plt.ylabel('Features') | |
plt.tight_layout() | |
plt.show() | |
# 使用示例 | |
""" | |
# 假设你有以下数据: | |
X = your_feature_matrix | |
y = your_target_variable | |
feature_names = your_feature_names | |
# 使用 split 重要性进行特征选择 | |
selected_features, importance_df = feature_selection_lgb( | |
X, | |
y, | |
feature_names, | |
threshold=0.01, | |
use_shap=False | |
) | |
# 绘制特征重要性图 | |
plot_feature_importance(importance_df) | |
# 使用 SHAP 值进行特征选择 | |
selected_features_shap, importance_df_shap = feature_selection_lgb( | |
X, | |
y, | |
feature_names, | |
threshold=0.01, | |
use_shap=True | |
) | |
# 绘制 SHAP 特征重要性图 | |
plot_feature_importance(importance_df_shap) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment