Skip to content
章节导航

分组和分区

特性分组 (Grouping)分区 (Partitioning)
分类依据任意分类键布尔条件(只有true/false两组)
组数多个组(取决于分类键的取值固定2组(true/false)
主要方法groupingBy()partitioningBy()
返回值Map<K, List<T>>Map<Boolean, List<T>>
适用场景多类别分类二分类(是/否)

分组(GroupingBy)

1. 基础分组

java
List<User> users = Arrays.asList(
    new User("Alice", "北京", 25),
    new User("Bob", "上海", 30),
    new User("Charlie", "北京", 28),
    new User("David", "上海", 22),
    new User("Eve", "广州", 35)
);

// 1.1 按城市分组
Map<String, List<User>> usersByCity = users.stream()
    .collect(Collectors.groupingBy(User::getCity));

// 结果:
// {
//   "北京": [Alice(25), Charlie(28)],
//   "上海": [Bob(30), David(22)],
//   "广州": [Eve(35)]
// }

// 1.2 按年龄范围分组
Map<String, List<User>> usersByAgeGroup = users.stream()
    .collect(Collectors.groupingBy(user -> {
        if (user.getAge() < 25) return "青年";
        else if (user.getAge() < 35) return "中年";
        else return "资深";
    }));

// 1.3 多级分组(按城市,再按年龄范围)
Map<String, Map<String, List<User>>> multiLevelGroup = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.groupingBy(user -> {
            if (user.getAge() < 25) return "<25";
            else if (user.getAge() < 35) return "25-34";
            else return "35+";
        })
    ));

2. 分组后的下游操作

java
// 2.1 分组计数
Map<String, Long> countByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.counting()
    ));
// 结果:{"北京": 2, "上海": 2, "广州": 1}

// 2.2 分组求和(年龄总和)
Map<String, Integer> sumAgeByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.summingInt(User::getAge)
    ));
// 结果:{"北京": 53, "上海": 52, "广州": 35}

// 2.3 分组求平均值
Map<String, Double> avgAgeByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.averagingInt(User::getAge)
    ));
// 结果:{"北京": 26.5, "上海": 26.0, "广州": 35.0}

// 2.4 分组求最大值/最小值
Map<String, Optional<User>> oldestByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.maxBy(Comparator.comparingInt(User::getAge))
    ));

// 2.5 分组映射(只获取姓名列表)
Map<String, List<String>> namesByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.mapping(User::getName, Collectors.toList())
    ));
// 结果:{"北京": ["Alice", "Charlie"], "上海": ["Bob", "David"], "广州": ["Eve"]}

// 2.6 分组连接字符串
Map<String, String> joinedNamesByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.mapping(
            User::getName,
            Collectors.joining(", ", "[", "]")
        )
    ));
// 结果:{"北京": "[Alice, Charlie]", "上海": "[Bob, David]", "广州": "[Eve]"}

3. 分组的高级用法

java
// 3.1 自定义Map类型
Map<String, Set<User>> usersByCitySet = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        TreeMap::new,  // 使用TreeMap按城市名排序
        Collectors.toSet()  // 使用Set去重(需要User重写equals/hashCode)
    ));

// 3.2 分组统计
Map<String, IntSummaryStatistics> ageStatsByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.summarizingInt(User::getAge)
    ));
// 可以获取每个城市的:count, sum, min, average, max

// 3.3 分组后过滤
Map<String, List<User>> filteredGroups = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.filtering(
            user -> user.getAge() > 25,  // 过滤条件
            Collectors.toList()
        )
    ));
// 每组中只保留年龄大于25的用户

// 3.4 分组扁平化处理
List<Order> orders = getOrders();
Map<String, List<String>> productsByCustomer = orders.stream()
    .collect(Collectors.groupingBy(
        Order::getCustomerName,
        Collectors.flatMapping(
            order -> order.getItems().stream(),  // 展平订单项
            Collectors.mapping(Item::getName, Collectors.toList())
        )
    ));

// 3.5 使用teeing进行复杂分组统计(Java 12+)
Map<String, CityStats> statsByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        Collectors.teeing(
            Collectors.counting(),  // 统计1:计数
            Collectors.averagingInt(User::getAge),  // 统计2:平均年龄
            (count, avgAge) -> new CityStats(count, avgAge)
        )
    ));

分区(PartitioningBy)

1. 基础分区

