Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 27 additions & 86 deletions inst/Classification/ClassificationGAM.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Copyright (C) 2024 Ruchika Sonagote <ruchikasonagote2003@gmail.com>
## Copyright (C) 2024-2025 Andreas Bertsatos <abertsatos@biol.uoa.gr>
## Copyright (C) 2025 Swayam Shah <swayamshah66@gmail.com>
## Copyright (C) 2026 Jayant Chauhan <0001jayant@gmail.com>
##
## This file is part of the statistics package for GNU Octave.
##
Expand Down Expand Up @@ -842,7 +843,8 @@ function disp (this)
if (F_I > 0)
if (isempty (this.Formula))
## Analyze Interactions optional parameter
this.IntMatrix = this.parseInteractions ();
## Use external FormulaParser function
this.IntMatrix = FormulaParser (this.Interactions, this.NumPredictors);
## Append interaction terms to the predictor matrix
for i = 1:rows (this.IntMatrix)
tindex = logical (this.IntMatrix(i,:));
Expand All @@ -857,7 +859,14 @@ function disp (this)

else
## Analyze Formula optional parameter
this.IntMatrix = this.parseFormula ();
## Use external FormulaParser function and capture response
[this.IntMatrix, parsedResp, ~] = FormulaParser (this.Formula, this.PredictorNames);

## If the formula included a Response Name (LHS), update the object property
if (! isempty (parsedResp))
this.ResponseName = parsedResp;
endif

## Add selected predictors and interaction terms
XN = [];
for i = 1:rows (this.IntMatrix)
Expand Down Expand Up @@ -1240,90 +1249,6 @@ function savemodel (this, fname)
## Helper functions
methods (Access = private)

## Determine interactions from Interactions optional parameter
function intMat = parseInteractions (this)
if (islogical (this.Interactions))
## Check that interaction matrix corresponds to predictors
if (numel (this.PredictorNames) != columns (this.Interactions))
error (strcat ("ClassificationGAM: columns in 'Interactions'", ...
" matrix must equal to the number of predictors."));
endif
intMat = this.Interactions;
elseif (isnumeric (this.Interactions))
## Need to measure the effect of all interactions to keep the best
## performing. Just check that the given number is not higher than
## p*(p-1)/2, where p is the number of predictors.
p = this.NumPredictors;
if (this.Interactions > p * (p - 1) / 2)
error (strcat ("ClassificationGAM: number of interaction terms", ...
" requested is larger than all possible", ...
" combinations of predictors in X."));
endif
## Get all combinations except all zeros
allMat = flip (fullfact(p)([2:end],:), 2);
## Only keep interaction terms
iterms = find (sum (allMat, 2) != 1);
intMat = allMat(iterms);
elseif (strcmpi (this.Interactions, "all"))
p = this.NumPredictors;
## Calculate all p*(p-1)/2 interaction terms
allMat = flip (fullfact(p)([2:end],:), 2);
## Only keep interaction terms
iterms = find (sum (allMat, 2) != 1);
intMat = allMat(iterms);
endif
endfunction

## Determine interactions from formula
function intMat = parseFormula (this)
intMat = [];
## Check formula for syntax
if (isempty (strfind (this.Formula, '~')))
error ("ClassificationGAM: invalid syntax in 'Formula'.");
endif
## Split formula and keep predictor terms
formulaParts = strsplit (this.Formula, '~');
## Check there is some string after '~'
if (numel (formulaParts) < 2)
error ("ClassificationGAM: no predictor terms in 'Formula'.");
endif
predictorString = strtrim (formulaParts{2});
if (isempty (predictorString))
error ("ClassificationGAM: no predictor terms in 'Formula'.");
endif
## Split additive terms (between + sign)
aterms = strtrim (strsplit (predictorString, '+'));
## Process all terms
for i = 1:numel (aterms)
## Find individual terms (string missing ':')
if (isempty (strfind (aterms(i), ':'){:}))
## Search PredictorNames to associate with column in X
sterms = strcmp (this.PredictorNames, aterms(i));
## Append to interactions matrix
intMat = [intMat; sterms];
else
## Split interaction terms (string contains ':')
mterms = strsplit (aterms{i}, ':');
## Add each individual predictor to interaction term vector
iterms = logical (zeros (1, this.NumPredictors));
for t = 1:numel (mterms)
iterms = iterms | strcmp (this.PredictorNames, mterms(t));
endfor
## Check that all predictors have been identified
if (sum (iterms) != t)
error (strcat ("ClassificationGAM: some predictors", ...
" have not been identified."));
endif
## Append to interactions matrix
intMat = [intMat; iterms];
endif
endfor
## Check that all terms have been identified
if (! all (sum (intMat, 2) > 0))
error ("ClassificationGAM: some terms have not been identified.");
endif
endfunction

## Fit the model
function [iter, param, res, RSS, intercept] = fitGAM (this, X, Y, Inter, ...
Knots, Order, learning_rate, num_iterations)
Expand Down Expand Up @@ -1543,6 +1468,22 @@ function savemodel (this, fname)
%! ClassificationGAM (ones (5,2), ones (5,1), "Cost", "string")
%!error<ClassificationGAM: 'Cost' must be a numeric square matrix.> ...
%! ClassificationGAM (ones (5,2), ones (5,1), "Cost", {eye(2)})
%!error<ClassificationGAM: 'Formula' must be a string.>
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", {"y~x1+x2"})
%!error<ClassificationGAM: 'Formula' must be a string.>
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", [0, 1, 0])
%!error<FormulaParser: invalid syntax. Formula must contain '~'.> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", "something")
%!error<FormulaParser: no predictor terms found.> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", "something~")
%!error<FormulaParser: no predictor terms found.> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", "something~ ")
%!error<FormulaParser: invalid syntax> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", "something~x1:")
%!error<ClassificationGAM: 'Formula' has already been defined.> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "formula", "y ~ x1 + x2", "interactions", 1)
%!error<ClassificationGAM: 'Interactions' have already been defined.> ...
%! ClassificationGAM (ones(10,2), ones (10,1), "interactions", 1, "formula", "y ~ x1 + x2")

## Test predict method
%!test
Expand Down
Loading