Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/PLNPCA.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA
#' @inherit PLN_param details
#' @export
PLNPCA_param <- function(
backend = "nlopt",
backend = c("nlopt", "torch"),
trace = 1 ,
config_optim = list() ,
config_post = list() ,
Expand Down
189 changes: 186 additions & 3 deletions R/PLNPCAfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,185 @@ PLNPCAfit <- R6Class(
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
private = list(
C = NULL,
svdCM = NULL
svdCM = NULL,

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PRIVATE TORCH METHODS FOR RANK-CONSTRAINED OPTIMIZATION
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
torch_elbo_rank_core = function(data, M, S, B, C, index) {
S2 <- torch_square(S[index]) # (batch, q)
C2 <- torch_square(C) # (p, q)
Z <- data$O[index] +
torch_mm(M[index], torch_t(C)) +
torch_mm(data$X[index], B) # (batch, p)
A <- torch_exp(Z + 0.5 * torch_mm(S2, torch_t(C2)))
lik_part <- torch_sum(data$w[index, NULL] * (A - data$Y[index] * Z))
kl_part <- 0.5 * torch_sum(data$w[index, NULL] *
(torch_square(M[index]) + S2 - torch_log(S2) - 1))
lik_part + kl_part
},

torch_elbo_rank = function(data, params, index = torch_tensor(1:self$n)) {
private$torch_elbo_rank_core(data, params$M, params$S, params$B, params$C, index)
},

torch_vloglik_rank = function(data, params) {
S2 <- torch_square(params$S)
C2 <- torch_square(params$C)
Z <- data$O + torch_mm(params$M, torch_t(params$C)) + torch_mm(data$X, params$B)
A <- torch_exp(Z + 0.5 * torch_mm(S2, torch_t(C2)))
Ji <- - torch_sum(.logfactorial_torch(data$Y), dim = 2) +
torch_sum(data$Y * Z - A, dim = 2) -
0.5 * torch_sum(torch_square(params$M) + S2 - torch_log(S2) - 1, dim = 2)
Ji <- .5 * self$p + as.numeric(Ji$cpu())
attr(Ji, "weights") <- as.numeric(data$w$cpu())
Ji
},

torch_optimize_rank_core = function(data, params, config, n_obs, loss_fn) {
optimizer <- switch(config$algorithm,
"RPROP" = optim_rprop(params, lr = config$lr, etas = config$etas, step_sizes = config$step_sizes),
"RMSPROP" = optim_rmsprop(params, lr = config$lr, weight_decay = config$weight_decay, momentum = config$momentum, centered = config$centered),
"ADAM" = optim_adam(params, lr = config$lr, weight_decay = config$weight_decay),
"ADAGRAD" = optim_adagrad(params, lr = config$lr, weight_decay = config$weight_decay)
)

status <- 5
num_epoch <- config$num_epoch
num_batch <- config$num_batch
batch_size <- floor(n_obs / num_batch)

objective <- double(length = config$num_epoch + 1)
for (iterate in 1:num_epoch) {
permute <- torch::torch_tensor(sample.int(n_obs), dtype = torch_long(), device = config$device)
for (batch_idx in 1:num_batch) {
index <- permute[(batch_size * (batch_idx - 1) + 1):(batch_idx * batch_size)]
optimizer$zero_grad()
loss <- loss_fn(index)
loss$backward()
optimizer$step()
}

objective[iterate + 1] <- loss$item()
delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1])

if (!is.finite(loss$item())) {
stop(sprintf(
"The ELBO diverged during the optimization procedure.\nConsider using:\n* a different optimizer (current optimizer: %s)\n* a smaller learning rate (current rate: %.3f)\nwith `control = PLNPCA_param(backend = 'torch', config_optim = list(algorithm = ..., lr = ...))`",
config$algorithm, config$lr
))
}

if (config$trace > 1 && (iterate %% 50 == 1))
cat('\niteration:', iterate, 'objective', objective[iterate + 1],
'delta_f', round(delta_f, 6))

if (delta_f < config$ftol_rel) status <- 3
if (status %in% c(3, 4)) {
objective <- objective[seq_len(iterate + 1)]
break
}
}

list(
params = params,
objective = objective,
iterations = iterate,
status = status
)
},

torch_optimize_vestep_rank = function(data, params, B, C, config) {
if (config$trace > 1)
message(paste("optimizing with device:", config$device))

n <- nrow(data$Y)
data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device)
params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device)
B <- torch_tensor(B, dtype = torch_float32(), device = config$device)
C <- torch_tensor(C, dtype = torch_float32(), device = config$device)