java
List<Student> students = Arrays.asList(
    new Student("Alice", 85),
    new Student("Bob", 92),
    new Student("Charlie", 78),
    new Student("David", 65),
    new Student("Eve", 88)
);

// 1.1 按是否及格分区
Map<Boolean, List<Student>> partitionedByPass = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 60
    ));

// 结果:
// {
//   true: [Alice(85), Bob(92), Charlie(78), David(65), Eve(88)],
//   false: []
// }

// 1.2 按是否优秀分区(80分以上)
Map<Boolean, List<Student>> partitionedByExcellent = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 80
    ));

// 1.3 复杂条件分区
Map<Boolean, List<Student>> partitionedByGrade = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 90 || student.getName().startsWith("A")
    ));

2. 分区后的下游操作

java
// 2.1 分区计数
Map<Boolean, Long> countByPass = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 60,
        Collectors.counting()
    ));
// 结果:{true: 5, false: 0}

// 2.2 分区求和
Map<Boolean, Integer> sumScoreByPass = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 60,
        Collectors.summingInt(Student::getScore)
    ));

// 2.3 分区求平均值
Map<Boolean, Double> avgScoreByPass = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 60,
        Collectors.averagingInt(Student::getScore)
    ));

// 2.4 分区映射(只获取姓名)
Map<Boolean, List<String>> namesByPass = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 60,
        Collectors.mapping(Student::getName, Collectors.toList())
    ));

// 2.5 多级分区(先分区,再分组)
Map<Boolean, Map<String, List<Student>>> complexPartition = students.stream()
    .collect(Collectors.partitioningBy(
        student -> student.getScore() >= 80,
        Collectors.groupingBy(student -> {
            if (student.getScore() >= 90) return "A";
            else if (student.getScore() >= 80) return "B";
            else if (student.getScore() >= 60) return "C";
            else return "D";
        })
    ));

实际应用场景

场景1:电商数据分析

java
public class EcommerceAnalyzer {
    
    // 按商品类别分组统计
    public Map<String, CategoryStats> analyzeSalesByCategory(List<Order> orders) {
        return orders.stream()
            .flatMap(order -> order.getItems().stream())
            .collect(Collectors.groupingBy(
                OrderItem::getCategory,
                Collectors.teeing(
                    Collectors.counting(),  // 销量
                    Collectors.summingDouble(item -> 
                        item.getPrice() * item.getQuantity()  // 销售额
                    ),
                    Collectors.mapping(
                        item -> item.getProductName(),
                        Collectors.toSet()  // 商品集合
                    ),
                    (count, revenue, products) -> 
                        new CategoryStats(count, revenue, products)
                )
            ));
    }
    
    // 按价格区间分组
    public Map<String, List<Product>> groupProductsByPriceRange(List<Product> products) {
        return products.stream()
            .collect(Collectors.groupingBy(product -> {
                double price = product.getPrice();
                if (price < 100) return "低价(<100)";
                else if (price < 500) return "中价(100-500)";
                else if (price < 2000) return "高价(500-2000)";
                else return "奢侈(>2000)";
            }));
    }
    
    // 分区:VIP客户和普通客户
    public Map<Boolean, CustomerStats> analyzeCustomerSegmentation(List<Customer> customers) {
        return customers.stream()
            .collect(Collectors.partitioningBy(
                customer -> customer.getTotalSpent() > 10000,  // VIP门槛
                Collectors.collectingAndThen(
                    Collectors.toList(),
                    list -> {
                        double avgOrderValue = list.stream()
                            .mapToDouble(Customer::getAvgOrderValue)
                            .average()
                            .orElse(0.0);
                        long count = list.size();
                        return new CustomerStats(count, avgOrderValue);
                    }
                )
            ));
    }
}

场景2:员工管理系统

java
public class EmployeeManager {
    
    // 按部门分组统计
    public Map<String, DepartmentStats> getDepartmentStatistics(List<Employee> employees) {
        return employees.stream()
            .collect(Collectors.groupingBy(
                Employee::getDepartment,
                Collectors.teeing(
                    Collectors.counting(),  // 人数
                    Collectors.averagingDouble(Employee::getSalary),  // 平均薪资
                    Collectors.mapping(
                        Employee::getLevel,
                        Collectors.toSet()  // 职级分布
                    ),
                    (count, avgSalary, levels) -> 
                        new DepartmentStats(count, avgSalary, levels)
                )
            ));
    }
    
