|
1 | | -use std::collections::HashMap; |
2 | | - |
3 | | -use crate::core; |
4 | | -use crate::review::normalize_rule_id; |
5 | | - |
6 | | -use super::super::{EvalFixtureResult, EvalPattern, EvalRuleMetrics, EvalRuleScoreSummary}; |
7 | | - |
8 | | -#[derive(Debug, Default, Clone, Copy)] |
9 | | -struct RuleMetricCounts { |
10 | | - expected: usize, |
11 | | - predicted: usize, |
12 | | - true_positives: usize, |
13 | | -} |
14 | | - |
15 | | -pub(in super::super) fn compute_rule_metrics( |
16 | | - expected_patterns: &[EvalPattern], |
17 | | - comments: &[core::Comment], |
18 | | - matched_pairs: &[(usize, usize)], |
19 | | -) -> Vec<EvalRuleMetrics> { |
20 | | - let mut counts_by_rule: HashMap<String, RuleMetricCounts> = HashMap::new(); |
21 | | - |
22 | | - for pattern in expected_patterns { |
23 | | - if let Some(rule_id) = pattern.normalized_rule_id() { |
24 | | - counts_by_rule.entry(rule_id).or_default().expected += 1; |
25 | | - } |
26 | | - } |
27 | | - |
28 | | - for comment in comments { |
29 | | - if let Some(rule_id) = normalize_rule_id(comment.rule_id.as_deref()) { |
30 | | - counts_by_rule.entry(rule_id).or_default().predicted += 1; |
31 | | - } |
32 | | - } |
33 | | - |
34 | | - for (expected_idx, comment_idx) in matched_pairs { |
35 | | - let expected_rule = expected_patterns |
36 | | - .get(*expected_idx) |
37 | | - .and_then(EvalPattern::normalized_rule_id); |
38 | | - let predicted_rule = comments |
39 | | - .get(*comment_idx) |
40 | | - .and_then(|comment| normalize_rule_id(comment.rule_id.as_deref())); |
41 | | - if let (Some(expected_rule), Some(predicted_rule)) = (expected_rule, predicted_rule) { |
42 | | - if expected_rule == predicted_rule { |
43 | | - counts_by_rule |
44 | | - .entry(expected_rule) |
45 | | - .or_default() |
46 | | - .true_positives += 1; |
47 | | - } |
48 | | - } |
49 | | - } |
50 | | - |
51 | | - build_rule_metrics_from_counts(&counts_by_rule) |
52 | | -} |
53 | | - |
54 | | -pub(in super::super) fn aggregate_rule_metrics( |
55 | | - results: &[EvalFixtureResult], |
56 | | -) -> Vec<EvalRuleMetrics> { |
57 | | - let mut counts_by_rule: HashMap<String, RuleMetricCounts> = HashMap::new(); |
58 | | - for result in results { |
59 | | - for metric in &result.rule_metrics { |
60 | | - let counts = counts_by_rule.entry(metric.rule_id.clone()).or_default(); |
61 | | - counts.expected = counts.expected.saturating_add(metric.expected); |
62 | | - counts.predicted = counts.predicted.saturating_add(metric.predicted); |
63 | | - counts.true_positives = counts.true_positives.saturating_add(metric.true_positives); |
64 | | - } |
65 | | - } |
66 | | - |
67 | | - build_rule_metrics_from_counts(&counts_by_rule) |
68 | | -} |
69 | | - |
70 | | -pub(in super::super) fn summarize_rule_metrics( |
71 | | - metrics: &[EvalRuleMetrics], |
72 | | -) -> Option<EvalRuleScoreSummary> { |
73 | | - if metrics.is_empty() { |
74 | | - return None; |
75 | | - } |
76 | | - |
77 | | - let mut tp_sum = 0usize; |
78 | | - let mut predicted_sum = 0usize; |
79 | | - let mut expected_sum = 0usize; |
80 | | - let mut precision_sum = 0.0f32; |
81 | | - let mut recall_sum = 0.0f32; |
82 | | - let mut f1_sum = 0.0f32; |
83 | | - |
84 | | - for metric in metrics { |
85 | | - tp_sum = tp_sum.saturating_add(metric.true_positives); |
86 | | - predicted_sum = predicted_sum.saturating_add(metric.predicted); |
87 | | - expected_sum = expected_sum.saturating_add(metric.expected); |
88 | | - precision_sum += metric.precision; |
89 | | - recall_sum += metric.recall; |
90 | | - f1_sum += metric.f1; |
91 | | - } |
92 | | - |
93 | | - let micro_precision = if predicted_sum > 0 { |
94 | | - tp_sum as f32 / predicted_sum as f32 |
95 | | - } else { |
96 | | - 0.0 |
97 | | - }; |
98 | | - let micro_recall = if expected_sum > 0 { |
99 | | - tp_sum as f32 / expected_sum as f32 |
100 | | - } else { |
101 | | - 0.0 |
102 | | - }; |
103 | | - let micro_f1 = harmonic_mean(micro_precision, micro_recall); |
104 | | - let count = metrics.len() as f32; |
105 | | - |
106 | | - Some(EvalRuleScoreSummary { |
107 | | - micro_precision, |
108 | | - micro_recall, |
109 | | - micro_f1, |
110 | | - macro_precision: precision_sum / count, |
111 | | - macro_recall: recall_sum / count, |
112 | | - macro_f1: f1_sum / count, |
113 | | - }) |
114 | | -} |
115 | | - |
116 | | -fn build_rule_metrics_from_counts( |
117 | | - counts_by_rule: &HashMap<String, RuleMetricCounts>, |
118 | | -) -> Vec<EvalRuleMetrics> { |
119 | | - let mut metrics = Vec::new(); |
120 | | - for (rule_id, counts) in counts_by_rule { |
121 | | - let false_positives = counts.predicted.saturating_sub(counts.true_positives); |
122 | | - let false_negatives = counts.expected.saturating_sub(counts.true_positives); |
123 | | - let precision = if counts.predicted > 0 { |
124 | | - counts.true_positives as f32 / counts.predicted as f32 |
125 | | - } else { |
126 | | - 0.0 |
127 | | - }; |
128 | | - let recall = if counts.expected > 0 { |
129 | | - counts.true_positives as f32 / counts.expected as f32 |
130 | | - } else { |
131 | | - 0.0 |
132 | | - }; |
133 | | - let f1 = harmonic_mean(precision, recall); |
134 | | - |
135 | | - metrics.push(EvalRuleMetrics { |
136 | | - rule_id: rule_id.clone(), |
137 | | - expected: counts.expected, |
138 | | - predicted: counts.predicted, |
139 | | - true_positives: counts.true_positives, |
140 | | - false_positives, |
141 | | - false_negatives, |
142 | | - precision, |
143 | | - recall, |
144 | | - f1, |
145 | | - }); |
146 | | - } |
147 | | - |
148 | | - metrics.sort_by(|left, right| { |
149 | | - right |
150 | | - .expected |
151 | | - .cmp(&left.expected) |
152 | | - .then_with(|| right.predicted.cmp(&left.predicted)) |
153 | | - .then_with(|| left.rule_id.cmp(&right.rule_id)) |
154 | | - }); |
155 | | - metrics |
156 | | -} |
157 | | - |
158 | | -fn harmonic_mean(precision: f32, recall: f32) -> f32 { |
159 | | - if precision + recall <= f32::EPSILON { |
160 | | - 0.0 |
161 | | - } else { |
162 | | - (2.0 * precision * recall) / (precision + recall) |
163 | | - } |
164 | | -} |
| 1 | +#[path = "rules/build.rs"] |
| 2 | +mod build; |
| 3 | +#[path = "rules/compute.rs"] |
| 4 | +mod compute; |
| 5 | +#[path = "rules/counts.rs"] |
| 6 | +mod counts; |
| 7 | +#[path = "rules/summary.rs"] |
| 8 | +mod summary; |
| 9 | + |
| 10 | +pub(in super::super) use compute::{aggregate_rule_metrics, compute_rule_metrics}; |
| 11 | +pub(in super::super) use summary::summarize_rule_metrics; |
0 commit comments