optim_out <- private$torch_optimize_rank_core(
data = data,
params = params,
config = config,
n_obs = n,
loss_fn = function(index) {
private$torch_elbo_rank_core(data, params$M, params$S, B, C, index)
}
)
params_r <- lapply(optim_out$params, function(x) as.matrix(x$cpu()))
Ji_r <- private$torch_vloglik_rank(data, c(optim_out$params, list(B = B, C = C)))

list(
M = params_r$M,
S = params_r$S,
Ji = Ji_r,
monitoring = list(
objective = optim_out$objective,
iterations = optim_out$iterations,
status = optim_out$status,
backend = "torch"
)
)
},

torch_optimize_rank = function(data, params, config) {
if (config$trace > 1)
message(paste("optimizing with device:", config$device))

data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device)
params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device)

optim_out <- private$torch_optimize_rank_core(
data = data,
params = params,
config = config,
n_obs = self$n,
loss_fn = function(index) {
private$torch_elbo_rank(data, params, index)
}
)

## Compute derived quantities on CPU
params_r <- lapply(optim_out$params, function(x) as.matrix(x$cpu()))
data_r <- lapply(data, function(x) as.matrix(x$cpu()))

q <- ncol(params_r$M)
S2_r <- params_r$S^2
C2_r <- params_r$C^2
Z_r <- data_r$O + params_r$M %*% t(params_r$C) + data_r$X %*% params_r$B
A_r <- exp(Z_r + 0.5 * S2_r %*% t(C2_r))
w_r <- as.numeric(data_r$w)

wM <- params_r$M * sqrt(w_r)
inner_q <- (crossprod(wM) + diag(colSums(S2_r * w_r), nrow = q)) / sum(w_r)
Sigma_r <- params_r$C %*% inner_q %*% t(params_r$C)
Omega_r <- params_r$C %*% solve(inner_q) %*% t(params_r$C)

Ji_r <- .5 * self$p - rowSums(.logfactorial(as.matrix(data_r$Y))) +
rowSums(data_r$Y * Z_r - A_r) -
0.5 * rowSums(params_r$M^2 + S2_r - log(S2_r) - 1)
attr(Ji_r, "weights") <- w_r

