@@ -8,10 +8,13 @@ import qualified DataFrame as D
88import DataFrame.DecisionTree
99import qualified DataFrame.Functions as F
1010import qualified DataFrame.Internal.Column as DI
11- import DataFrame.Internal.Expression (Expr )
11+ import DataFrame.Internal.Expression (Expr (.. ))
12+ import DataFrame.Internal.Interpreter (interpret )
1213import DataFrame.Operators
1314
14- import Data.List (sort )
15+ import Data.Function (on )
16+ import Data.List (maximumBy , sort )
17+ import qualified Data.Map.Strict as M
1518import qualified Data.Text as T
1619import qualified Data.Vector as V
1720import Test.HUnit
@@ -559,6 +562,158 @@ numericExprsWithTermsMixedTest = TestCase $ do
559562 " combined exprs include NMaybeDouble (nullable arithmetic)"
560563 (any (\ case NMaybeDouble _ -> True ; _ -> False ) exprs)
561564
565+ ------------------------------------------------------------------------
566+ -- Probability tree tests
567+ ------------------------------------------------------------------------
568+
569+ -- probsFromIndices: counts correct on a 3-row slice
570+ probsFromIndicesBasic :: Test
571+ probsFromIndicesBasic = TestCase $ do
572+ let df =
573+ D. fromNamedColumns
574+ [ (" label" , DI. fromList ([" A" , " A" , " B" ] :: [T. Text ]))
575+ , (" x" , DI. fromList ([1.0 , 2.0 , 3.0 ] :: [Double ]))
576+ ]
577+ probs = probsFromIndices @ T. Text " label" df (V. fromList [0 , 1 , 2 ])
578+ assertBool " A prob ≈ 2/3" (abs (probs M. ! " A" - 2 / 3 ) < 1e-9 )
579+ assertBool " B prob ≈ 1/3" (abs (probs M. ! " B" - 1 / 3 ) < 1e-9 )
580+
581+ -- probsFromIndices: only a subset of rows counted
582+ probsFromIndicesSubset :: Test
583+ probsFromIndicesSubset = TestCase $ do
584+ let df =
585+ D. fromNamedColumns
586+ [ (" label" , DI. fromList ([" A" , " A" , " B" , " B" ] :: [T. Text ]))
587+ , (" x" , DI. fromList ([1.0 , 2.0 , 3.0 , 4.0 ] :: [Double ]))
588+ ]
589+ probs = probsFromIndices @ T. Text " label" df (V. fromList [0 , 1 ])
590+ assertEqual " only rows 0,1 → A:1.0" (M. fromList [(" A" , 1.0 )]) probs
591+
592+ -- probsFromIndices: single class → probability 1.0
593+ probsFromIndicesSingleClass :: Test
594+ probsFromIndicesSingleClass = TestCase $ do
595+ let probs = probsFromIndices @ T. Text " label" fixtureDF (V. fromList [0 , 2 ])
596+ assertEqual " rows 0,2 both A → A:1.0" (M. fromList [(" A" , 1.0 )]) probs
597+
598+ -- buildProbTree: Leaf preserves distribution
599+ buildProbTreeLeaf :: Test
600+ buildProbTreeLeaf = TestCase $ do
601+ let df =
602+ D. fromNamedColumns
603+ [ (" label" , DI. fromList ([" A" , " A" , " A" ] :: [T. Text ]))
604+ , (" x" , DI. fromList ([1.0 , 2.0 , 3.0 ] :: [Double ]))
605+ ]
606+ pt = buildProbTree @ T. Text (Leaf " A" ) " label" df (V. fromList [0 , 1 , 2 ])
607+ case pt of
608+ Leaf m -> assertEqual " pure-A leaf → {A:1.0}" (M. fromList [(" A" , 1.0 )]) m
609+ _ -> assertFailure " expected Leaf"
610+
611+ -- buildProbTree: Branch distributes rows to left/right leaves correctly
612+ buildProbTreeBranch :: Test
613+ buildProbTreeBranch = TestCase $ do
614+ -- splitCond: x <= 2.5 → idx 0,1 go left; idx 2,3 go right
615+ -- left (idx 0,1): labels ["A","B"] → {A:0.5, B:0.5}
616+ -- right (idx 2,3): labels ["A","C"] → {A:0.5, C:0.5}
617+ let stump = Branch splitCond (Leaf " A" ) (Leaf " B" ) :: Tree T. Text
618+ pt = buildProbTree @ T. Text stump " label" fixtureDF allIndices
619+ case pt of
620+ Branch _ (Leaf lm) (Leaf rm) -> do
621+ assertBool " left leaf has A:0.5" (abs (M. findWithDefault 0 " A" lm - 0.5 ) < 1e-9 )
622+ assertBool " left leaf has B:0.5" (abs (M. findWithDefault 0 " B" lm - 0.5 ) < 1e-9 )
623+ assertBool
624+ " right leaf has A:0.5"
625+ (abs (M. findWithDefault 0 " A" rm - 0.5 ) < 1e-9 )
626+ assertBool
627+ " right leaf has C:0.5"
628+ (abs (M. findWithDefault 0 " C" rm - 0.5 ) < 1e-9 )
629+ _ -> assertFailure " expected Branch with two Leaves"
630+
631+ -- probExprs: leaf tree produces Lit values
632+ probExprsLeaf :: Test
633+ probExprsLeaf = TestCase $ do
634+ let pt = Leaf (M. fromList [(" A" , 0.75 ), (" B" , 0.25 )]) :: ProbTree T. Text
635+ pe = probExprs pt
636+ assertEqual " A expr is Lit 0.75" (Lit 0.75 ) (pe M. ! " A" )
637+ assertEqual " B expr is Lit 0.25" (Lit 0.25 ) (pe M. ! " B" )
638+
639+ -- probExprs: class absent from one leaf gets Lit 0.0 on that side
640+ probExprsMissingClass :: Test
641+ probExprsMissingClass = TestCase $ do
642+ let pt =
643+ Branch
644+ splitCond
645+ (Leaf (M. fromList [(" A" , 1.0 )]))
646+ (Leaf (M. fromList [(" B" , 1.0 )])) ::
647+ ProbTree T. Text
648+ pe = probExprs pt
649+ assertEqual
650+ " A expr: If cond (Lit 1.0) (Lit 0.0)"
651+ (F. ifThenElse splitCond (Lit 1.0 ) (Lit 0.0 ))
652+ (pe M. ! " A" )
653+ assertEqual
654+ " B expr: If cond (Lit 0.0) (Lit 1.0)"
655+ (F. ifThenElse splitCond (Lit 0.0 ) (Lit 1.0 ))
656+ (pe M. ! " B" )
657+
658+ -- probExprs: keys equal all classes that appear across any leaf
659+ probExprsAllClasses :: Test
660+ probExprsAllClasses = TestCase $ do
661+ let pt =
662+ Branch
663+ splitCond
664+ (Leaf (M. fromList [(" A" , 1.0 )]))
665+ (Leaf (M. fromList [(" B" , 0.6 ), (" C" , 0.4 )])) ::
666+ ProbTree T. Text
667+ pe = probExprs pt
668+ assertEqual " three classes in result" (sort [" A" , " B" , " C" ]) (sort (M. keys pe))
669+
670+ -- Probabilities sum to 1.0 at every row after applying probExprs
671+ probsSumToOne :: Test
672+ probsSumToOne = TestCase $ do
673+ let stump = Branch splitCond (Leaf " A" ) (Leaf " B" ) :: Tree T. Text
674+ pt = buildProbTree @ T. Text stump " label" fixtureDF allIndices
675+ pe = probExprs pt
676+ sumExpr = foldl1 (.+) (M. elems pe)
677+ case interpret @ Double fixtureDF sumExpr of
678+ Left e -> assertFailure (show e)
679+ Right (DI. TColumn col) ->
680+ case DI. toVector @ Double col of
681+ Left e -> assertFailure (show e)
682+ Right vals ->
683+ mapM_
684+ (\ v -> assertBool (" sum ≈ 1.0, got " ++ show v) (abs (v - 1.0 ) < 1e-9 ))
685+ (V. toList vals)
686+
687+ -- argmax of probExprs agrees with fitDecisionTree on sepDF
688+ probArgmaxMatchesClassifier :: Test
689+ probArgmaxMatchesClassifier = TestCase $ do
690+ let cfg = defaultTreeConfig{taoIterations = 5 , expressionPairs = 4 , minLeafSize = 1 }
691+ hardExpr = fitDecisionTree @ T. Text cfg (Col " label" ) sepDF
692+ pe = fitProbTree @ T. Text cfg (Col " label" ) sepDF
693+ indices = [0 .. D. nRows sepDF - 1 ]
694+ case interpret @ T. Text sepDF hardExpr of
695+ Left e -> assertFailure (show e)
696+ Right (DI. TColumn hardCol) ->
697+ case DI. toVector @ T. Text hardCol of
698+ Left e -> assertFailure (show e)
699+ Right hardVals -> do
700+ probCols <-
701+ mapM
702+ ( \ (cls, expr) -> case interpret @ Double sepDF expr of
703+ Left e -> assertFailure (show e) >> return (cls, V. empty)
704+ Right (DI. TColumn col) -> case DI. toVector @ Double col of
705+ Left e -> assertFailure (show e) >> return (cls, V. empty)
706+ Right v -> return (cls, v)
707+ )
708+ (M. toList pe)
709+ mapM_
710+ ( \ i ->
711+ let argmax = fst $ maximumBy (compare `on` (V. ! i) . snd ) probCols
712+ hard = hardVals V. ! i
713+ in assertEqual (" row " ++ show i) hard argmax
714+ )
715+ indices
716+
562717------------------------------------------------------------------------
563718-- Test list
564719------------------------------------------------------------------------
@@ -596,4 +751,14 @@ tests =
596751 , TestLabel " nullableFitZeroLoss" nullableFitZeroLossTest
597752 , TestLabel " nullableFitWithNullsNoCrash" nullableFitWithNullsNoCrashTest
598753 , TestLabel " numericExprsWithTermsMixed" numericExprsWithTermsMixedTest
754+ , TestLabel " probsFromIndicesBasic" probsFromIndicesBasic
755+ , TestLabel " probsFromIndicesSubset" probsFromIndicesSubset
756+ , TestLabel " probsFromIndicesSingleClass" probsFromIndicesSingleClass
757+ , TestLabel " buildProbTreeLeaf" buildProbTreeLeaf
758+ , TestLabel " buildProbTreeBranch" buildProbTreeBranch
759+ , TestLabel " probExprsLeaf" probExprsLeaf
760+ , TestLabel " probExprsMissingClass" probExprsMissingClass
761+ , TestLabel " probExprsAllClasses" probExprsAllClasses
762+ , TestLabel " probsSumToOne" probsSumToOne
763+ , TestLabel " probArgmaxMatchesClassifier" probArgmaxMatchesClassifier
599764 ]
0 commit comments