@@ -548,38 +548,108 @@ findBestGreedySplit cfg target conds df =
548548 (boolExpansion (synthConfig cfg))
549549 )
550550
551+ -- | Unifies non-nullable and nullable Double expressions for feature generation.
552+ data NumExpr
553+ = NDouble ! (Expr Double )
554+ | NMaybeDouble ! (Expr (Maybe Double ))
555+
556+ numExprCols :: NumExpr -> [T. Text ]
557+ numExprCols (NDouble e) = getColumns e
558+ numExprCols (NMaybeDouble e) = getColumns e
559+
560+ numExprEq :: NumExpr -> NumExpr -> Bool
561+ numExprEq (NDouble e1) (NDouble e2) = e1 == e2
562+ numExprEq (NMaybeDouble e1) (NMaybeDouble e2) = e1 == e2
563+ numExprEq _ _ = False
564+
565+ combineNumExprs :: NumExpr -> NumExpr -> [NumExpr ]
566+ combineNumExprs (NDouble e1) (NDouble e2) =
567+ [ NDouble (e1 .+ e2)
568+ , NDouble (e1 .- e2)
569+ , NDouble (e1 .* e2)
570+ , NDouble
571+ (F. ifThenElse (e2 ./= F. lit (0 :: Double )) (e1 ./ e2) (F. lit (0 :: Double )))
572+ ]
573+ combineNumExprs (NDouble e1) (NMaybeDouble e2) =
574+ [ NMaybeDouble (e1 .+ e2)
575+ , NMaybeDouble (e1 .- e2)
576+ , NMaybeDouble (e1 .* e2)
577+ , NMaybeDouble
578+ ( F. ifThenElse
579+ (F. fromMaybe False (e2 ./= F. lit (0 :: Double )))
580+ (e1 ./ e2)
581+ (F. lit (Nothing :: Maybe Double ))
582+ )
583+ ]
584+ combineNumExprs (NMaybeDouble e1) (NDouble e2) =
585+ [ NMaybeDouble (e1 .+ e2)
586+ , NMaybeDouble (e1 .- e2)
587+ , NMaybeDouble (e1 .* e2)
588+ , NMaybeDouble
589+ ( F. ifThenElse
590+ (e2 ./= F. lit (0 :: Double ))
591+ (e1 ./ e2)
592+ (F. lit (Nothing :: Maybe Double ))
593+ )
594+ ]
595+ combineNumExprs (NMaybeDouble e1) (NMaybeDouble e2) =
596+ [ NMaybeDouble (e1 .+ e2)
597+ , NMaybeDouble (e1 .- e2)
598+ , NMaybeDouble (e1 .* e2)
599+ , NMaybeDouble
600+ ( F. ifThenElse
601+ (F. fromMaybe False (e2 ./= F. lit (0 :: Double )))
602+ (e1 ./ e2)
603+ (F. lit (Nothing :: Maybe Double ))
604+ )
605+ ]
606+
551607numericConditions :: TreeConfig -> DataFrame -> [Expr Bool ]
552608numericConditions = generateNumericConds
553609
554610generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool ]
555611generateNumericConds cfg df = do
556612 expr <- numericExprsWithTerms (synthConfig cfg) df
557- let thresholds = map ( \ p -> percentile p expr df) (percentiles cfg)
613+ let thresholds = numericThresholds expr
558614 threshold <- thresholds
559- [ expr .<= F. lit threshold
560- , expr .>= F. lit threshold
561- , expr .< F. lit threshold
562- , expr .> F. lit threshold
615+ numericCondsFromExpr expr threshold
616+ where
617+ numericThresholds (NDouble e) = map (\ p -> percentile p e df) (percentiles cfg)
618+ numericThresholds (NMaybeDouble e) = map (\ p -> percentile p (F. fromMaybe 0 e) df) (percentiles cfg)
619+
620+ numericCondsFromExpr (NDouble e) t =
621+ [e .<= F. lit t, e .>= F. lit t, e .< F. lit t, e .> F. lit t]
622+ numericCondsFromExpr (NMaybeDouble e) t =
623+ [ F. fromMaybe False (e .<= F. lit t)
624+ , F. fromMaybe False (e .>= F. lit t)
625+ , F. fromMaybe False (e .< F. lit t)
626+ , F. fromMaybe False (e .> F. lit t)
563627 ]
564628
565- numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double ]
629+ numericExprsWithTerms :: SynthConfig -> DataFrame -> [NumExpr ]
566630numericExprsWithTerms cfg df =
567631 concatMap (numericExprs cfg df [] 0 ) [0 .. maxExprDepth cfg]
568632
569- numericCols :: DataFrame -> [Expr Double ]
633+ numericCols :: DataFrame -> [NumExpr ]
570634numericCols df = concatMap extract (columnNames df)
571635 where
572636 extract col = case unsafeGetColumn col df of
573637 UnboxedColumn (_ :: VU. Vector b ) ->
574638 case testEquality (typeRep @ b ) (typeRep @ Double ) of
575- Just Refl -> [Col col]
639+ Just Refl -> [NDouble ( Col col) ]
576640 Nothing -> case sIntegral @ b of
577- STrue -> [F. toDouble (Col @ b col)]
641+ STrue -> [NDouble (F. toDouble (Col @ b col))]
642+ SFalse -> []
643+ OptionalColumn (_ :: V. Vector (Maybe b )) ->
644+ case testEquality (typeRep @ b ) (typeRep @ Double ) of
645+ Just Refl -> [NMaybeDouble (Col @ (Maybe b ) col)]
646+ Nothing -> case sIntegral @ b of
647+ STrue -> [NMaybeDouble (F. whenPresent (realToFrac @ b @ Double ) (Col @ (Maybe b ) col))]
578648 SFalse -> []
579649 _ -> []
580650
581651numericExprs ::
582- SynthConfig -> DataFrame -> [Expr Double ] -> Int -> Int -> [Expr Double ]
652+ SynthConfig -> DataFrame -> [NumExpr ] -> Int -> Int -> [NumExpr ]
583653numericExprs cfg df prevExprs depth maxDepth
584654 | depth == 0 = baseExprs ++ numericExprs cfg df baseExprs (depth + 1 ) maxDepth
585655 | depth >= maxDepth = []
@@ -592,20 +662,16 @@ numericExprs cfg df prevExprs depth maxDepth
592662 | otherwise = do
593663 e1 <- prevExprs
594664 e2 <- baseExprs
595- let cols = getColumns e1 <> getColumns e2
665+ let cols = numExprCols e1 <> numExprCols e2
596666 guard
597- ( e1 /= e2
667+ ( not (numExprEq e1 e2)
598668 && not
599669 ( any
600670 (\ (l, r) -> l `elem` cols && r `elem` cols)
601671 (disallowedCombinations cfg)
602672 )
603673 )
604- [ e1 + e2
605- , e1 - e2
606- , e1 * e2
607- , F. ifThenElse (e2 ./= (0 :: Expr Double )) (e1 / e2) 0
608- ]
674+ combineNumExprs e1 e2
609675
610676boolExprs ::
611677 DataFrame -> [Expr Bool ] -> [Expr Bool ] -> Int -> Int -> [Expr Bool ]
@@ -631,37 +697,9 @@ generateConditionsOld cfg df =
631697 let ps = map (Lit . (`percentileOrd'` col)) [1 , 25 , 75 , 99 ]
632698 in map (F. lift2 (==) (Col @ a colName)) ps
633699 (OptionalColumn (col :: V. Vector (Maybe a ))) -> case sFloating @ a of
634- STrue ->
635- let doubleCol =
636- VU. convert
637- (V. map fromJust (V. filter isJust (V. map (fmap (realToFrac @ a @ Double )) col)))
638- in zipWith
639- ($)
640- [ F. lift2 (==) (Col @ (Maybe a ) colName)
641- , F. lift2 (<=) (Col @ (Maybe a ) colName)
642- , F. lift2 (>=) (Col @ (Maybe a ) colName)
643- ]
644- ( Lit Nothing
645- : map
646- (Lit . Just . realToFrac . (`percentile'` doubleCol))
647- (percentiles cfg)
648- )
700+ STrue -> [] -- handled by numericCols / numericExprs
649701 SFalse -> case sIntegral @ a of
650- STrue ->
651- let doubleCol =
652- VU. convert
653- (V. map fromJust (V. filter isJust (V. map (fmap (fromIntegral @ a @ Double )) col)))
654- in zipWith
655- ($)
656- [ F. lift2 (==) (Col @ (Maybe a ) colName)
657- , F. lift2 (<=) (Col @ (Maybe a ) colName)
658- , F. lift2 (>=) (Col @ (Maybe a ) colName)
659- ]
660- ( Lit Nothing
661- : map
662- (Lit . Just . round . (`percentile'` doubleCol))
663- (percentiles cfg)
664- )
702+ STrue -> [] -- handled by numericCols / numericExprs
665703 SFalse ->
666704 map
667705 (F. lift2 (==) (Col @ (Maybe a ) colName) . Lit . (`percentileOrd'` col))
0 commit comments