tianyun 2 anos atrás
pai
commit
30ae12f31c

+ 16 - 1
test1/pom.xml

@@ -36,7 +36,18 @@
             <groupId>org.projectlombok</groupId>
             <artifactId>lombok</artifactId>
         </dependency>
-
+        <!--03 版本 -->
+        <dependency>
+            <groupId>org.apache.poi</groupId>
+            <artifactId>poi</artifactId>
+            <version>4.1.2</version>
+        </dependency>
+        <!--07版本-->
+        <dependency>
+            <groupId>org.apache.poi</groupId>
+            <artifactId>poi-ooxml</artifactId>
+            <version>4.1.2</version>
+        </dependency>
         <!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->
         <dependency>
             <groupId>ai.catboost</groupId>
@@ -71,6 +82,10 @@
             <version>2.13.4.2</version>
             <scope>compile</scope>
         </dependency>
+        <dependency>
+            <groupId>junit</groupId>
+            <artifactId>junit</artifactId>
+        </dependency>
 
     </dependencies>
 </project>

+ 158 - 0
test1/src/main/java/ExcelUtil.java

@@ -0,0 +1,158 @@
+import org.apache.poi.ss.usermodel.*;
+import org.apache.poi.xssf.usermodel.XSSFWorkbook;
+import org.apache.poi.hssf.usermodel.HSSFWorkbook;
+import org.junit.Test;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * excel工具类
+ *
+ * @author alvin
+ * @date 2023/07/28
+ */
+public class ExcelUtil {
+
+    /**
+     * 将数据写入Excel文件的指定工作表中
+     *
+     * @param filePath       Excel文件路径
+     * @param sheetName      工作表名称
+     * @param rowBeginIndex  写入数据开始的行索引
+     * @param header         表头
+     * @param col            列数据列表,每个Map表示一列数据,key为行索引,value为对应单元格的值
+     * @return               写入是否成功
+     */
+    public static boolean writeExcel(String filePath, String sheetName, int rowBeginIndex, List<String> header, List<Map<String, Object>> col) {
+        Workbook workbook;
+        if (filePath.endsWith(".xlsx")) {
+            workbook = new XSSFWorkbook();
+        } else if (filePath.endsWith(".xls")) {
+            workbook = new HSSFWorkbook();
+        } else {
+            throw new IllegalArgumentException("Invalid file extension");
+        }
+
+        Sheet sheet = workbook.createSheet(sheetName);
+
+        int rowIndex = rowBeginIndex;
+        Row headerRow = sheet.createRow(rowIndex++);
+        for (int i = 0; i < header.size(); i++) {
+            Cell cell = headerRow.createCell(i);
+            cell.setCellValue(header.get(i));
+        }
+
+        for (Map<String, Object> rowData : col) {
+            Row dataRow = sheet.createRow(rowIndex++);
+            int colIndex = 0;
+            for (String key : header) {
+                Cell cell = dataRow.createCell(colIndex++);
+                Object value = rowData.get(key);
+                if (value instanceof String) {
+                    cell.setCellValue((String) value);
+                } else if (value instanceof Double) {
+                    cell.setCellValue((Double) value);
+                } else if (value instanceof Integer) {
+                    cell.setCellValue((Integer) value);
+                } else if (value instanceof Boolean) {
+                    cell.setCellValue((Boolean) value);
+                }
+            }
+        }
+
+        try (FileOutputStream outputStream = new FileOutputStream(filePath)) {
+            workbook.write(outputStream);
+            return true;
+        } catch (IOException e) {
+            e.printStackTrace();
+            return false;
+        }
+    }
+
+    /**
+     * 从Excel文件的指定工作表中读取数据,并返回一个包含每行数据的LinkedList
+     *
+     * @param filePath       Excel文件路径
+     * @param sheetName      工作表名称
+     * @param rowBeginIndex  读取数据开始的行索引
+     * @param colBeginIndex  读取数据开始的列索引
+     * @return               包含每行数据的LinkedList,每行数据以Map的形式表示,key为列索引的字符串表示,value为对应单元格的值
+     */
+    public static LinkedList<Map<String, Object>> readExcel(String filePath, String sheetName, int rowBeginIndex, int colBeginIndex) {
+        LinkedList<Map<String, Object>> result = new LinkedList<>();
+
+        try (FileInputStream fileInputStream = new FileInputStream(filePath)) {
+            Workbook workbook;
+            if (filePath.endsWith(".xlsx")) {
+                workbook = new XSSFWorkbook(fileInputStream);
+            } else if (filePath.endsWith(".xls")) {
+                workbook = new HSSFWorkbook(fileInputStream);
+            } else {
+                throw new IllegalArgumentException("Invalid file extension");
+            }
+
+            Sheet sheet = workbook.getSheet(sheetName);
+            if (sheet == null) {
+                throw new IllegalArgumentException("Sheet not found");
+            }
+
+            Row headerRow = sheet.getRow(rowBeginIndex);
+            if (headerRow == null) {
+                throw new IllegalArgumentException("Header row not found");
+            }
+
+            int lastRowNum = sheet.getLastRowNum();
+            for (int i = rowBeginIndex + 1; i <= lastRowNum; i++) {
+                Row dataRow = sheet.getRow(i);
+                if (dataRow == null) {
+                    continue;
+                }
+
+                Map<String, Object> rowData = new LinkedHashMap<>();
+                for (int j = colBeginIndex; j < headerRow.getLastCellNum(); j++) {
+                    Cell cell = dataRow.getCell(j);
+                    if (cell == null) {
+                        continue;
+                    }
+
+                    String key = headerRow.getCell(j).getStringCellValue();
+                    Object value = getCellValue(cell);
+                    rowData.put(key, value);
+                }
+
+                result.add(rowData);
+            }
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+
+        return result;
+    }
+
+    private static Object getCellValue(Cell cell) {
+        switch (cell.getCellType()) {
+            case STRING:
+                return cell.getStringCellValue();
+            case NUMERIC:
+                if (DateUtil.isCellDateFormatted(cell)) {
+                    return cell.getDateCellValue();
+                } else {
+                    return cell.getNumericCellValue();
+                }
+            case BOOLEAN:
+                return cell.getBooleanCellValue();
+            case FORMULA:
+                return cell.getCellFormula();
+            default:
+                return null;
+        }
+    }
+
+    public static void main(String[] args) {
+        LinkedList<Map<String, Object>> detailedListOfUserNumbers = readExcel("/Users/alvin/Downloads/授用信各资方家数表.xlsx", "用信家数明细表", 0, 0);
+        System.out.println(detailedListOfUserNumbers.size());
+    }
+}

+ 113 - 0
test1/src/main/java/RecommendationSystemTest.java

@@ -0,0 +1,113 @@
+import org.junit.Test;
+
+import java.util.*;
+
+public class RecommendationSystemTest {
+
+    @Test
+    public void testCalculateSimilarity() {
+        Map<String, Double> user1 = new HashMap<>();
+        user1.put("1", 2.5);
+        user1.put("2", 3.5);
+        user1.put("3", 3.0);
+        //user1.put(4, 3.5);
+        //user1.put(5, 2.5);
+        //user1.put(6, 3.0);
+
+        Map<String, Double> user2 = new HashMap<>();
+        user2.put("1", 2.5);
+        user2.put("2", 3.5);
+        user2.put("3", 3.5);
+        user2.put("4", 4.5);
+        user2.put("5", 5.0);
+        //user2.put(3, 3.5);
+        //user2.put(4, 4.0);
+
+        RecommenderSystem rs = new RecommenderSystem(new HashMap<>(), 3);
+        double similarity = rs.calculateSimilarity(user1, user2);
+        System.out.println(similarity);
+        // Check the similarity score
+        //assertEquals(0.39605901719066977, similarity);
+    }
+
+    @Test
+    public void testRecommendItems() {
+        Map<String, Double> user1 = new HashMap<>();
+        user1.put("1", 2.5);
+        user1.put("2", 3.5);
+        user1.put("3", 3.0);
+        //user1.put(5, 2.5);
+        //user1.put(6, 3.0);
+
+        Map<String, Double> user2 = new HashMap<>();
+        user2.put("1", 2.5);
+        user2.put("2", 3.5);
+        user2.put("3", 3.0);
+        user2.put("4", 4.0);
+        user2.put("7", 4.0);
+
+        Map<String, Double> user3 = new HashMap<>();
+        user3.put("1", 2.5);
+        user3.put("2", 3.5);
+        user3.put("3", 3.5);
+        user3.put("4", 4.5);
+        user3.put("5", 4.0);
+        user3.put("6", 1.0);
+
+        Map<String, Map<String, Double>> userItemRatingTable = new HashMap<>();
+        userItemRatingTable.put("1", user1);
+        userItemRatingTable.put("2", user2);
+        userItemRatingTable.put("3", user3);
+
+        RecommenderSystem rs = new RecommenderSystem(userItemRatingTable, 2);
+        Map<String, Double> recommendations = rs.recommendItems("1", 2);
+
+        System.out.println(recommendations);
+
+    }
+
+    @Test
+    public void testByExcel() {
+        LinkedList<Map<String, Object>> list = ExcelUtil.readExcel("/Users/alvin/Downloads/授用信各资方家数表.xlsx", "用信家数明细表", 0, 0);
+
+        HashSet<String> ignore = new HashSet<>();
+        ignore.add("企业名称");
+        ignore.add("总计");
+        ignore.add("家数");
+
+        Map<String, Map<String, Double>> userItemRatingTable = new HashMap<>();
+        for (Map<String, Object> map : list) {
+            HashMap<String, Double> tmpMap = new HashMap<>();
+            for (String s : map.keySet()) {
+                if (ignore.contains(s)) {
+                    continue;
+                }
+                if (map.get(s) == null) {
+                    tmpMap.put(s, 0D);
+                } else {
+                    tmpMap.put(s, Double.valueOf(map.get(s) + ""));
+                }
+
+            }
+            userItemRatingTable.put(map.get("企业名称") + "", tmpMap);
+        }
+        RecommenderSystem rs = new RecommenderSystem(userItemRatingTable, userItemRatingTable.size());
+
+        ArrayList<Map<String, Object>> outRes = new ArrayList<>();
+
+        for (String companyName : userItemRatingTable.keySet()) {
+            Map<String, Double> recommendations = rs.recommendItems(companyName);
+            recommendations.entrySet().removeIf(entry -> entry.getValue() == 0.0);
+
+            HashMap<String, Object> map = new HashMap<>();
+            map.put("企业名称", companyName);
+            map.putAll(recommendations);
+            outRes.add(map);
+        }
+
+        ExcelUtil.writeExcel("系统过滤推荐.xlsx", "推荐结果", 0,
+                Arrays.asList("企业名称", "工商银行", "瀚华普惠有限公司", "平安银行", "普惠", "苏宁银行", "天能惠商贷", "邮储银行", "邮惠万家"),
+                outRes);
+    }
+
+}

+ 83 - 0
test1/src/main/java/RecommenderSystem.java

@@ -0,0 +1,83 @@
+import jnr.ffi.annotations.In;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class RecommenderSystem {
+    private Map<String, Map<String, Double>> userItemRatingTable;
+    private int neighborhoodSize;
+
+    public RecommenderSystem(Map<String, Map<String, Double>> userItemRatingTable, int neighborhoodSize) {
+        this.userItemRatingTable = userItemRatingTable;
+        this.neighborhoodSize = neighborhoodSize;
+    }
+
+    public Map<String, Double> recommendItems(String userId) {
+        Map<String, Double> ratingTotalMap = new HashMap<>();
+        Map<String, Double> weightTotalMap = new HashMap<>();
+        Map<Double, String> similarityMap = new TreeMap<>(Collections.reverseOrder());
+        for (Map.Entry<String, Map<String, Double>> userEntry : userItemRatingTable.entrySet()) {
+            String neighborId = userEntry.getKey();
+            if (neighborId != userId) {
+                double similarity = calculateSimilarity(userItemRatingTable.get(userId), userItemRatingTable.get(neighborId));
+                similarityMap.put(similarity, neighborId);
+            }
+        }
+        int count = 0;
+        for (Map.Entry<Double, String> similarityEntry : similarityMap.entrySet()) {
+            String neighborId = similarityEntry.getValue();
+            Map<String, Double> items = userItemRatingTable.get(neighborId);
+            for (Map.Entry<String, Double> itemEntry : items.entrySet()) {
+                String itemId = itemEntry.getKey();
+                double rating = itemEntry.getValue();
+                ratingTotalMap.put(itemId, ratingTotalMap.getOrDefault(itemId, 0.0) + similarityEntry.getKey() * rating);
+                weightTotalMap.put(itemId, weightTotalMap.getOrDefault(itemId, 0.0) + similarityEntry.getKey());
+            }
+            count++;
+            if (count >= neighborhoodSize) {
+                break;
+            }
+        }
+        Map<String, Double> recommendedItemScores = new HashMap<>();
+        for (Map.Entry<String, Double> ratingTotalEntry : ratingTotalMap.entrySet()) {
+            String itemId = ratingTotalEntry.getKey();
+            double score = ratingTotalEntry.getValue() / weightTotalMap.get(itemId);
+            recommendedItemScores.put(itemId, score);
+        }
+        return recommendedItemScores;
+    }
+
+    public LinkedHashMap<String, Double> recommendItems(String userId, int topN) {
+        Map<String, Double> userMap = this.userItemRatingTable.get(userId);
+        Map<String, Double> recommendedMaps = this.recommendItems(userId);
+
+        recommendedMaps = recommendedMaps.entrySet()
+                .stream()
+                .filter(entry -> !userMap.containsKey(entry.getKey()))  // 过滤当前用户
+                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())  // 排序
+                .limit(topN)   // 取topN
+                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new));
+
+        return (LinkedHashMap<String, Double>) recommendedMaps;
+    }
+
+
+    public double calculateSimilarity(Map<String, Double> user1, Map<String, Double> user2) {
+        Set<String> commonItemIds = new HashSet<>(user1.keySet());
+        commonItemIds.retainAll(user2.keySet());
+        double numerator = 0.0;
+        double denominator1 = 0.0;
+        double denominator2 = 0.0;
+        for (String itemId : commonItemIds) {
+            numerator += user1.get(itemId) * user2.get(itemId);
+            denominator1 += Math.pow(user1.get(itemId), 2);
+            denominator2 += Math.pow(user2.get(itemId), 2);
+        }
+        double denominator = Math.sqrt(denominator1) * Math.sqrt(denominator2);
+        if (denominator == 0) {
+            return 0.0;
+        } else {
+            return numerator / denominator;
+        }
+    }
+}