Skip to content

Commit 711d0e9

Browse files
committed
feat: Include stratified sampling.
1 parent 2537870 commit 711d0e9

8 files changed

Lines changed: 890 additions & 53 deletions

File tree

app/Synthesis.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ main = do
5353
fitDecisionTree
5454
( defaultTreeConfig
5555
{ maxTreeDepth = 5
56-
, minSamplesSplit = 10
56+
, minSamplesSplit = 5
5757
, minLeafSize = 3
5858
, taoIterations = 100
5959
, synthConfig =
6060
defaultSynthConfig
61-
{ complexityPenalty = 0.00
62-
, maxExprDepth = 2
61+
{ complexityPenalty = 0.1
62+
, maxExprDepth = 3
6363
, disallowedCombinations =
6464
[ (F.name age, F.name fare)
6565
, ("passenger_class", "number_of_siblings_and_spouses")

dataframe.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ test-suite tests
249249
type: exitcode-stdio-1.0
250250
main-is: Main.hs
251251
other-modules: Assertions,
252+
DecisionTree,
252253
Functions,
253254
GenDataFrame,
254255
Internal.Parsing,

src/DataFrame.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ import DataFrame.Operations.Subset as Subset (
362362
sample,
363363
select,
364364
selectBy,
365+
stratifiedSample,
366+
stratifiedSplit,
365367
take,
366368
takeLast,
367369
)

src/DataFrame/DecisionTree.hs

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -548,38 +548,108 @@ findBestGreedySplit cfg target conds df =
548548
(boolExpansion (synthConfig cfg))
549549
)
550550

551+
-- | Unifies non-nullable and nullable Double expressions for feature generation.
552+
data NumExpr
553+
= NDouble !(Expr Double)
554+
| NMaybeDouble !(Expr (Maybe Double))
555+
556+
numExprCols :: NumExpr -> [T.Text]
557+
numExprCols (NDouble e) = getColumns e
558+
numExprCols (NMaybeDouble e) = getColumns e
559+
560+
numExprEq :: NumExpr -> NumExpr -> Bool
561+
numExprEq (NDouble e1) (NDouble e2) = e1 == e2
562+
numExprEq (NMaybeDouble e1) (NMaybeDouble e2) = e1 == e2
563+
numExprEq _ _ = False
564+
565+
combineNumExprs :: NumExpr -> NumExpr -> [NumExpr]
566+
combineNumExprs (NDouble e1) (NDouble e2) =
567+
[ NDouble (e1 .+ e2)
568+
, NDouble (e1 .- e2)
569+
, NDouble (e1 .* e2)
570+
, NDouble
571+
(F.ifThenElse (e2 ./= F.lit (0 :: Double)) (e1 ./ e2) (F.lit (0 :: Double)))
572+
]
573+
combineNumExprs (NDouble e1) (NMaybeDouble e2) =
574+
[ NMaybeDouble (e1 .+ e2)
575+
, NMaybeDouble (e1 .- e2)
576+
, NMaybeDouble (e1 .* e2)
577+
, NMaybeDouble
578+
( F.ifThenElse
579+
(F.fromMaybe False (e2 ./= F.lit (0 :: Double)))
580+
(e1 ./ e2)
581+
(F.lit (Nothing :: Maybe Double))
582+
)
583+
]
584+
combineNumExprs (NMaybeDouble e1) (NDouble e2) =
585+
[ NMaybeDouble (e1 .+ e2)
586+
, NMaybeDouble (e1 .- e2)
587+
, NMaybeDouble (e1 .* e2)
588+
, NMaybeDouble
589+
( F.ifThenElse
590+
(e2 ./= F.lit (0 :: Double))
591+
(e1 ./ e2)
592+
(F.lit (Nothing :: Maybe Double))
593+
)
594+
]
595+
combineNumExprs (NMaybeDouble e1) (NMaybeDouble e2) =
596+
[ NMaybeDouble (e1 .+ e2)
597+
, NMaybeDouble (e1 .- e2)
598+
, NMaybeDouble (e1 .* e2)
599+
, NMaybeDouble
600+
( F.ifThenElse
601+
(F.fromMaybe False (e2 ./= F.lit (0 :: Double)))
602+
(e1 ./ e2)
603+
(F.lit (Nothing :: Maybe Double))
604+
)
605+
]
606+
551607
numericConditions :: TreeConfig -> DataFrame -> [Expr Bool]
552608
numericConditions = generateNumericConds
553609

554610
generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool]
555611
generateNumericConds cfg df = do
556612
expr <- numericExprsWithTerms (synthConfig cfg) df
557-
let thresholds = map (\p -> percentile p expr df) (percentiles cfg)
613+
let thresholds = numericThresholds expr
558614
threshold <- thresholds
559-
[ expr .<= F.lit threshold
560-
, expr .>= F.lit threshold
561-
, expr .< F.lit threshold
562-
, expr .> F.lit threshold
615+
numericCondsFromExpr expr threshold
616+
where
617+
numericThresholds (NDouble e) = map (\p -> percentile p e df) (percentiles cfg)
618+
numericThresholds (NMaybeDouble e) = map (\p -> percentile p (F.fromMaybe 0 e) df) (percentiles cfg)
619+
620+
numericCondsFromExpr (NDouble e) t =
621+
[e .<= F.lit t, e .>= F.lit t, e .< F.lit t, e .> F.lit t]
622+
numericCondsFromExpr (NMaybeDouble e) t =
623+
[ F.fromMaybe False (e .<= F.lit t)
624+
, F.fromMaybe False (e .>= F.lit t)
625+
, F.fromMaybe False (e .< F.lit t)
626+
, F.fromMaybe False (e .> F.lit t)
563627
]
564628

565-
numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
629+
numericExprsWithTerms :: SynthConfig -> DataFrame -> [NumExpr]
566630
numericExprsWithTerms cfg df =
567631
concatMap (numericExprs cfg df [] 0) [0 .. maxExprDepth cfg]
568632

569-
numericCols :: DataFrame -> [Expr Double]
633+
numericCols :: DataFrame -> [NumExpr]
570634
numericCols df = concatMap extract (columnNames df)
571635
where
572636
extract col = case unsafeGetColumn col df of
573637
UnboxedColumn (_ :: VU.Vector b) ->
574638
case testEquality (typeRep @b) (typeRep @Double) of
575-
Just Refl -> [Col col]
639+
Just Refl -> [NDouble (Col col)]
576640
Nothing -> case sIntegral @b of
577-
STrue -> [F.toDouble (Col @b col)]
641+
STrue -> [NDouble (F.toDouble (Col @b col))]
642+
SFalse -> []
643+
OptionalColumn (_ :: V.Vector (Maybe b)) ->
644+
case testEquality (typeRep @b) (typeRep @Double) of
645+
Just Refl -> [NMaybeDouble (Col @(Maybe b) col)]
646+
Nothing -> case sIntegral @b of
647+
STrue -> [NMaybeDouble (F.whenPresent (realToFrac @b @Double) (Col @(Maybe b) col))]
578648
SFalse -> []
579649
_ -> []
580650

581651
numericExprs ::
582-
SynthConfig -> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
652+
SynthConfig -> DataFrame -> [NumExpr] -> Int -> Int -> [NumExpr]
583653
numericExprs cfg df prevExprs depth maxDepth
584654
| depth == 0 = baseExprs ++ numericExprs cfg df baseExprs (depth + 1) maxDepth
585655
| depth >= maxDepth = []
@@ -592,20 +662,16 @@ numericExprs cfg df prevExprs depth maxDepth
592662
| otherwise = do
593663
e1 <- prevExprs
594664
e2 <- baseExprs
595-
let cols = getColumns e1 <> getColumns e2
665+
let cols = numExprCols e1 <> numExprCols e2
596666
guard
597-
( e1 /= e2
667+
( not (numExprEq e1 e2)
598668
&& not
599669
( any
600670
(\(l, r) -> l `elem` cols && r `elem` cols)
601671
(disallowedCombinations cfg)
602672
)
603673
)
604-
[ e1 + e2
605-
, e1 - e2
606-
, e1 * e2
607-
, F.ifThenElse (e2 ./= (0 :: Expr Double)) (e1 / e2) 0
608-
]
674+
combineNumExprs e1 e2
609675

610676
boolExprs ::
611677
DataFrame -> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
@@ -631,37 +697,9 @@ generateConditionsOld cfg df =
631697
let ps = map (Lit . (`percentileOrd'` col)) [1, 25, 75, 99]
632698
in map (F.lift2 (==) (Col @a colName)) ps
633699
(OptionalColumn (col :: V.Vector (Maybe a))) -> case sFloating @a of
634-
STrue ->
635-
let doubleCol =
636-
VU.convert
637-
(V.map fromJust (V.filter isJust (V.map (fmap (realToFrac @a @Double)) col)))
638-
in zipWith
639-
($)
640-
[ F.lift2 (==) (Col @(Maybe a) colName)
641-
, F.lift2 (<=) (Col @(Maybe a) colName)
642-
, F.lift2 (>=) (Col @(Maybe a) colName)
643-
]
644-
( Lit Nothing
645-
: map
646-
(Lit . Just . realToFrac . (`percentile'` doubleCol))
647-
(percentiles cfg)
648-
)
700+
STrue -> [] -- handled by numericCols / numericExprs
649701
SFalse -> case sIntegral @a of
650-
STrue ->
651-
let doubleCol =
652-
VU.convert
653-
(V.map fromJust (V.filter isJust (V.map (fmap (fromIntegral @a @Double)) col)))
654-
in zipWith
655-
($)
656-
[ F.lift2 (==) (Col @(Maybe a) colName)
657-
, F.lift2 (<=) (Col @(Maybe a) colName)
658-
, F.lift2 (>=) (Col @(Maybe a) colName)
659-
]
660-
( Lit Nothing
661-
: map
662-
(Lit . Just . round . (`percentile'` doubleCol))
663-
(percentiles cfg)
664-
)
702+
STrue -> [] -- handled by numericCols / numericExprs
665703
SFalse ->
666704
map
667705
(F.lift2 (==) (Col @(Maybe a) colName) . Lit . (`percentileOrd'` col))

src/DataFrame/Operations/Subset.hs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE ExplicitNamespaces #-}
23
{-# LANGUAGE FlexibleContexts #-}
34
{-# LANGUAGE GADTs #-}
@@ -36,10 +37,12 @@ import DataFrame.Internal.DataFrame (
3637
derivingExpressions,
3738
empty,
3839
getColumn,
40+
unsafeGetColumn,
3941
)
4042
import DataFrame.Internal.Expression
4143
import DataFrame.Internal.Interpreter
4244
import DataFrame.Operations.Core
45+
import DataFrame.Operations.Merge ()
4346
import DataFrame.Operations.Transformations (apply)
4447
import System.Random
4548
import Type.Reflection
@@ -471,3 +474,81 @@ generateRandomVector pureGen k = VU.fromList $ go pureGen k
471474
(v, g') = uniformR (0 :: Double, 1 :: Double) g
472475
in
473476
v : go g' (n - 1)
477+
478+
-- | Convert any Column to a vector of Text labels (one per row).
479+
columnToTextVec :: Column -> V.Vector T.Text
480+
columnToTextVec (BoxedColumn (col :: V.Vector a)) =
481+
case testEquality (typeRep @a) (typeRep @T.Text) of
482+
Just Refl -> col
483+
Nothing -> V.map (T.pack . show) col
484+
columnToTextVec (UnboxedColumn col) = V.map (T.pack . show) (V.convert col)
485+
columnToTextVec (OptionalColumn col) = V.map (T.pack . show) col
486+
487+
-- | Build a map from stringified label to row indices.
488+
groupByIndices :: Column -> M.Map T.Text (VU.Vector Int)
489+
groupByIndices col =
490+
let textVec = columnToTextVec col
491+
(grouped, _) =
492+
V.foldl'
493+
(\(!m, !i) key -> (M.insertWith (++) key [i] m, i + 1))
494+
(M.empty, 0)
495+
textVec
496+
in M.map (VU.fromList . L.reverse) grouped
497+
498+
-- | Select rows at the given indices from all columns.
499+
rowsAtIndices :: VU.Vector Int -> DataFrame -> DataFrame
500+
rowsAtIndices ixs df =
501+
df
502+
{ columns = V.map (atIndicesStable ixs) (columns df)
503+
, dataframeDimensions = (VU.length ixs, snd (dataframeDimensions df))
504+
}
505+
506+
{- | Sample a dataframe, preserving per-stratum proportions.
507+
508+
==== __Example__
509+
@
510+
ghci> import System.Random
511+
ghci> D.stratifiedSample (mkStdGen 42) 0.8 "label" df
512+
@
513+
-}
514+
stratifiedSample ::
515+
forall a g.
516+
(SplitGen g, RandomGen g, Columnable a) =>
517+
g -> Double -> Expr a -> DataFrame -> DataFrame
518+
stratifiedSample gen p strataCol df =
519+
let col = case strataCol of
520+
Col name -> unsafeGetColumn name df
521+
_ -> unwrapTypedColumn (either throw id (interpret @a df strataCol))
522+
groups = M.elems (groupByIndices col)
523+
go _ [] = mempty
524+
go g (ixs : rest) =
525+
let stratum = rowsAtIndices ixs df
526+
(g1, g2) = splitGen g
527+
in sample g1 p stratum <> go g2 rest
528+
in go gen groups
529+
530+
{- | Split a dataframe into two, preserving per-stratum proportions.
531+
532+
==== __Example__
533+
@
534+
ghci> import System.Random
535+
ghci> D.stratifiedSplit (mkStdGen 42) 0.8 "label" df
536+
@
537+
-}
538+
stratifiedSplit ::
539+
forall a g.
540+
(SplitGen g, RandomGen g, Columnable a) =>
541+
g -> Double -> Expr a -> DataFrame -> (DataFrame, DataFrame)
542+
stratifiedSplit gen p strataCol df =
543+
let col = case strataCol of
544+
Col name -> unsafeGetColumn name df
545+
_ -> unwrapTypedColumn (either throw id (interpret @a df strataCol))
546+
groups = M.elems (groupByIndices col)
547+
go _ [] = (mempty, mempty)
548+
go g (ixs : rest) =
549+
let stratum = rowsAtIndices ixs df
550+
(g1, g2) = splitGen g
551+
(tr, va) = randomSplit g1 p stratum
552+
(trAcc, vaAcc) = go g2 rest
553+
in (tr <> trAcc, va <> vaAcc)
554+
in go gen groups

0 commit comments

Comments
 (0)