list(
B = params_r$B,
C = params_r$C,
M = params_r$M,
S = params_r$S,
Z = Z_r,
A = A_r,
Sigma = Sigma_r,
Omega = Omega_r,
Ji = Ji_r,
monitoring = list(
objective = optim_out$objective,
iterations = optim_out$iterations,
status = optim_out$status,
backend = "torch"
)
)
}
),
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PUBLIC MEMBERS ----
Expand All @@ -58,7 +236,11 @@ PLNPCAfit <- R6Class(
#' @description Initialize a [`PLNPCAfit`] object
initialize = function(rank, responses, covariates, offsets, weights, formula, control) {
super$initialize(responses, covariates, offsets, weights, formula, control)
private$optimizer$main <- nlopt_optimize_rank
if (control$backend == "torch") {
private$optimizer$main <- private$torch_optimize_rank
} else {
private$optimizer$main <- nlopt_optimize_rank
}
private$optimizer$vestep <- nlopt_optimize_vestep_rank
if (!is.null(control$svdM)) {
svdM <- control$svdM
Expand Down Expand Up @@ -125,7 +307,8 @@ PLNPCAfit <- R6Class(
B = private$B,
C = private$C,
config = control$config_optim)
optim_out <- do.call(private$optimizer$vestep, args)
vestep_optimizer <- if (control$backend == "torch") private$torch_optimize_vestep_rank else private$optimizer$vestep
optim_out <- do.call(vestep_optimizer, args)
optim_out
},

Expand Down
11 changes: 7 additions & 4 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,12 @@ PLNfit <- R6Class(
if (is.null(O)) O <- matrix(0, n_new, self$p)

# Compute parameters of the law
vcov11 <- private$Sigma[cond , cond, drop = FALSE]
vcov22 <- private$Sigma[!cond, !cond, drop = FALSE]
vcov12 <- private$Sigma[cond , !cond, drop = FALSE]
# as.matrix() coerces sparse Matrix (returned by diagonal/spherical covariance
# models) to dense, so that simplify2array() in the map below produces a
# numeric array rather than a list of sparse Matrix objects.
vcov11 <- as.matrix(private$Sigma[cond , cond, drop = FALSE])
vcov22 <- as.matrix(private$Sigma[!cond, !cond, drop = FALSE])
vcov12 <- as.matrix(private$Sigma[cond , !cond, drop = FALSE])
prec11 <- solve(vcov11)
A <- crossprod(vcov12, prec11)
Sigma21 <- vcov22 - A %*% vcov12
Expand Down Expand Up @@ -931,7 +934,7 @@ PLNfit_fixedcov <- R6Class(
},

torch_Omega = function(data, params) {
params$Omega <- torch_tensor(private$Omega)
params$Omega <- torch_tensor(private$Omega, dtype = params$B$dtype, device = params$B$device)
},

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand Down
2 changes: 1 addition & 1 deletion man/PLNPCA_param.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-plnfit.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ test_that("PLN fit: Check conditional prediction", {

})

test_that("PLN fit: Check conditional prediction with sparse covariance models", {

n_cond <- 10
p_cond <- 2
p <- ncol(trichoptera$Abundance)
Yc <- trichoptera$Abundance[1:n_cond, 1:p_cond, drop = FALSE]
newdata <- trichoptera[1:n_cond, , drop = FALSE]

for (covariance in c("diagonal", "spherical")) {
model <- PLN(
Abundance ~ 1,
data = trichoptera,
control = PLN_param(covariance = covariance, trace = 0)
)

pred <- predict_cond(model, newdata, Yc, type = "response", var_par = TRUE)
expect_equal(dim(pred), c(n_cond, p - p_cond))
expect_equal(dim(attr(pred, "M")), dim(pred))
expect_equal(dim(attr(pred, "S")), c(p - p_cond, p - p_cond, n_cond))
expect_true(is.array(attr(pred, "S")))
expect_true(is.numeric(attr(pred, "S")))
}
})

test_that("PLN fit: Check number of parameters", {

p <- ncol(trichoptera$Abundance)
Expand Down
40 changes: 40 additions & 0 deletions tests/testthat/test-plnnetworkfit.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,43 @@ test_that("PLNnetwork fit: check classes, getters and field access", {
expect_true(inherits(myPLNfit$plot_network(output = "corrplot", plot = FALSE), "Matrix"))

})

test_that("PLNnetwork fit accepts torch backend", {
skip_if_not_installed("torch")
skip_if_not(torch::torch_is_installed())

data("trichoptera", package = "PLNmodels", envir = environment())
trichoptera_small <- prepare_data(
trichoptera$Abundance[1:10, 1:4],
trichoptera$Covariate[1:10, , drop = FALSE]
)
Y <- as.matrix(trichoptera_small$Abundance)
torch_control <- PLNnetwork_param(
backend = "torch",
trace = 0,
config_optim = list(
algorithm = "RPROP",
lr = 0.01,
num_epoch = 5,
num_batch = 1,
maxit_out = 2
)
)

models <- NULL
expect_no_error(models <- PLNnetwork(
Abundance ~ 1,
data = trichoptera_small,
penalties = 0.1,
control = torch_control
))
expect_false(is.null(models))

myPLNfit <- getBestModel(models)
expect_equal(dim(myPLNfit$latent), dim(Y))
expect_equal(dim(myPLNfit$model_par$B), c(1, ncol(Y)))
expect_equal(dim(myPLNfit$model_par$Omega), c(ncol(Y), ncol(Y)))
expect_equal(dim(myPLNfit$var_par$M), dim(Y))
expect_equal(dim(myPLNfit$var_par$S), dim(Y))
expect_equal(sum(myPLNfit$loglik_vec), myPLNfit$loglik, tolerance = 1e-4)
})
34 changes: 34 additions & 0 deletions tests/testthat/test-plnpcafit.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,40 @@ test_that("PLNPCA fit: check classes, getters and field access", {
expect_true(inherits(myPLNfit, "PCA"))
})

test_that("PLNPCA torch backend works for fit and project", {
skip_if_not_installed("torch")
skip_if_not(torch::torch_is_installed())

torch_control <- PLNPCA_param(
backend = "torch",
trace = 0,
config_optim = list(algorithm = "RPROP", lr = 0.01, num_epoch = 20, num_batch = 1)
)

torch_fit <- getModel(
PLNPCA(
Abundance ~ 1,
data = trichoptera,
ranks = 1,
control = torch_control
),
1
)

Y <- as.matrix(trichoptera$Abundance)
expected_loglik_vec <- .5 * ncol(Y) - rowSums(PLNmodels:::.logfactorial(Y)) +
rowSums(Y * torch_fit$latent - fitted(torch_fit)) -
.5 * rowSums(torch_fit$var_par$M^2 + torch_fit$var_par$S^2 - log(torch_fit$var_par$S^2) - 1)

expect_equal(torch_fit$loglik_vec, expected_loglik_vec, tolerance = 1e-4, check.attributes = FALSE)

model1 <- getModel(models, 1)
expect_no_error(scores <- model1$project(newdata = trichoptera, control = torch_control))
expect_false(is.null(scores))
expect_equal(dim(scores), dim(model1$scores))
expect_equal(dimnames(scores), dimnames(model1$scores))
})

test_that("Bindings for factoextra return sensible values", {
## $eig
expect_gte(min(myPLNfit$eig[, "eigenvalue"]), 0)
Expand Down
Loading