Skip to content

Commit 6dd885c

Browse files
Add Gaussian Naive Bayes classifier
1 parent 3c5e410 commit 6dd885c

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

machine_learning/naive_bayes.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
https://en.wikipedia.org/wiki/Naive_Bayes_classifier
99
"""
1010

11-
from typing import List, Dict
11+
from typing import Dict, List, Tuple
1212
import math
1313

1414

@@ -21,11 +21,12 @@ def gaussian_probability(x: float, mean: float, variance: float) -> float:
2121
>>> gaussian_probability(1.0, 1.0, 0.0)
2222
0.0
2323
"""
24-
if variance == 0:
24+
if variance == 0.0:
2525
return 0.0
2626

27-
exponent = math.exp(-((x - mean) ** 2) / (2 * variance))
28-
return (1 / math.sqrt(2 * math.pi * variance)) * exponent
27+
exponent = math.exp(-((x - mean) ** 2) / (2.0 * variance))
28+
coefficient = 1.0 / math.sqrt(2.0 * math.pi * variance)
29+
return coefficient * exponent
2930

3031

3132
class GaussianNaiveBayes:
@@ -61,12 +62,11 @@ def fit(self, features: List[List[float]], labels: List[int]) -> None:
6162
for label, rows in separated.items():
6263
self.class_priors[label] = len(rows) / total_samples
6364

64-
transposed = list(zip(*rows))
65-
self.means[label] = [sum(col) / len(col) for col in transposed]
66-
65+
columns = list(zip(*rows))
66+
self.means[label] = [sum(col) / len(col) for col in columns]
6767
self.variances[label] = [
6868
sum((x - mean) ** 2 for x in col) / len(col)
69-
for col, mean in zip(transposed, self.means[label])
69+
for col, mean in zip(columns, self.means[label])
7070
]
7171

7272
def predict(self, features: List[List[float]]) -> List[int]:
@@ -86,25 +86,22 @@ def predict(self, features: List[List[float]]) -> List[int]:
8686
predictions: List[int] = []
8787

8888
for row in features:
89-
class_scores: Dict[int, float] = {}
89+
scores: List[Tuple[int, float]] = []
9090

9191
for label in self.class_priors:
92-
score = math.log(self.class_priors[label])
92+
log_likelihood = math.log(self.class_priors[label])
9393

9494
for index, value in enumerate(row):
95-
mean = self.means[label][index]
96-
variance = self.variances[label][index]
97-
probability = gaussian_probability(value, mean, variance)
98-
99-
if probability > 0:
100-
score += math.log(probability)
95+
probability = gaussian_probability(
96+
value,
97+
self.means[label][index],
98+
self.variances[label][index],
99+
)
100+
if probability > 0.0:
101+
log_likelihood += math.log(probability)
101102

102-
class_scores[label] = score
103+
scores.append((label, log_likelihood))
103104

104-
predicted_label = max(
105-
class_scores.items(),
106-
key=lambda item: item[1],
107-
)[0]
108-
predictions.append(predicted_label)
105+
predictions.append(max(scores, key=lambda pair: pair[1])[0])
109106

110107
return predictions

0 commit comments

Comments
 (0)