Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se
- Event study aggregation and pre-treatment placebo effects
- Multiplier bootstrap for valid inference in staggered settings

**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`.
**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). "Better Understanding Triple Differences Estimators." *Working Paper*. R package: `triplediff`.

### Enhanced Visualization

Expand Down
3 changes: 3 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ Deferred items from PR reviews that were not addressed before merge.
| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. | `prep_dgp.py`, `power.py` | #208 | Low |
| Survey design resolution/collapse patterns inconsistent across panel estimators — extract shared helpers for panel-to-unit collapse, post-filter re-resolution, metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low |
| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` | — | Low |
| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium |
| StaggeredTripleDifference R parity: benchmark only tests no-covariate path (xformla=~1). Add covariate-adjusted scenarios and aggregation SE parity assertions. | `benchmarks/R/benchmark_staggered_triplediff.R` | #245 | Medium |
| StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low |

#### Performance

Expand Down
147 changes: 147 additions & 0 deletions benchmarks/R/benchmark_staggered_triplediff.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env Rscript
# Benchmark: Staggered Triple Difference (R `triplediff` package)
#
# Generates golden values for cross-validation against Python
# StaggeredTripleDifference estimator.
#
# Usage:
# Rscript benchmark_staggered_triplediff.R

library(triplediff)
library(jsonlite)
library(data.table)

cat("=== Staggered DDD Benchmark Generator ===\n")

output_dir <- file.path(dirname(dirname(getwd())), "benchmarks", "data", "synthetic")
# Handle running from project root or benchmarks/R
if (!dir.exists(output_dir)) {
output_dir <- "benchmarks/data/synthetic"
}
if (!dir.exists(output_dir)) {
dir.create(output_dir, recursive = TRUE)
}

results <- list()

# Scenario definitions
scenarios <- list(
list(seed=42, dgp=1, method="dr", cg="nevertreated", key="s42_dgp1_dr_nt"),
list(seed=42, dgp=1, method="ipw", cg="nevertreated", key="s42_dgp1_ipw_nt"),
list(seed=42, dgp=1, method="reg", cg="nevertreated", key="s42_dgp1_reg_nt"),
list(seed=42, dgp=1, method="dr", cg="notyettreated", key="s42_dgp1_dr_nyt"),
list(seed=42, dgp=1, method="ipw", cg="notyettreated", key="s42_dgp1_ipw_nyt"),
list(seed=42, dgp=1, method="reg", cg="notyettreated", key="s42_dgp1_reg_nyt"),
list(seed=123, dgp=1, method="dr", cg="nevertreated", key="s123_dgp1_dr_nt"),
list(seed=123, dgp=1, method="dr", cg="notyettreated", key="s123_dgp1_dr_nyt"),
list(seed=99, dgp=1, method="dr", cg="nevertreated", key="s99_dgp1_dr_nt"),
list(seed=99, dgp=1, method="dr", cg="notyettreated", key="s99_dgp1_dr_nyt")
)

for (sc in scenarios) {
cat(sprintf(" Running scenario: %s ...\n", sc$key))

set.seed(sc$seed)
dgp <- gen_dgp_mult_periods(size = 500, dgp_type = sc$dgp)
data <- dgp$data

# Save data CSV (one per seed+dgp combo, reused across methods)
data_key <- sprintf("s%d_dgp%d", sc$seed, sc$dgp)
csv_path <- file.path(output_dir, sprintf("staggered_ddd_data_%s.csv", data_key))
if (!file.exists(csv_path)) {
fwrite(data, csv_path)
cat(sprintf(" Saved data: %s\n", csv_path))
}

# Run DDD estimation
res <- tryCatch({
ddd(yname = "y", tname = "time", idname = "id",
gname = "state", pname = "partition",
xformla = ~1, # no covariates for cross-validation
data = data,
control_group = sc$cg,
base_period = "varying",
est_method = sc$method,
panel = TRUE)
}, error = function(e) {
cat(sprintf(" ERROR: %s\n", e$message))
return(NULL)
})

if (is.null(res)) next

# Group-time results
gt_results <- data.frame(
group = res$groups,
period = res$periods,
att = res$ATT,
se = res$se
)

# Event study aggregation
agg_es <- tryCatch({
agg_ddd(res, type = "eventstudy")
}, error = function(e) {
cat(sprintf(" Event study agg failed: %s\n", e$message))
NULL
})

es_results <- NULL
overall_att_es <- NA
overall_se_es <- NA
if (!is.null(agg_es)) {
a <- agg_es$aggte_ddd
es_results <- data.frame(
event_time = a$egt,
att = a$att.egt,
se = a$se.egt
)
overall_att_es <- a$overall.att
overall_se_es <- a$overall.se
}

# Simple aggregation
agg_simple <- tryCatch({
agg_ddd(res, type = "simple")
}, error = function(e) {
cat(sprintf(" Simple agg failed: %s\n", e$message))
NULL
})

overall_att_simple <- NA
overall_se_simple <- NA
if (!is.null(agg_simple)) {
a <- agg_simple$aggte_ddd
overall_att_simple <- a$overall.att
overall_se_simple <- a$overall.se
}

# Store results
results[[sc$key]] <- list(
seed = sc$seed,
dgp_type = sc$dgp,
est_method = sc$method,
control_group = sc$cg,
n = res$n,
gt_att = as.list(gt_results$att),
gt_se = as.list(gt_results$se),
gt_groups = as.list(gt_results$group),
gt_periods = as.list(gt_results$period),
overall_att_simple = overall_att_simple,
overall_se_simple = overall_se_simple,
overall_att_es = overall_att_es,
overall_se_es = overall_se_es,
es_event_times = if (!is.null(es_results)) as.list(es_results$event_time) else NULL,
es_att = if (!is.null(es_results)) as.list(es_results$att) else NULL,
es_se = if (!is.null(es_results)) as.list(es_results$se) else NULL
)

cat(sprintf(" GT ATT: %s\n", paste(round(res$ATT, 4), collapse=", ")))
cat(sprintf(" Overall ATT (simple): %.4f\n", overall_att_simple))
}

# Save all results as JSON
json_path <- file.path(output_dir, "staggered_ddd_r_results.json")
writeLines(toJSON(results, auto_unbox = TRUE, pretty = TRUE, digits = 10), json_path)
cat(sprintf("\nResults saved to: %s\n", json_path))
cat("Done.\n")
Loading
Loading