Skip to content

Commit df7999f

Browse files
committed
added skim_arrow; fixed parallelization in add_spatial_lags, test_fable_resids
1 parent 1bfaa7c commit df7999f

File tree

11 files changed

+379
-81
lines changed

11 files changed

+379
-81
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(print,skim_arrow)
34
export("%>%")
45
export(":=")
56
export(.data)
@@ -18,6 +19,7 @@ export(expr)
1819
export(get_boot_ci)
1920
export(prewhitened_ccf)
2021
export(scale_by_mase)
22+
export(skim_arrow)
2123
export(sym)
2224
export(syms)
2325
export(test_fable_resids)

R/add-spatial-lags.R

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,6 @@
8787
#'
8888
#' cat(attributes(tib_spat_lags)$summ_wgts_spatlag_1, sep = "\n")
8989
#'
90-
#' rlang::check_installed(
91-
#' "mirai (>= 2.1.0.9000)",
92-
#' action = function(...) {
93-
#' remotes::install_version('mirai',
94-
#' version = ">= 2.1.0.9000",
95-
#' repos = c('https://shikokuchuo.r-universe.dev',
96-
#' 'https://cloud.r-project.org'))
97-
#' }
98-
#' )
9990
#'
10091
#' library(mirai)
10192
#'
@@ -116,12 +107,12 @@
116107

117108

