tmp2.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -*- coding: utf-8 -*-
  2. # 导入必要的库
  3. import pandas as pd
  4. from catboost import CatBoostClassifier
  5. from matplotlib import pyplot as plt
  6. from sklearn.model_selection import train_test_split
  7. # 加载示例数据集
  8. data = pd.read_excel('/Users/alvin/Downloads/ai_v3_bill_sample02.xlsx')
  9. # 获取除了risk_buss_no之外的所有列,并转为数字
  10. data = data.drop('risk_buss_no', axis=1)
  11. data = data.apply(pd.to_numeric, errors='coerce')
  12. print(data.columns)
  13. X = pd.get_dummies(data.drop('y', axis=1))
  14. y = data['y']
  15. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  16. # 定义CatBoost分类器并训练模型
  17. # n_estimators 78
  18. # max_depth 2
  19. # learning_rate 0.05
  20. # l2_leaf_reg 17
  21. # subsample 0.7
  22. # max_leaves 28
  23. # min_data_in_leaf 2
  24. # grow_policy Lossguide
  25. # model = CatBoostClassifier()
  26. model = CatBoostClassifier(n_estimators=78, max_depth=2, learning_rate=0.05, l2_leaf_reg=17, subsample=0.7, max_leaves=28, min_data_in_leaf=2, grow_policy='Lossguide')
  27. model.fit(X_train, y_train, verbose=False)
  28. # 评估模型性能
  29. print('Train accuracy:', model.score(X_train, y_train))
  30. print('Test accuracy:', model.score(X_test, y_test))
  31. # 保存模型
  32. # model.save_model('catboost_model.bin')
  33. # 加载模型
  34. # loaded_model = CatBoostClassifier()
  35. # loaded_model.load_model('catboost_model.bin')
  36. #
  37. # 使用模型进行预测
  38. preds_class = model.predict(X_test)
  39. preds_proba = model.predict_proba(X_test)
  40. # 输出预测结果
  41. print('Predicted classes:', preds_class)
  42. print('Predicted probabilities:', preds_proba)
  43. # 获取 preds_proba 的第一列,打印最大值、最小值
  44. print(preds_proba[:, 0].max())
  45. print(preds_proba[:, 0].min())
  46. #
  47. # from sklearn.metrics import classification_report, confusion_matrix
  48. # from sklearn.metrics import precision_recall_curve, roc_curve, auc
  49. #
  50. # print(classification_report(y_test, model.predict(X_test)))
  51. # confusion = confusion_matrix(y_test, model.predict(X_test), normalize='all')
  52. # print(confusion)
  53. # # AUC ROC Curve plotting
  54. # probs = model.predict_proba(X_test)
  55. # preds = probs[:, 1]
  56. # fpr, tpr, threshold = roc_curve(y_test, preds)
  57. # roc_auc = auc(fpr, tpr)
  58. #
  59. # # plt.figure(figsize = (12, 7))
  60. # plt.title('Receiver Operating Characteristic', weight='bold')
  61. # plt.plot(fpr, tpr, 'b', label='XGBClassifier (AUC = %0.2f)' % roc_auc)
  62. # plt.legend(loc='lower right')
  63. # plt.plot([0, 1], [0, 1], 'r--')
  64. # plt.ylabel('True Positive Rate', fontsize=12)
  65. # plt.xlabel('False Positive Rate', fontsize=12)
  66. # plt.show()