# -*- coding: utf-8 -*- # 导入必要的库 import pandas as pd from catboost import CatBoostClassifier from matplotlib import pyplot as plt from sklearn.model_selection import train_test_split # 加载示例数据集 data = pd.read_excel('/Users/alvin/Downloads/ai_v3_bill_sample02.xlsx') # 获取除了risk_buss_no之外的所有列,并转为数字 data = data.drop('risk_buss_no', axis=1) data = data.apply(pd.to_numeric, errors='coerce') print(data.columns) X = pd.get_dummies(data.drop('y', axis=1)) y = data['y'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 定义CatBoost分类器并训练模型 # n_estimators 78 # max_depth 2 # learning_rate 0.05 # l2_leaf_reg 17 # subsample 0.7 # max_leaves 28 # min_data_in_leaf 2 # grow_policy Lossguide # model = CatBoostClassifier() model = CatBoostClassifier(n_estimators=78, max_depth=2, learning_rate=0.05, l2_leaf_reg=17, subsample=0.7, max_leaves=28, min_data_in_leaf=2, grow_policy='Lossguide') model.fit(X_train, y_train, verbose=False) # 评估模型性能 print('Train accuracy:', model.score(X_train, y_train)) print('Test accuracy:', model.score(X_test, y_test)) # 保存模型 # model.save_model('catboost_model.bin') # 加载模型 # loaded_model = CatBoostClassifier() # loaded_model.load_model('catboost_model.bin') # # 使用模型进行预测 preds_class = model.predict(X_test) preds_proba = model.predict_proba(X_test) # 输出预测结果 print('Predicted classes:', preds_class) print('Predicted probabilities:', preds_proba) # 获取 preds_proba 的第一列,打印最大值、最小值 print(preds_proba[:, 0].max()) print(preds_proba[:, 0].min()) # # from sklearn.metrics import classification_report, confusion_matrix # from sklearn.metrics import precision_recall_curve, roc_curve, auc # # print(classification_report(y_test, model.predict(X_test))) # confusion = confusion_matrix(y_test, model.predict(X_test), normalize='all') # print(confusion) # # AUC ROC Curve plotting # probs = model.predict_proba(X_test) # preds = probs[:, 1] # fpr, tpr, threshold = roc_curve(y_test, preds) # roc_auc = auc(fpr, tpr) # # # plt.figure(figsize = (12, 7)) # plt.title('Receiver Operating Characteristic', weight='bold') # plt.plot(fpr, tpr, 'b', label='XGBClassifier (AUC = %0.2f)' % roc_auc) # plt.legend(loc='lower right') # plt.plot([0, 1], [0, 1], 'r--') # plt.ylabel('True Positive Rate', fontsize=12) # plt.xlabel('False Positive Rate', fontsize=12) # plt.show()