tmp1.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import pandas as pd
  2. from catboost import CatBoostClassifier, Pool
  3. from sklearn.model_selection import train_test_split
  4. def train_model(data_path, model_path):
  5. # 加载数据集
  6. data = pd.read_csv(data_path)
  7. print(data)
  8. X = pd.get_dummies(data.drop('class', axis=1))
  9. y = data['class']
  10. # 划分训练集和测试集
  11. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  12. # 定义CatBoost分类器并训练模型
  13. model = CatBoostClassifier(iterations=100, depth=2, learning_rate=0.1, loss_function='Logloss')
  14. model.fit(X_train, y_train, verbose=False)
  15. # 保存模型
  16. model.save_model(model_path)
  17. # 返回训练好的模型
  18. return model, list(X.columns)
  19. def predict(model_path, input_data, input_columns):
  20. # 加载模型
  21. loaded_model = CatBoostClassifier()
  22. loaded_model.load_model(model_path)
  23. # 将输入数据转换为DataFrame格式
  24. input_df = pd.DataFrame(input_data, columns=input_columns)
  25. input_df = pd.get_dummies(input_df)
  26. # 使用模型进行预测
  27. preds_class = loaded_model.predict(input_df)
  28. preds_proba = loaded_model.predict_proba(input_df)
  29. # 返回预测结果
  30. return preds_class, preds_proba
  31. data_path = 'mushroom.csv'
  32. model_path = 'catboost_model.bin'
  33. # 训练模型
  34. trained_model, input_columns = train_model(data_path, model_path)
  35. # 输入数据示例
  36. input_data = [['x', 's', 'n', 't', 'p', 'f', 'c', 'n', 'k', 'e', 'e', 's', 's', 'w', 'w', 'p', 'w', 'o', 'p', 'k', 's', 'u']]
  37. # 进行预测
  38. preds_class, preds_proba = predict(model_path, input_data, input_columns)
  39. # 输出预测结果
  40. print('Predicted classes:', preds_class)
  41. print('Predicted probabilities:', preds_proba)