Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions R/Classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ utils::globalVariables(
#' @title Creates a binary classifier to classify cells within a Seurat object
#'
#' @description Creates a binary classifier to classify cells
#' @param training_matrix A counts or data slot provided by TrainModelsFromSeurat
#' @param training_matrix A matrix (counts or data layer) provided by TrainModelsFromSeurat
#' @param celltype The celltype (provided by TrainModelsFromSeurat) used as classifier's positive prediction
#' @param hyperparameter_tuning logical that determines whether or not hyperparameter tuning should be performed.
#' @param learner The mlr3 learner that should be used. Currently fixed to "classif.ranger" if hyperparameter tuning is FALSE. Otherwise, "classif.xgboost" and "classif.ranger" are supported.
Expand Down Expand Up @@ -161,7 +161,7 @@ TrainModel <- function(training_matrix, celltype, hyperparameter_tuning = F, lea
#' @param seuratObj The Seurat Object to be updated
#' @param celltype_column The metadata column containing the celltypes. One classifier will be created for each celltype present in this column.
#' @param assay SeuratObj assay containing the desired count matrix/metadata
#' @param slot Slot containing the count data. Should be restricted to counts, data, or scale.data.
#' @param layer Layer containing the count data. Should be restricted to counts, data, or scale.data.
#' @param output_dir The directory in which models, metrics, and training data will be saved.
#' @param hyperparameter_tuning Logical that determines whether or not hyperparameter tuning should be performed.
#' @param learner The mlr3 learner that should be used. Currently fixed to "classif.ranger" if hyperparameter tuning is FALSE. Otherwise, "classif.xgboost" and "classif.ranger" are supported.
Expand All @@ -178,7 +178,7 @@ TrainModel <- function(training_matrix, celltype, hyperparameter_tuning = F, lea
#' @param verbose Whether or not to print the metrics data for each model after training.
#' @param min_cells_per_class If provided, any classes (and corresponding cells) with fewer than this many cells will be dropped from the training data
#' @export
TrainModelsFromSeurat <- function(seuratObj, celltype_column, assay = "RNA", slot = "data", output_dir = "./classifiers", hyperparameter_tuning = F, learner = "classif.ranger", inner_resampling = "cv", outer_resampling = "cv", inner_folds = 4, inner_ratio = 0.8, outer_folds = 3, outer_ratio = 0.8, n_models = 20, n_cores = NULL, gene_list = NULL, gene_exclusion_list = NULL, verbose = TRUE, min_cells_per_class = 20){
TrainModelsFromSeurat <- function(seuratObj, celltype_column, assay = "RNA", layer = "data", output_dir = "./classifiers", hyperparameter_tuning = F, learner = "classif.ranger", inner_resampling = "cv", outer_resampling = "cv", inner_folds = 4, inner_ratio = 0.8, outer_folds = 3, outer_ratio = 0.8, n_models = 20, n_cores = NULL, gene_list = NULL, gene_exclusion_list = NULL, verbose = TRUE, min_cells_per_class = 20){
if (methods::missingArg(celltype_column)) {
stop('Must provide the celltype_column argument')
}
Expand All @@ -196,7 +196,7 @@ TrainModelsFromSeurat <- function(seuratObj, celltype_column, assay = "RNA", slo
}

#Read the raw data from a seurat object and parse into an mlr3-compatible labeled matrix
raw_data_matrix <- attr(x = seuratObj@assays[[assay]], which = slot)
raw_data_matrix <- attr(x = seuratObj@assays[[assay]], which = layer)
if (!all(is.null(gene_list))) {
gene_list <- ExpandGeneList(gene_list)
if (!all(gene_list %in% rownames(raw_data_matrix))) {
Expand Down Expand Up @@ -352,7 +352,7 @@ ScoreCellsWithSavedModel <- function(seuratObj, model, fieldToClass, batchSize =
classifier <- .ResolveModel(modelFile = model)

#De-sparse and transpose seuratObj normalized data & make names unique
gene_expression_matrix <- Matrix::t(Seurat::GetAssayData(seuratObj, assay = assayName, slot = "data"))
gene_expression_matrix <- Matrix::t(Seurat::GetAssayData(seuratObj, assay = assayName, layer = "data"))

# NOTE: makeNames() will convert hyphen to period, and also prefix genes with numeric starts, like 7SK.2 -> X7SK.2
colnames(gene_expression_matrix) <- make.names(colnames(gene_expression_matrix))
Expand Down Expand Up @@ -572,11 +572,12 @@ InterpretModels <- function(output_dir= "./classifiers", plot_type = "ratio"){
#' @param seuratObj The Seurat Object to be updated
#' @param model The trained sPLSDA model to use for prediction. This can be a file path to an RDS file, or a built-in model name.
#' @param modelList A list of trained sPLSDA models to use for prediction. This can be a list of file paths to RDS files, or built-in model names.
#' @param removePreexistingColumns If true, any columns
#' @import nnet
#' @return A Seurat object with the sPLSDA scores and predicted probabilities added to the metadata
#' @export

PredictTcellActivation <- function(seuratObj, model = NULL, modelList = NULL) {
PredictTcellActivation <- function(seuratObj, model = NULL, modelList = NULL, removePreexistingColumns = TRUE) {
#################
### Sanitize ###
#################
Expand Down Expand Up @@ -638,6 +639,15 @@ PredictTcellActivation <- function(seuratObj, model = NULL, modelList = NULL) {
#check if component scores already exist in metadata and score if missing or contain NAs
metadataNames <- colnames(seuratObj@meta.data)
for (modelName in names(modelList)) {
if (removePreexistingColumns) {
toDrop <- grep(names(seuratObj@meta.data), pattern = paste0(modelName, "_Activation_sPLSDA_Score_"), value = TRUE)
if (length(toDrop) > 0) {
print(paste0('Dropping pre-existing columns: ', paste0(toDrop, collapse = ',')))
for (colName in toDrop) {
seuratObj[[toDrop]] <- NULL
}
}
}
#determine number of components for this model
nComponents <- .GetModelComponentCount(modelName)
modelVersion <- modelVersions[[modelName]]
Expand Down
4 changes: 2 additions & 2 deletions R/GeneModules.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ PlotUcellCorrelation <- function(seuratObj, toCalculate, assayName = 'RNA') {
# Drop any genes with all zeros
genesToSkip <- NULL
for (gene in geneList) {
if (sum(Seurat::GetAssayData(seuratObj, assay = assayName, slot = 'data')[gene,] > 0) == 0) {
if (sum(Seurat::GetAssayData(seuratObj, assay = assayName, layer = 'data')[gene,] > 0) == 0) {
genesToSkip <- c(genesToSkip, gene)
}
}
Expand All @@ -126,7 +126,7 @@ PlotUcellCorrelation <- function(seuratObj, toCalculate, assayName = 'RNA') {
next
}

geneData <- as.data.frame(t(as.matrix(Seurat::GetAssayData(seuratObj, assay = assayName, slot = 'data')[geneList,, drop = FALSE])))
geneData <- as.data.frame(t(as.matrix(Seurat::GetAssayData(seuratObj, assay = assayName, layer = 'data')[geneList,, drop = FALSE])))
if (! paste0(moduleName, '_UCell') %in% names(seuratObj@meta.data)) {
stop(paste0('Missing column: ', paste0(moduleName, '_UCell')))
}
Expand Down
6 changes: 3 additions & 3 deletions R/Utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ GetSeed <- function() {
return(pkg.env$RANDOM_SEED)
}

SeuratToMatrix <- function(seuratObj, outDir, assayName, slot = 'counts'){
SeuratToMatrix <- function(seuratObj, outDir, assayName, layer = 'counts'){
if (endsWith(outDir, "/")){
outDir <- gsub(outDir, pattern = "/$", replacement = "")
}

DropletUtils::write10xCounts(x = Seurat::GetAssayData(seuratObj, assay = assayName, slot = slot), path = outDir, overwrite = TRUE, type = 'sparse')
DropletUtils::write10xCounts(x = Seurat::GetAssayData(seuratObj, assay = assayName, layer = layer), path = outDir, overwrite = TRUE, type = 'sparse')

return(paste0(outDir, '/matrix.mtx'))
}
Expand Down Expand Up @@ -57,7 +57,7 @@ SeuratToMatrix <- function(seuratObj, outDir, assayName, slot = 'counts'){

assayObj <- Seurat::GetAssay(seuratObj, assay = assayName)
if (class(assayObj)[1] == 'Assay') {
return(!identical(Seurat::GetAssayData(seuratObj, assay = assayName, slot = 'counts'), Seurat::GetAssayData(seuratObj, assay = assayName, slot = 'data')))
return(!identical(Seurat::GetAssayData(seuratObj, assay = assayName, layer = 'counts'), Seurat::GetAssayData(seuratObj, assay = assayName, layer = 'data')))
} else if (class(assayObj)[1] == 'Assay5') {
return('data' %in% names(assayObj@layers))
} else {
Expand Down
9 changes: 8 additions & 1 deletion man/PredictTcellActivation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/TrainModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/TrainModelsFromSeurat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.