@@ -72,10 +72,10 @@ def fit(self, texts: list[str], labels: list[str]) -> None:
7272 ...
7373 ValueError: training data must not be empty.
7474 """
75- if len (texts ) != len (labels ):
76- raise ValueError ("texts and labels must have the same length." )
7775 if not texts :
7876 raise ValueError ("training data must not be empty." )
77+ if len (texts ) != len (labels ):
78+ raise ValueError ("texts and labels must have the same length." )
7979
8080 self .classes_ = sorted (set (labels ))
8181 self .vocabulary_ .clear ()
@@ -114,6 +114,11 @@ def predict_proba(self, text: str) -> dict[str, float]:
114114 >>> probs['spam'] > probs['ham']
115115 True
116116
117+ An empty input text has no tokens, so predictions fall back to class priors.
118+ >>> empty_probs = model.predict_proba("")
119+ >>> round(empty_probs['spam'], 3), round(empty_probs['ham'], 3)
120+ (0.5, 0.5)
121+
117122 >>> NaiveBayesTextClassifier().predict_proba("hello")
118123 Traceback (most recent call last):
119124 ...
@@ -159,7 +164,7 @@ def predict(self, text: str) -> str:
159164 'ham'
160165 """
161166 probabilities = self .predict_proba (text )
162- return max (probabilities , key = probabilities . get )
167+ return max (probabilities , key = lambda label : probabilities [ label ] )
163168
164169
165170def build_toy_dataset () -> tuple [list [str ], list [str ]]:
@@ -188,3 +193,10 @@ def build_toy_dataset() -> tuple[list[str], list[str]]:
188193 import doctest
189194
190195 doctest .testmod ()
196+
197+ sample_texts , sample_labels = build_toy_dataset ()
198+ classifier = NaiveBayesTextClassifier (alpha = 1.0 )
199+ classifier .fit (sample_texts , sample_labels )
200+
201+ print ("Prediction:" ,classifier .predict ("cheap prizes available now" ))
202+ print ("Prediction:" ,classifier .predict ("team meeting about project timeline" ))
0 commit comments