118109
add_spatial_lags <- function(nblist,
119-
y,
120-
.data,
121-
lags,
122-
type = NULL,
123-
parallel = FALSE,
124-
...) {
110+
y,
111+
.data,
112+
lags,
113+
type = NULL,
114+
parallel = FALSE,
115+
...) {
125116

126117
# ---------------- tests ------------------
127118
# Check if nblist is of class "nb"
@@ -151,15 +142,32 @@ add_spatial_lags <- function(nblist,
151142
}
152143
# -----------------------------------------
153144

154-
get_vec_lags <- function(lag_nb, vec_num, .data, lag, type, ...) {
145+
get_vec_lags <- function(lag_nb, vec_num, .data, lag, type, dots) {
155146

156147
# add weights to nb list
157148
if (is.null(type)) {
149+
listw_args <- list(neighbours = lag_nb)
150+
if (length(dots) != 0) {
151+
listw_args <- append(listw_args, dots)
152+
}
158153
ls_wts <-
159-
spdep::nb2listw(lag_nb, ...)
154+
do.call(
155+
spdep::nb2listw,
156+
listw_args
157+
)
160158
} else {
159+
listwdist_args <-
160+
list(neighbours = lag_nb,
161+
x = .data,
162+
type = type)
163+
if (length(dots) != 0) {
164+
listwdist_args <- append(listwdist_args, dots)
165+
}
161166
ls_wts <-
162-
spdep::nb2listwdist(lag_nb, .data, type, ...)
167+
do.call(
168+
spdep::nb2listwdist,
169+
listwdist_args
170+
)
163171
}
164172

165173
# get weights summary
@@ -200,20 +208,25 @@ add_spatial_lags <- function(nblist,
200208
purrr::map2(
201209
lags_nb,
202210
1:lags,
203-
carrier::crate(
211+
purrr::in_parallel(
204212
\(x1, x2) {
205213
get_vec_lags(
206214
x1,
207-
!!vec_num,
208-
!!.data,
215+
vec_num,
216+
.data,
209217
x2,
210-
!!type,
211-
!!!dots
218+
type,
219+
dots
212220
)
213221
},
214-
get_vec_lags = get_vec_lags
222+
get_vec_lags = get_vec_lags,
223+
vec_num = vec_num,
224+
.data = .data,
225+
type = type,
226+
dots = dots,
227+
y = y # not a fun arg, but needed to add to env for glue variable naming
215228
),
216-
.parallel = TRUE
229+
.progress = TRUE
217230
)
218231

219232
} else {
@@ -229,7 +242,7 @@ add_spatial_lags <- function(nblist,
229242
.data,
230243
x2,
231244
type,
232-
...
245+
dots
233246
)
234247
}
235248
)

R/skim-arrow.R

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#' Skim an Arrow Dataset
2+
#'
3+
#' @description
4+
#' Provides a \{skimr\}-style summary of an Arrow Dataset with statistics
5+
#' organized by variable type. Computes summary statistics efficiently using
6+
#' Arrow's query engine without loading the full dataset into memory.
7+
#'
8+
#' @param ds An Arrow Dataset object created with `arrow::open_dataset()`. This would probably work on any \{arrow\} data object with a schema.
9+
#'
10+
#' @return A list of class "skim_arrow" containing:
11+
#' \item{overview}{A tibble with dataset dimensions and column type counts}
12+
#' \item{numeric}{A tibble with statistics for numeric columns (missing_pct, mean, sd, min, max)}
13+
#' \item{character}{A tibble with statistics for character columns (missing_pct, n_unique)}
14+
#' \item{timestamp}{A tibble with statistics for timestamp columns (missing_pct, min, max)}
15+
#'
16+
#' @details
17+
#' The function classifies columns by type and computes appropriate summary
18+
#' statistics for each:
19+
#' \itemize{
20+
#' \item Numeric columns: missing percentage, mean, standard deviation, min, max
21+
#' \item Character columns: missing percentage, number of unique values
22+
#' \item Timestamp columns: missing percentage, min, max (as POSIXct objects)
23+
#' }
24+
#'
25+
#' All computations are performed using Arrow's query engine, making this
26+
#' function efficient even for very large datasets stored in Parquet files.
27+
#'
28+
#' @examples
29+
#' \dontrun{
30+
#' # Open a directory of Parquet files
31+
#' ds <- arrow::open_dataset("path/to/parquet/files")
32+
#'
33+
#' # Get summary statistics
34+
#' summary <- skim_arrow(ds)
35+
#'
36+
#' # View all sections
37+
#' summary
38+
#'
39+
#' # Access specific sections
40+
#' summary$numeric
41+
#' summary$character
42+
#' summary$timestamp
43+
#' }
44+
#'
45+
#' @seealso \code{\link[arrow]{open_dataset}}, \code{\link[skimr]{skim}}
46+
#'
47+
#' @export
48+
skim_arrow <- function(ds) {
49+
50+
# Get schema to identify column types
51+
schema <- ds$schema
52+
col_names <- names(schema)
53+
54+
# Classify columns by type
55+
numeric_cols <- col_names[sapply(schema, function(field) {
56+
type_name <- field$type$ToString()
57+
grepl("int|float|double|decimal", type_name, ignore.case = TRUE)
58+
})]
59+
60+
character_cols <- col_names[sapply(schema, function(field) {
61+
type_name <- field$type$ToString()
62+
grepl("string|utf8", type_name, ignore.case = TRUE)
63+
})]
64+
65+
timestamp_cols <- col_names[sapply(schema, function(field) {
66+
type_name <- field$type$ToString()
67+
grepl("timestamp", type_name, ignore.case = TRUE)
68+
})]
69+
70+
# Build the summary query
71+
result <- ds |>
72+
dplyr::summarize(
73+
# Missingness for ALL columns
74+
dplyr::across(
75+
dplyr::everything(),
76+
~mean(is.na(.)) * 100,
77+
.names = "{.col}_missing_pct"
78+
),
79+
80+
# Numeric column stats
81+
dplyr::across(
82+
dplyr::all_of(numeric_cols),
83+
list(
84+
min = ~min(., na.rm = TRUE),
85+
max = ~max(., na.rm = TRUE),
86+
mean = ~mean(., na.rm = TRUE),
87+
sd = ~sd(., na.rm = TRUE)
88+
),
89+
.names = "{.col}_{.fn}"
90+
),
91+
92+
# Character column stats
93+
dplyr::across(
94+
dplyr::all_of(character_cols),
95+
~dplyr::n_distinct(., na.rm = TRUE),
96+
.names = "{.col}_n_unique"
97+
),
98+
99+
# Timestamp column stats (min/max only)
100+
dplyr::across(
101+
dplyr::all_of(timestamp_cols),
102+
list(
103+
min = ~min(., na.rm = TRUE),
104+
max = ~max(., na.rm = TRUE)
105+
),
106+
.names = "{.col}_{.fn}"
107+
)
108+
) |>
109+
dplyr::collect()
110+
111+
# Create separate tables for each variable type
112+
output <- list()
113+
114+
# Overview table
115+
output$overview <- dplyr::tibble(
116+
n_rows = nrow(ds),
117+
n_cols = length(col_names),
118+
n_numeric = length(numeric_cols),
119+
n_character = length(character_cols),
120+
n_timestamp = length(timestamp_cols)
121+
)
122+
123+
# Numeric variables table
124+
if (length(numeric_cols) > 0) {
125+
numeric_data <- result |>
126+
dplyr::select(dplyr::ends_with("_missing_pct"), dplyr::ends_with(c("_min", "_max", "_mean", "_sd"))) |>
127+
dplyr::select(dplyr::matches(paste0("^(", paste(numeric_cols, collapse = "|"), ")_")))
128+
129+
output$numeric <- dplyr::tibble(
130+
variable = numeric_cols,
131+
missing_pct = as.numeric(numeric_data[1, paste0(numeric_cols, "_missing_pct")]),
132+
mean = as.numeric(numeric_data[1, paste0(numeric_cols, "_mean")]),
133+
sd = as.numeric(numeric_data[1, paste0(numeric_cols, "_sd")]),
134+
min = as.numeric(numeric_data[1, paste0(numeric_cols, "_min")]),
135+
max = as.numeric(numeric_data[1, paste0(numeric_cols, "_max")])
136+
)
137+
}
138+
139+
# Character variables table
140+
if (length(character_cols) > 0) {
141+
char_data <- result |>
142+
dplyr::select(dplyr::matches(paste0("^(", paste(character_cols, collapse = "|"), ")_(missing_pct|n_unique)")))
143+
144+
output$character <- dplyr::tibble(
145+
variable = character_cols,
146+
missing_pct = as.numeric(char_data[1, paste0(character_cols, "_missing_pct")]),
147+
n_unique = as.numeric(char_data[1, paste0(character_cols, "_n_unique")])
148+
)
149+
}
150+
151+
# Timestamp variables table
152+
if (length(timestamp_cols) > 0) {
153+
ts_data <- result |>
154+
dplyr::select(dplyr::matches(paste0("^(", paste(timestamp_cols, collapse = "|"), ")_(missing_pct|min|max)")))
155+
156+
output$timestamp <- dplyr::tibble(
157+
variable = timestamp_cols,
158+
missing_pct = as.numeric(ts_data[1, paste0(timestamp_cols, "_missing_pct")]),
159+
min = as.POSIXct(unlist(ts_data[1, paste0(timestamp_cols, "_min")]), origin = "1970-01-01", tz = "UTC"),
160+
max = as.POSIXct(unlist(ts_data[1, paste0(timestamp_cols, "_max")]), origin = "1970-01-01", tz = "UTC")
161+
)
162+
}
163+
164+
# Set class for custom print method
165+
class(output) <- c("skim_arrow", "list")
166+
167+
return(output)
168+
}
169+
170+
#' Print Method for skim_arrow Objects
171+
#'
172+
#' Provides formatted output for skim_arrow results, displaying summary
173+
#' statistics organized by variable type in a `skimr`-style format.
174+
#'
175+
#' @param x A skim_arrow object (output from `skim_arrow()`)
176+
#' @param ... Additional arguments passed to print methods (currently unused)
177+
#'
178+
#' @return Invisibly returns the input object `x`
179+
#' @keywords internal
180+
#' @export
181+
print.skim_arrow <- function(x, ...) {
182+
cat("\u2500\u2500 Data Summary \u2500\u2500\n\n")
183+
print(x$overview)
184+
185+
if (!is.null(x$numeric)) {
186+
cat("\n\u2500\u2500 Numeric Variables \u2500\u2500\n\n")
187+
print(x$numeric, n = Inf)
188+
}
189+
190+
if (!is.null(x$character)) {
191+
cat("\n\u2500\u2500 Character Variables \u2500\u2500\n\n")
192+
print(x$character, n = Inf)
193+
}
194+
195+
if (!is.null(x$timestamp)) {
196+
cat("\n\u2500\u2500 Timestamp Variables \u2500\u2500\n\n")
197+
print(x$timestamp, n = Inf)
198+
}
199+
200+
invisible(x)
201+
}

R/test-fable-resids.R

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
#'
1919
#' @examples
2020
#'
21-
#' library(dplyr, warn.conflicts = FALSE)
22-
#' library(fable, quietly = TRUE)
23-
#' library(furrr, quietly = TRUE)
24-
#' plan(multisession)
21+
#' library(dplyr, warn.conflicts = FALSE)
22+
#' library(fable, quietly = TRUE)
23+
#' library(mirai)
2524
#'
26-
#' head(ohio_covid)[,1:6]
25+
#' head(ohio_covid)[,1:6]
2726
#'
28-
#' models_dyn <- ohio_covid[ ,1:7] %>%
27+
#' daemons(3)
28+
#'
29+
#' models_dyn <- ohio_covid[ ,1:7] %>%
2930
#' tidyr::pivot_longer(
3031
#' cols = contains("lead"),
3132
#' names_to = "lead",
@@ -37,21 +38,20 @@
3738
#' tidyr::drop_na() %>%
3839
#' tidyr::nest(data = c(date, cases, lead_deaths)) %>%
3940
#' # Run a regression on lagged cases and date vs deaths
40-
#' mutate(model = furrr::future_map(data, function(df) {
41-
#' model(.data = df,
42-
#' dyn_reg = ARIMA(lead_deaths ~ 1 + cases),
43-
#' dyn_reg_trend = ARIMA(lead_deaths ~ 1 + cases + trend()),
44-
#' dyn_reg_quad = ARIMA(lead_deaths ~ 1 + cases + poly(date, 2))
45-
#' )
46-
#' }
47-
#' ))
48-
#' # shut down workers
49-
#' plan(sequential)
41+
#' mutate(model = purrr::map(data, purrr::in_parallel(\(df) {
42+
#' fabletools::model(
43+
#' .data = df,
44+
#' dyn_reg = fable::ARIMA(lead_deaths ~ 1 + cases),
45+
#' dyn_reg_trend = fable::ARIMA(lead_deaths ~ 1 + cases + trend()),
46+
#' dyn_reg_quad = fable::ARIMA(lead_deaths ~ 1 + cases + poly(date, 2))
47+
#' )})))
5048
#'
51-
#' dyn_mod_tbl <- select(models_dyn, -data)
49+
#' # shut down workers
50+
#' daemons(0)
5251
#'
53-
#' fable_resid_res <- test_fable_resids(dyn_mod_tbl, grp_col = "lead", mod_col = "model")
54-
#' head(fable_resid_res)
52+
#' dyn_mod_tbl <- select(models_dyn, -data)
53+
#' fable_resid_res <- test_fable_resids(dyn_mod_tbl, grp_col = "lead", mod_col = "model")
54+
#' head(fable_resid_res)
5555

5656

5757

0 commit comments

Comments
 (0)