    // 按薪资等级分组
    public Map<String, List<Employee>> groupBySalaryLevel(List<Employee> employees) {
        return employees.stream()
            .collect(Collectors.groupingBy(employee -> {
                double salary = employee.getSalary();
                if (salary < 10000) return "初级(<10k)";
                else if (salary < 30000) return "中级(10k-30k)";
                else if (salary < 80000) return "高级(30k-80k)";
                else return "专家(>80k)";
            }));
    }
    
    // 分区:在职和离职员工
    public Map<Boolean, EmployeeAnalysis> analyzeEmployeeStatus(List<Employee> employees) {
        return employees.stream()
            .collect(Collectors.partitioningBy(
                Employee::isActive,
                Collectors.collectingAndThen(
                    Collectors.toList(),
                    list -> {
                        double avgTenure = list.stream()
                            .mapToDouble(Employee::getTenureYears)
                            .average()
                            .orElse(0.0);
                        double avgSalary = list.stream()
                            .mapToDouble(Employee::getSalary)
                            .average()
                            .orElse(0.0);
                        return new EmployeeAnalysis(list.size(), avgTenure, avgSalary);
                    }
                )
            ));
    }
    
    // 多级分组:部门 -> 职级
    public Map<String, Map<String, List<Employee>>> groupByDeptAndLevel(List<Employee> employees) {
        return employees.stream()
            .collect(Collectors.groupingBy(
                Employee::getDepartment,
                Collectors.groupingBy(Employee::getLevel)
            ));
    }
}

场景3:日志分析

java
public class LogAnalyzer {
    
    // 按日志级别分组统计
    public Map<String, LogStats> analyzeLogLevels(List<LogEntry> logs) {
        return logs.stream()
            .collect(Collectors.groupingBy(
                LogEntry::getLevel,
                Collectors.teeing(
                    Collectors.counting(),  // 数量
                    Collectors.mapping(
                        LogEntry::getSource,
                        Collectors.toSet()  // 来源集合
                    ),
                    Collectors.mapping(
                        log -> log.getTimestamp().toLocalDate(),
                        Collectors.toSet()  // 日期集合
                    ),
                    (count, sources, dates) -> 
                        new LogStats(count, sources, dates)
                )
            ));
    }
    
    // 按时间段分组
    public Map<String, List<LogEntry>> groupLogsByHour(List<LogEntry> logs) {
        return logs.stream()
            .collect(Collectors.groupingBy(log -> {
                int hour = log.getTimestamp().getHour();
                return String.format("%02d:00-%02d:59", hour, hour);
            }));
    }
    
    // 分区:错误日志和非错误日志
    public Map<Boolean, ErrorAnalysis> analyzeErrors(List<LogEntry> logs) {
        return logs.stream()
            .collect(Collectors.partitioningBy(
                log -> log.getLevel().equals("ERROR") || log.getLevel().equals("FATAL"),
                Collectors.collectingAndThen(
                    Collectors.toList(),
                    list -> {
                        Map<String, Long> errorBySource = list.stream()
                            .collect(Collectors.groupingBy(
                                LogEntry::getSource,
                                Collectors.counting()
                            ));
                        return new ErrorAnalysis(list.size(), errorBySource);
                    }
                )
            ));
    }
    
    // 按IP地址分组,统计访问频率
    public Map<String, AccessStats> analyzeAccessPatterns(List<AccessLog> logs) {
        return logs.stream()
            .collect(Collectors.groupingBy(
                AccessLog::getIpAddress,
                Collectors.teeing(
                    Collectors.counting(),  // 访问次数
                    Collectors.mapping(
                        AccessLog::getEndpoint,
                        Collectors.toSet()  // 访问的端点
                    ),
                    Collectors.mapping(
                        AccessLog::getUserAgent,
                        Collectors.toSet()  // 用户代理
                    ),
                    (count, endpoints, userAgents) -> 
                        new AccessStats(count, endpoints, userAgents)
                )
            ));
    }
}

性能优化和最佳实践

java
// 分组时选择合适的Map实现
Map<String, List<User>> groupByCity = users.stream()
    .collect(Collectors.groupingBy(
        User::getCity,
        HashMap::new,  // 默认,无序
        // TreeMap::new,  // 按键排序
        // LinkedHashMap::new,  // 保持插入顺序
        Collectors.toList()
    ));

// 使用并发收集器提高并行流性能
Map<String, List<User>> concurrentGroup = users.parallelStream()
    .collect(Collectors.groupingByConcurrent(
        User::getCity,
        ConcurrentHashMap::new,  // 线程安全的Map
        Collectors.toList()
    ));