tianyunperfect 2 年 前
コミット
d90e30de09

+ 22 - 1
test1/pom.xml

@@ -19,7 +19,7 @@
         <dependency>
             <groupId>org.python</groupId>
             <artifactId>jython-standalone</artifactId>
-            <version>2.7.0</version>
+            <version>2.7.3</version>
         </dependency>
 
         <dependency>
@@ -27,5 +27,26 @@
             <artifactId>pmml-evaluator</artifactId>
             <version>1.5.15</version>
         </dependency>
+        <dependency>
+            <groupId>org.jpmml</groupId>
+            <artifactId>pmml-evaluator-extension</artifactId>
+            <version>1.5.15</version>
+        </dependency>
+        <dependency>
+            <groupId>org.projectlombok</groupId>
+            <artifactId>lombok</artifactId>
+        </dependency>
+
+        <!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->
+        <dependency>
+            <groupId>ai.catboost</groupId>
+            <artifactId>catboost-prediction</artifactId>
+            <version>1.0.6</version>
+        </dependency>
+        <dependency>
+            <groupId>com.alibaba</groupId>
+            <artifactId>fastjson</artifactId>
+            <version>1.2.79</version>
+        </dependency>
     </dependencies>
 </project>

+ 59 - 0
test1/src/main/java/CatBoostClassifier.java

@@ -0,0 +1,59 @@
+
+import ai.catboost.CatBoostModel;
+import ai.catboost.CatBoostPredictions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.Map;
+ 
+public class CatBoostClassifier {
+ 
+    public CatBoostModel model;
+ 
+    private static Logger LOG= LoggerFactory.getLogger(CatBoostClassifier.class);
+ 
+    public CatBoostClassifier(String model_path) {
+        try{
+
+            InputStream model_file = new FileInputStream(new File(model_path));
+            this.model = CatBoostModel.loadModel(model_file);
+ 
+        }catch (Exception err){
+            LOG.error("catboost-init-error", err);
+            System.out.println(err.toString());
+        };
+ 
+    }
+ 
+    public Double predict(Map<String,Double> num_features,Map<String,String> cat_features){
+        try {
+            String[] feature_names = this.model.getFeatureNames();
+            assert (num_features.size() + cat_features.size()) == feature_names.length;
+ 
+            float[] num_arr = new float[num_features.size()];
+            String[] cat_arr = new String[cat_features.size()];
+ 
+            int i = 0;
+            int j = 0;
+            for (String name : feature_names) {
+                if (num_features.keySet().contains(name)) {
+                    num_arr[i] = num_features.get(name).floatValue();
+                    i += 1;
+                } else {
+                    assert cat_features.keySet().contains(name);
+                    cat_arr[j] = cat_features.get(name);
+                    j += 1;
+                }
+            }
+            CatBoostPredictions prediction = this.model.predict(
+                    num_arr,
+                    cat_arr);
+            Double prob = 1.0 - 1.0 / (1.0 + Math.exp(prediction.get(0, 0)));
+            return prob;
+        }catch (Exception err){
+            LOG.error("catboost-predict-error", err);
+            return null;
+        }
+    }
+}

+ 66 - 0
test1/src/main/java/CatBoostTest.java

