88https://en.wikipedia.org/wiki/Naive_Bayes_classifier
99"""
1010
11- from typing import List , Dict
11+ from typing import Dict , List , Tuple
1212import math
1313
1414
@@ -21,11 +21,12 @@ def gaussian_probability(x: float, mean: float, variance: float) -> float:
2121 >>> gaussian_probability(1.0, 1.0, 0.0)
2222 0.0
2323 """
24- if variance == 0 :
24+ if variance == 0.0 :
2525 return 0.0
2626
27- exponent = math .exp (- ((x - mean ) ** 2 ) / (2 * variance ))
28- return (1 / math .sqrt (2 * math .pi * variance )) * exponent
27+ exponent = math .exp (- ((x - mean ) ** 2 ) / (2.0 * variance ))
28+ coefficient = 1.0 / math .sqrt (2.0 * math .pi * variance )
29+ return coefficient * exponent
2930
3031
3132class GaussianNaiveBayes :
@@ -61,12 +62,11 @@ def fit(self, features: List[List[float]], labels: List[int]) -> None:
6162 for label , rows in separated .items ():
6263 self .class_priors [label ] = len (rows ) / total_samples
6364
64- transposed = list (zip (* rows ))
65- self .means [label ] = [sum (col ) / len (col ) for col in transposed ]
66-
65+ columns = list (zip (* rows ))
66+ self .means [label ] = [sum (col ) / len (col ) for col in columns ]
6767 self .variances [label ] = [
6868 sum ((x - mean ) ** 2 for x in col ) / len (col )
69- for col , mean in zip (transposed , self .means [label ])
69+ for col , mean in zip (columns , self .means [label ])
7070 ]
7171
7272 def predict (self , features : List [List [float ]]) -> List [int ]:
@@ -86,25 +86,22 @@ def predict(self, features: List[List[float]]) -> List[int]:
8686 predictions : List [int ] = []
8787
8888 for row in features :
89- class_scores : Dict [ int , float ] = {}
89+ scores : List [ Tuple [ int , float ]] = []
9090
9191 for label in self .class_priors :
92- score = math .log (self .class_priors [label ])
92+ log_likelihood = math .log (self .class_priors [label ])
9393
9494 for index , value in enumerate (row ):
95- mean = self .means [label ][index ]
96- variance = self .variances [label ][index ]
97- probability = gaussian_probability (value , mean , variance )
98-
99- if probability > 0 :
100- score += math .log (probability )
95+ probability = gaussian_probability (
96+ value ,
97+ self .means [label ][index ],
98+ self .variances [label ][index ],
99+ )
100+ if probability > 0.0 :
101+ log_likelihood += math .log (probability )
101102
102- class_scores [ label ] = score
103+ scores . append (( label , log_likelihood ))
103104
104- predicted_label = max (
105- class_scores .items (),
106- key = lambda item : item [1 ],
107- )[0 ]
108- predictions .append (predicted_label )
105+ predictions .append (max (scores , key = lambda pair : pair [1 ])[0 ])
109106
110107 return predictions
0 commit comments