Skip to content

Commit 00fe186

Browse files
committed
feat: Emit probabilities with decision tree.
1 parent ec5453e commit 00fe186

3 files changed

Lines changed: 288 additions & 2 deletions

File tree

src/DataFrame/DecisionTree.hs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,117 @@ findBestSplit = findBestGreedySplit @a
804804

805805
pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
806806
pruneTree = pruneExpr
807+
808+
-- | A tree where each leaf stores a class-probability distribution.
809+
type ProbTree a = Tree (M.Map a Double)
810+
811+
-- | Compute normalised class probabilities from a subset of training rows.
812+
probsFromIndices ::
813+
forall a.
814+
(Columnable a) =>
815+
T.Text ->
816+
DataFrame ->
817+
V.Vector Int ->
818+
M.Map a Double
819+
probsFromIndices target df indices =
820+
case interpret @a df (Col target) of
821+
Left _ -> M.empty
822+
Right (TColumn col) ->
823+
case toVector @a col of
824+
Left _ -> M.empty
825+
Right vals ->
826+
let counts =
827+
V.foldl'
828+
(\acc i -> M.insertWith (+) (vals V.! i) 1 acc)
829+
M.empty
830+
indices
831+
total = fromIntegral (V.length indices) :: Double
832+
in M.map (\c -> fromIntegral c / total) counts
833+
834+
{- | Annotate a fitted 'Tree a' with class distributions by routing the
835+
training data through it. The split conditions are preserved; only the
836+
leaf values change from a majority label to a probability map.
837+
-}
838+
buildProbTree ::
839+
forall a.
840+
(Columnable a) =>
841+
Tree a ->
842+
T.Text ->
843+
DataFrame ->
844+
V.Vector Int ->
845+
ProbTree a
846+
buildProbTree (Leaf _) target df indices =
847+
Leaf (probsFromIndices @a target df indices)
848+
buildProbTree (Branch cond left right) target df indices =
849+
let (indicesL, indicesR) = partitionIndices cond df indices
850+
in Branch
851+
cond
852+
(buildProbTree @a left target df indicesL)
853+
(buildProbTree @a right target df indicesR)
854+
855+
{- | Fit a TAO decision tree and return one @Expr Double@ per class.
856+
857+
Each @(c, e)@ pair in the result map means: evaluate @e@ on a 'DataFrame'
858+
row to get the predicted probability of class @c@. You can insert these
859+
as new columns with 'derive' or evaluate them with 'interpret'.
860+
861+
Example:
862+
@
863+
let pes = fitProbTree \@T.Text cfg (Col \"species\") trainDf
864+
-- pes M.! \"setosa\" :: Expr Double
865+
df' = M.foldlWithKey' (\\d cls e -> D.derive (cls <> \"_prob\") e d) testDf pes
866+
@
867+
-}
868+
fitProbTree ::
869+
forall a.
870+
(Columnable a) =>
871+
TreeConfig ->
872+
Expr a -> -- target column, e.g. @Col \"label\"@
873+
DataFrame ->
874+
M.Map a (Expr Double)
875+
fitProbTree cfg (Col target) df =
876+
let
877+
conds =
878+
nubOrd $
879+
numericConditions cfg (exclude [target] df)
880+
++ generateConditionsOld cfg (exclude [target] df)
881+
initialTree = buildGreedyTree @a cfg (maxTreeDepth cfg) target conds df
882+
indices = V.enumFromN 0 (nRows df)
883+
optimizedTree = taoOptimize @a cfg target conds df indices initialTree
884+
pruned = pruneDead optimizedTree
885+
in
886+
probExprs (buildProbTree @a pruned target df indices)
887+
fitProbTree _ expr _ =
888+
error $ "Cannot create prob tree for compound expression: " ++ show expr
889+
890+
{- | Convert a 'ProbTree' into one 'Expr Double' per class.
891+
892+
Each @(c, e)@ pair means: evaluate @e@ on a 'DataFrame' row to get the
893+
predicted probability of class @c@. You can insert these as new columns
894+
with 'derive' or evaluate them with 'interpret'.
895+
896+
Example:
897+
@
898+
let pt = fitProbTree \@T.Text cfg (Col \"species\") trainDf
899+
pes = probExprs pt
900+
-- pes M.! \"setosa\" :: Expr Double
901+
df' = M.foldlWithKey' (\\d cls e -> D.derive (cls <> \"_prob\") e d) testDf pes
902+
@
903+
-}
904+
probExprs ::
905+
forall a.
906+
(Columnable a) =>
907+
ProbTree a ->
908+
M.Map a (Expr Double)
909+
probExprs tree =
910+
let classes = nubOrd (allClasses tree)
911+
in M.fromList [(c, classExpr c tree) | c <- classes]
912+
where
913+
allClasses :: ProbTree a -> [a]
914+
allClasses (Leaf m) = M.keys m
915+
allClasses (Branch _ l r) = allClasses l ++ allClasses r
916+
917+
classExpr :: a -> ProbTree a -> Expr Double
918+
classExpr c (Leaf m) = Lit (M.findWithDefault 0.0 c m)
919+
classExpr c (Branch cond l r) =
920+
F.ifThenElse cond (classExpr c l) (classExpr c r)

src/DataFrame/Operators.hs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ lift2Decorated f name rep comm prec =
9292
}
9393
)
9494

95+
(.==.) ::
96+
(Columnable a) =>
97+
Expr a ->
98+
Expr a ->
99+
Expr Bool
100+
(.==.) = lift2Decorated (==) "eq" (Just ".==.") True 4
101+
95102
-- Nullable-aware arithmetic operators
96103

97104
{- | Nullable-aware addition. Works for all combinations of nullable\/non-nullable operands.

tests/DecisionTree.hs

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ import qualified DataFrame as D
88
import DataFrame.DecisionTree
99
import qualified DataFrame.Functions as F
1010
import qualified DataFrame.Internal.Column as DI
11-
import DataFrame.Internal.Expression (Expr)
11+
import DataFrame.Internal.Expression (Expr (..))
12+
import DataFrame.Internal.Interpreter (interpret)
1213
import DataFrame.Operators
1314

14-
import Data.List (sort)
15+
import Data.Function (on)
16+
import Data.List (maximumBy, sort)
17+
import qualified Data.Map.Strict as M
1518
import qualified Data.Text as T
1619
import qualified Data.Vector as V
1720
import Test.HUnit
@@ -559,6 +562,158 @@ numericExprsWithTermsMixedTest = TestCase $ do
559562
"combined exprs include NMaybeDouble (nullable arithmetic)"
560563
(any (\case NMaybeDouble _ -> True; _ -> False) exprs)
561564

565+
------------------------------------------------------------------------
566+
-- Probability tree tests
567+
------------------------------------------------------------------------
568+
569+
-- probsFromIndices: counts correct on a 3-row slice
570+
probsFromIndicesBasic :: Test
571+
probsFromIndicesBasic = TestCase $ do
572+
let df =
573+
D.fromNamedColumns
574+
[ ("label", DI.fromList (["A", "A", "B"] :: [T.Text]))
575+
, ("x", DI.fromList ([1.0, 2.0, 3.0] :: [Double]))
576+
]
577+
probs = probsFromIndices @T.Text "label" df (V.fromList [0, 1, 2])
578+
assertBool "A prob ≈ 2/3" (abs (probs M.! "A" - 2 / 3) < 1e-9)
579+
assertBool "B prob ≈ 1/3" (abs (probs M.! "B" - 1 / 3) < 1e-9)
580+
581+
-- probsFromIndices: only a subset of rows counted
582+
probsFromIndicesSubset :: Test
583+
probsFromIndicesSubset = TestCase $ do
584+
let df =
585+
D.fromNamedColumns
586+
[ ("label", DI.fromList (["A", "A", "B", "B"] :: [T.Text]))
587+
, ("x", DI.fromList ([1.0, 2.0, 3.0, 4.0] :: [Double]))
588+
]
589+
probs = probsFromIndices @T.Text "label" df (V.fromList [0, 1])
590+
assertEqual "only rows 0,1 → A:1.0" (M.fromList [("A", 1.0)]) probs
591+
592+
-- probsFromIndices: single class → probability 1.0
593+
probsFromIndicesSingleClass :: Test
594+
probsFromIndicesSingleClass = TestCase $ do
595+
let probs = probsFromIndices @T.Text "label" fixtureDF (V.fromList [0, 2])
596+
assertEqual "rows 0,2 both A → A:1.0" (M.fromList [("A", 1.0)]) probs
597+
598+
-- buildProbTree: Leaf preserves distribution
599+
buildProbTreeLeaf :: Test
600+
buildProbTreeLeaf = TestCase $ do
601+
let df =
602+
D.fromNamedColumns
603+
[ ("label", DI.fromList (["A", "A", "A"] :: [T.Text]))
604+
, ("x", DI.fromList ([1.0, 2.0, 3.0] :: [Double]))
605+
]
606+
pt = buildProbTree @T.Text (Leaf "A") "label" df (V.fromList [0, 1, 2])
607+
case pt of
608+
Leaf m -> assertEqual "pure-A leaf → {A:1.0}" (M.fromList [("A", 1.0)]) m
609+
_ -> assertFailure "expected Leaf"
610+
611+
-- buildProbTree: Branch distributes rows to left/right leaves correctly
612+
buildProbTreeBranch :: Test
613+
buildProbTreeBranch = TestCase $ do
614+
-- splitCond: x <= 2.5 → idx 0,1 go left; idx 2,3 go right
615+
-- left (idx 0,1): labels ["A","B"] → {A:0.5, B:0.5}
616+
-- right (idx 2,3): labels ["A","C"] → {A:0.5, C:0.5}
617+
let stump = Branch splitCond (Leaf "A") (Leaf "B") :: Tree T.Text
618+
pt = buildProbTree @T.Text stump "label" fixtureDF allIndices
619+
case pt of
620+
Branch _ (Leaf lm) (Leaf rm) -> do
621+
assertBool "left leaf has A:0.5" (abs (M.findWithDefault 0 "A" lm - 0.5) < 1e-9)
622+
assertBool "left leaf has B:0.5" (abs (M.findWithDefault 0 "B" lm - 0.5) < 1e-9)
623+
assertBool
624+
"right leaf has A:0.5"
625+
(abs (M.findWithDefault 0 "A" rm - 0.5) < 1e-9)
626+
assertBool
627+
"right leaf has C:0.5"
628+
(abs (M.findWithDefault 0 "C" rm - 0.5) < 1e-9)
629+
_ -> assertFailure "expected Branch with two Leaves"
630+
631+
-- probExprs: leaf tree produces Lit values
632+
probExprsLeaf :: Test
633+
probExprsLeaf = TestCase $ do
634+
let pt = Leaf (M.fromList [("A", 0.75), ("B", 0.25)]) :: ProbTree T.Text
635+
pe = probExprs pt
636+
assertEqual "A expr is Lit 0.75" (Lit 0.75) (pe M.! "A")
637+
assertEqual "B expr is Lit 0.25" (Lit 0.25) (pe M.! "B")
638+
639+
-- probExprs: class absent from one leaf gets Lit 0.0 on that side
640+
probExprsMissingClass :: Test
641+
probExprsMissingClass = TestCase $ do
642+
let pt =
643+
Branch
644+
splitCond
645+
(Leaf (M.fromList [("A", 1.0)]))
646+
(Leaf (M.fromList [("B", 1.0)])) ::
647+
ProbTree T.Text
648+
pe = probExprs pt
649+
assertEqual
650+
"A expr: If cond (Lit 1.0) (Lit 0.0)"
651+
(F.ifThenElse splitCond (Lit 1.0) (Lit 0.0))
652+
(pe M.! "A")
653+
assertEqual
654+
"B expr: If cond (Lit 0.0) (Lit 1.0)"
655+
(F.ifThenElse splitCond (Lit 0.0) (Lit 1.0))
656+
(pe M.! "B")
657+
658+
-- probExprs: keys equal all classes that appear across any leaf
659+
probExprsAllClasses :: Test
660+
probExprsAllClasses = TestCase $ do
661+
let pt =
662+
Branch
663+
splitCond
664+
(Leaf (M.fromList [("A", 1.0)]))
665+
(Leaf (M.fromList [("B", 0.6), ("C", 0.4)])) ::
666+
ProbTree T.Text
667+
pe = probExprs pt
668+
assertEqual "three classes in result" (sort ["A", "B", "C"]) (sort (M.keys pe))
669+
670+
-- Probabilities sum to 1.0 at every row after applying probExprs
671+
probsSumToOne :: Test
672+
probsSumToOne = TestCase $ do
673+
let stump = Branch splitCond (Leaf "A") (Leaf "B") :: Tree T.Text
674+
pt = buildProbTree @T.Text stump "label" fixtureDF allIndices
675+
pe = probExprs pt
676+
sumExpr = foldl1 (.+) (M.elems pe)
677+
case interpret @Double fixtureDF sumExpr of
678+
Left e -> assertFailure (show e)
679+
Right (DI.TColumn col) ->
680+
case DI.toVector @Double col of
681+
Left e -> assertFailure (show e)
682+
Right vals ->
683+
mapM_
684+
(\v -> assertBool ("sum ≈ 1.0, got " ++ show v) (abs (v - 1.0) < 1e-9))
685+
(V.toList vals)
686+
687+
-- argmax of probExprs agrees with fitDecisionTree on sepDF
688+
probArgmaxMatchesClassifier :: Test
689+
probArgmaxMatchesClassifier = TestCase $ do
690+
let cfg = defaultTreeConfig{taoIterations = 5, expressionPairs = 4, minLeafSize = 1}
691+
hardExpr = fitDecisionTree @T.Text cfg (Col "label") sepDF
692+
pe = fitProbTree @T.Text cfg (Col "label") sepDF
693+
indices = [0 .. D.nRows sepDF - 1]
694+
case interpret @T.Text sepDF hardExpr of
695+
Left e -> assertFailure (show e)
696+
Right (DI.TColumn hardCol) ->
697+
case DI.toVector @T.Text hardCol of
698+
Left e -> assertFailure (show e)
699+
Right hardVals -> do
700+
probCols <-
701+
mapM
702+
( \(cls, expr) -> case interpret @Double sepDF expr of
703+
Left e -> assertFailure (show e) >> return (cls, V.empty)
704+
Right (DI.TColumn col) -> case DI.toVector @Double col of
705+
Left e -> assertFailure (show e) >> return (cls, V.empty)
706+
Right v -> return (cls, v)
707+
)
708+
(M.toList pe)
709+
mapM_
710+
( \i ->
711+
let argmax = fst $ maximumBy (compare `on` (V.! i) . snd) probCols
712+
hard = hardVals V.! i
713+
in assertEqual ("row " ++ show i) hard argmax
714+
)
715+
indices
716+
562717
------------------------------------------------------------------------
563718
-- Test list
564719
------------------------------------------------------------------------
@@ -596,4 +751,14 @@ tests =
596751
, TestLabel "nullableFitZeroLoss" nullableFitZeroLossTest
597752
, TestLabel "nullableFitWithNullsNoCrash" nullableFitWithNullsNoCrashTest
598753
, TestLabel "numericExprsWithTermsMixed" numericExprsWithTermsMixedTest
754+
, TestLabel "probsFromIndicesBasic" probsFromIndicesBasic
755+
, TestLabel "probsFromIndicesSubset" probsFromIndicesSubset
756+
, TestLabel "probsFromIndicesSingleClass" probsFromIndicesSingleClass
757+
, TestLabel "buildProbTreeLeaf" buildProbTreeLeaf
758+
, TestLabel "buildProbTreeBranch" buildProbTreeBranch
759+
, TestLabel "probExprsLeaf" probExprsLeaf
760+
, TestLabel "probExprsMissingClass" probExprsMissingClass
761+
, TestLabel "probExprsAllClasses" probExprsAllClasses
762+
, TestLabel "probsSumToOne" probsSumToOne
763+
, TestLabel "probArgmaxMatchesClassifier" probArgmaxMatchesClassifier
599764
]

0 commit comments

Comments
 (0)