@@ -0,0 +1,66 @@
+import com.alibaba.fastjson.JSON;
+
+import java.math.BigDecimal;
+import java.util.*;
+
+public class CatBoostTest {
+    public static void testCatBoostClassifier(){
+        CatBoostClassifier clf = new CatBoostClassifier("/Users/alvin/Downloads/catboost.cbm");
+        
+        Map<String,Double> num_features = new HashMap<String,Double>();
+        Map<String,String> cat_features = new HashMap<String,String>();
+        
+        //num_features.put("fnlwgt",236379.0);
+        //num_features.put("education-num",11.0);
+        Map<String, BigDecimal> hashMap = JSON.parseObject("{\n" +
+                "  \"income_tax_12m_chu_income_duty_amount_12m\": 0.5,\n" +
+                "  \"registcapi_chu_reccap\": 1.0,\n" +
+                "  \"vat_sales_huanbi_1y\": 0.2316023872825141,\n" +
+                "  \"late_fee_24m\": 0.0,\n" +
+                "  \"drs_nodebtscore\": 16.0,\n" +
+                "  \"vat_duty_amount_12m\": 11239.52,\n" +
+                "  \"stockpercent_jicha\": 2.0,\n" +
+                "  \"income_duty_amount_24m\": 7953.72,\n" +
+                "  \"cv_vat_sales_12m\": 1.7568520608821654,\n" +
+                "  \"vat_sales_huanbi_6m\": -0.146063232514475,\n" +
+                "  \"slope_vat_sales_12m\": -0.02365417557711674,\n" +
+                "  \"fraud18\": 43.0,\n" +
+                "  \"income_tax_12m\": 0.0,\n" +
+                "  \"stockpercent\": 49.0,\n" +
+                "  \"month_late_fee_24m\": 0.0,\n" +
+                "  \"cn_zeroincome_24m\": 8.0,\n" +
+                "  \"growth_index\": 0.0,\n" +
+                "  \"vat_duty_month_3m\": 3.0,\n" +
+                "  \"tl_id_m6_nbank_passorg\": 1.0,\n" +
+                "  \"now_amount_tax_arrears_24m\": 0.0,\n" +
+                "  \"cn_zerodeclaration_12m\": 0.0,\n" +
+                "  \"penalty_count_total\": 2.0,\n" +
+                "  \"now_month_tax_arrears_12m\": 0.0\n" +
+                "}", HashMap.class);
+        for (Map.Entry<String, BigDecimal> entry : hashMap.entrySet()) {
+            num_features.put(entry.getKey(), entry.getValue().doubleValue());
+        }
+
+        //cat_features.put("workclass","Private");
+        //cat_features.put("sex","Male");
+        //cat_features.put("relationship","Husband");
+        //cat_features.put("race","White");
+        //cat_features.put("occupation","Craft-repair");
+        //cat_features.put("native-country","United-States");
+        //cat_features.put("marital-status","Married-civ-spouse");
+        //cat_features.put("hoursperweek","2");
+        //cat_features.put("education","Assoc-voc");
+        //cat_features.put("capitalloss","0");
+        //cat_features.put("capitalgain","0");
+        //cat_features.put("age","1");
+ 
+        Double model_prob = clf.predict(num_features,cat_features);
+        System.out.println(model_prob);
+        System.out.println("sucessed!");
+    }
+
+    public static void main(String[] args) {
+        testCatBoostClassifier();
+    }
+ 
+}

+ 63 - 0
test1/src/main/java/FileUtil.java

@@ -0,0 +1,63 @@
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.UnsupportedEncodingException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * 文件:
+ * 创建
+ * 是否存在
+ * 读取字符串
+ * 读取全部行
+ * 按行读取
+ * 删除
+ * <p>
+ * 文件夹:
+ * 是否为空
+ * 文件列表
+ * 删除文件夹
+ *
+ * @author mlamp
+ * @date 2022/05/21
+ */
+public class FileUtil {
+    public static final String DefaultEncode = "UTF-8";
+
+    public static String readToString(String filePath, String encode) throws UnsupportedEncodingException {
+        File file = new File(filePath);
+        Long fileLength = new File(filePath).length();
+        byte[] fileContent = new byte[fileLength.intValue()];
+        try {
+            FileInputStream in = new FileInputStream(file);
+            in.read(fileContent);
+            in.close();
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+        return new String(fileContent, encode);
+    }
+
+    public static String readToString(String filePath) throws UnsupportedEncodingException {
+        return readToString(filePath, DefaultEncode);
+    }
+
+    public static List<String> readToLines(String filePath, String encode) throws UnsupportedEncodingException {
+        String[] strArr = readToString(filePath, encode).split("\\n");
+        ArrayList<String> strings = new ArrayList<>(strArr.length);
+        Collections.addAll(strings, strArr);
+        return strings;
+    }
+
+    public static List<String> readToLines(String filePath) throws UnsupportedEncodingException {
+        return readToLines(filePath, DefaultEncode);
+    }
+
+    public static void main(String[] args) throws UnsupportedEncodingException {
+        //System.out.println(FileUtil.readToString("pom.xml"));
+        System.out.println(readToLines("pom.xml").size());
+    }
+
+
+}

+ 0 - 8
test1/src/main/java/JavaPythonFile.java

@@ -1,8 +0,0 @@
-import org.python.util.PythonInterpreter;
-public class JavaPythonFile {
-    public static void main(String[] args) {
-        PythonInterpreter interpreter = new PythonInterpreter();
-        //我在这里使用相对路径,注意区分
-        interpreter.execfile("C:\\Users\\root\\IdeaProjects\\python-base\\tmp.py");
-    }
-}

+ 83 - 0
test1/src/main/java/PythonUtil.java

@@ -0,0 +1,83 @@
+import org.python.core.Py;
+import org.python.core.PyObject;
+import org.python.core.PyString;
+import org.python.util.PythonInterpreter;
+ 
+import java.util.*;
+ 
+/**
+ * @author lilong
+ * @date 2020/12/01
+ * @desc
+ */
+public class PythonUtil {
+ 
+    /**
+     * 设置一个关键key,用于从python中获取结果
+     */
+    private static String resultKey = "result_fc3905687c494c09a6dbfe713390bf3c";
+ 
+    /**
+     * 系统变量,如果不设置,初始化会很慢
+     */
+    static {
+        Properties props = new Properties();
+        props.put("python.home", "path to the Lib folder");
+        props.put("python.console.encoding", "UTF-8");
+        props.put("python.security.respectJavaAccessibility", "false");
+        props.put("python.import.site", "false");
+        Properties preProps = System.getProperties();
+        PythonInterpreter.initialize(preProps, props, new String[0]);
+    }
+ 
+    /**
+     * 执行python脚本
+     * @param express 脚本,需要注意换行格式正确
+     * @param map 入参变量集合,key为变量名,value为变量值
+     * @return
+     */
+    public static String doExpress(String express, Map<String, Object> map) {
+        try (PythonInterpreter pythonInterpreter = new PythonInterpreter()) {
+            map.forEach((key, value) -> {
+                PyObject pyValue;
+                if (value instanceof String) {
+                    //中文乱码解决
+                    pyValue = Py.newStringUTF8(value.toString());
+                } else if (value instanceof Float || value instanceof Double) {
+                    //python和java浮点精度不一致问题,都转为double
+                    pyValue = Py.newFloat(Double.parseDouble(value.toString()));
+                } else if (value instanceof Integer) {
+                    pyValue = Py.newInteger((int) value);
+                } else {
+                    pyValue = Py.java2py(value);
+                }
+                pythonInterpreter.set(key, pyValue);
+            });
+            express = express.replace("return", resultKey + "=");
+            
+
+            PyString pyString = Py.newStringUTF8(express);
+            pythonInterpreter.exec(pyString);
+            return pythonInterpreter.get(resultKey).toString();
+        } catch (Exception e) {
+            throw new RuntimeException("执行python脚本出错" + e);
+        }
+    }
+ 
+    public static void main(String[] args) {
+        Map<String, Object> input = new HashMap<>(4);
+ 
+        String express = 
+                "if args in [1.0,'1','挑葱夫']:\n" +
+                "    return 1\n" +
+                "else:\n" +
+                "    return 0";
+ 
+        input.put("args", "1");
+        System.out.println(doExpress(express, input));
+        input.put("args", "挑葱夫");
+        System.out.println(doExpress(express, input));
+        input.put("args", 1.0);
+        System.out.println(doExpress(express, input));
+    }
+}

+ 39 - 33
test1/src/main/java/Server.java

@@ -1,6 +1,8 @@
-import org.dmg.pmml.FieldName;
 import org.jpmml.evaluator.*;
+import org.jpmml.evaluator.visitors.DefaultModelEvaluatorBattery;
+import org.jpmml.model.visitors.VisitorBattery;
 import org.xml.sax.SAXException;
+import org.dmg.pmml.FieldName;
 
 import javax.xml.bind.JAXBException;
 import java.io.File;
@@ -13,9 +15,12 @@ public class Server {
 
     public static void main(String[] args) throws JAXBException, SAXException, IOException {
         // Building a model evaluator from a PMML file
-        String modelPath = "lgb_model.pmml";
-        System.out.println(modelPath);
-        Evaluator evaluator = new LoadingModelEvaluatorBuilder().load(new File(modelPath)).build();
+        String modelPath = "/Users/alvin/Downloads/first_test.pmml";
+        //String modelPath = "lgb_model.pmml";
+        //System.out.println(modelPath);
+        Evaluator evaluator = new LoadingModelEvaluatorBuilder()
+                .setLocatable(false)
+                .load(new File(modelPath)).build();
 
         // Performing the self-check
         evaluator.verify();
@@ -38,37 +43,38 @@ public class Server {
             System.out.println(outputField);
         }
 
-        // Predicting
-        Map<String, Double> inputRecord = new LinkedHashMap<String, Double>();
-        // 5.1, 3.5, 1.4, 0.2 -> 0
-        // 6.4, 3.2, 4.5, 1.5 -> 1
-        // 5.9, 3. , 5.1, 1.8 -> 2
-        inputRecord.put("sepal_length_(cm)", 5.1);
-        inputRecord.put("sepal_width_(cm)", 3.5);
-        inputRecord.put("petal_length_(cm)", 1.4);
-        inputRecord.put("petal_width_(cm)", 0.2);
-
-        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
-
-        // Mapping the record field-by-field from data source schema to PMML schema
-        for (InputField inputField : inputFields) {
-            FieldName inputName = inputField.getName();
-
-            Object rawValue = inputRecord.get(inputName.getValue());
-
-            // Transforming an arbitrary user-supplied value to a known-good PMML value
-            FieldValue inputValue = inputField.prepare(rawValue);
-
-            arguments.put(inputName, inputValue);
+        List<String> strings = FileUtil.readToLines("/Users/alvin/Downloads/11111111111111111111111111111111111111.csv");
+        for (String string : strings) {
+            // Predicting
+            Float[] doubles = new Float[23];
+            String[] split = string.split(",");
+            for (int i = 0; i < split.length; i++) {
+                if (split[i].length() == 0) {
+                    doubles[i] = new Float(0);
+                } else {
+                    doubles[i] = new Float(split[i]);
+                }
+            }
+            Float[] arr = new Float[]{5.1f, 3.5f, 1.4f, 0.2f};
+
+            Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
+
+            // Mapping the record field-by-field from data source schema to PMML schema
+            for (int i = 0; i < inputFields.size(); i++) {
+                InputField inputField = inputFields.get(i);
+                arguments.put(inputField.getName(), inputField.prepare(doubles[i]));
+            }
+
+            // Evaluating the model with known-good arguments
+            Map<FieldName, ?> results = evaluator.evaluate(arguments);
+            //System.out.println(results);
+
+            // Decoupling results from the JPMML-Evaluator runtime environment
+            Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
+            System.out.println(resultRecord);
+            break;
         }
 
-        // Evaluating the model with known-good arguments
-        Map<FieldName, ?> results = evaluator.evaluate(arguments);
-        System.out.println(results);
-
-        // Decoupling results from the JPMML-Evaluator runtime environment
-        Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
-        System.out.println(resultRecord);
 
     }
 }

+ 78 - 0
test1/src/main/java/Server1.java

@@ -0,0 +1,78 @@
+import org.dmg.pmml.FieldName;
+import org.jpmml.evaluator.*;
+import org.xml.sax.SAXException;
+
+import javax.xml.bind.JAXBException;
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Server1 {
+
+    public static void main(String[] args) throws JAXBException, SAXException, IOException {
+        // Building a model evaluator from a PMML file
+        String modelPath = "/Users/alvin/Downloads/RandomForestClassifier_Iris.pmml";
+        //String modelPath = "lgb_model.pmml";
+        //System.out.println(modelPath);
+        Evaluator evaluator = new LoadingModelEvaluatorBuilder()
+                .setLocatable(false)
+                .load(new File(modelPath)).build();
+
+        // Performing the self-check
+        evaluator.verify();
+
+        // Printing input (x1, x2, .., xn) fields
+        List<? extends InputField> inputFields = evaluator.getInputFields();
+        System.out.println("Input fields: ");
+        for (InputField inputField : inputFields) {
+            System.out.println(inputField);
+        }
+
+        // Printing primary result (y) field(s)
+        List<? extends TargetField> targetFields = evaluator.getTargetFields();
+        System.out.println("Target field(s): " + targetFields);
+
+        // Printing secondary result (eg. probability(y), decision(y)) fields
+        List<? extends OutputField> outputFields = evaluator.getOutputFields();
+        System.out.println("Output fields: ");
+        for (OutputField outputField : outputFields) {
+            System.out.println(outputField);
+        }
+
+        List<String> strings = FileUtil.readToLines("/Users/alvin/Downloads/11111111111111111111111111111111111111.csv");
+        for (String string : strings) {
+            // Predicting
+            Float[] doubles = new Float[23];
+            String[] split = string.split(",");
+            for (int i = 0; i < split.length; i++) {
+                if (split[i].length() == 0) {
+                    doubles[i] = new Float(0);
+                } else {
+                    doubles[i] = new Float(split[i]);
+                }
+            }
+            Float[] arr = new Float[]{6.6f, 3.0f, 4.4f, 1.4f};
+
+            Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
+
+            // Mapping the record field-by-field from data source schema to PMML schema
+            for (int i = 0; i < inputFields.size(); i++) {
+                InputField inputField = inputFields.get(i);
+                arguments.put(inputField.getName(), inputField.prepare(arr[i]));
+            }
+
+            // Evaluating the model with known-good arguments
+            Map<FieldName, ?> results = evaluator.evaluate(arguments);
+            //System.out.println(results);
+
+            // Decoupling results from the JPMML-Evaluator runtime environment
+            Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
+            System.out.println(resultRecord);
+            break;
+        }
+
+
+    }
+}

+ 26 - 6
test1/src/main/java/tmp.java

@@ -1,10 +1,30 @@
 import java.util.HashMap;
 
 public class tmp {
-    public static void main(String[] args) throws InterruptedException {
-        HashMap<String, String> stringStringHashMap = new HashMap<>();
-        stringStringHashMap.put(null, null);
-        stringStringHashMap.put(null, null);
-        System.out.println(stringStringHashMap);
+    private static String solution(String value) {
+        String[] arr = value.split(",");
+        //valueArray是题目中的整型数组
+        int[] valueArray = new int[arr.length];
+        for (int i = 0; i < arr.length; i++) {
+            valueArray[i] = Integer.parseInt(arr[i]);
+        }
+
+        //Coding here...
+        int minValue = Integer.MAX_VALUE;
+        int maxD = 0;
+        for (int i = 0; i < valueArray.length; i++) {
+            if (valueArray[i] < minValue) {
+                minValue = valueArray[i];
+            } else if (valueArray[i] - minValue > maxD) {
+                maxD = valueArray[i] - minValue;
+            }
+        }
+
+        //在处理完问题之后,请将你的结果转换为字符串,并return
+        return String.valueOf(maxD);
+    }
+
+    public static void main(String[] args) {
+        System.out.println(solution("7,1,5,3,6,4"));
     }
-}
+}

+ 12 - 0
test1/src/main/java/tmp2.java

@@ -0,0 +1,12 @@
+public class tmp2 {
+    public static void main(String[] args) {
+        int[] x = {1, 2, 3};
+        int y[] = {4, 5, 6};
+        new tmp2().go(x, y);
+    }
+
+    void go(int[]... z) {
+        for (int[] a : z)
+            System.out.print(a[0]);
+    }
+}