Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ scripts/
# Launch directories (local only)
launch/
launch-video/

# Reference implementations (local only)
trop_avg_ref/
3 changes: 3 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ test bootstrap::tests::test_webb_mean_approx_zero ... ok
- `TROPResults` - Results with ATT, factors, loadings, unit/time weights
- `trop()` - Convenience function for quick estimation
- Three robustness components: factor adjustment, unit weights, time weights
- Two estimation methods via `method` parameter:
- `"twostep"` (default): Per-observation model fitting (Algorithm 2 of paper)
- `"joint"`: Weighted least squares with homogeneous treatment effect (faster)
- Automatic rank selection via cross-validation, information criterion, or elbow detection
- Bootstrap and placebo-based variance estimation

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ trop = TROP(

```python
TROP(
method='twostep', # Estimation method: 'twostep' (default) or 'joint'
lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5])
lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5])
lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10])
Expand All @@ -1279,6 +1280,10 @@ TROP(
)
```

**Estimation methods:**
- `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive.
- `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects.

**Convenience function:**

```python
Expand Down
20 changes: 16 additions & 4 deletions diff_diff/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
project_simplex as _rust_project_simplex,
solve_ols as _rust_solve_ols,
compute_robust_vcov as _rust_compute_robust_vcov,
# TROP estimator acceleration
# TROP estimator acceleration (twostep method)
compute_unit_distance_matrix as _rust_unit_distance_matrix,
loocv_grid_search as _rust_loocv_grid_search,
bootstrap_trop_variance as _rust_bootstrap_trop_variance,
# TROP estimator acceleration (joint method)
loocv_grid_search_joint as _rust_loocv_grid_search_joint,
bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint,
)
_rust_available = True
except ImportError:
Expand All @@ -36,10 +39,13 @@
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
# TROP estimator acceleration
# TROP estimator acceleration (twostep method)
_rust_unit_distance_matrix = None
_rust_loocv_grid_search = None
_rust_bootstrap_trop_variance = None
# TROP estimator acceleration (joint method)
_rust_loocv_grid_search_joint = None
_rust_bootstrap_trop_variance_joint = None

# Determine final backend based on environment variable and availability
if _backend_env == 'python':
Expand All @@ -50,10 +56,13 @@
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
# TROP estimator acceleration
# TROP estimator acceleration (twostep method)
_rust_unit_distance_matrix = None
_rust_loocv_grid_search = None
_rust_bootstrap_trop_variance = None
# TROP estimator acceleration (joint method)
_rust_loocv_grid_search_joint = None
_rust_bootstrap_trop_variance_joint = None
elif _backend_env == 'rust':
# Force Rust mode - fail if not available
if not _rust_available:
Expand All @@ -73,8 +82,11 @@
'_rust_project_simplex',
'_rust_solve_ols',
'_rust_compute_robust_vcov',
# TROP estimator acceleration
# TROP estimator acceleration (twostep method)
'_rust_unit_distance_matrix',
'_rust_loocv_grid_search',
'_rust_bootstrap_trop_variance',
# TROP estimator acceleration (joint method)
'_rust_loocv_grid_search_joint',
'_rust_bootstrap_trop_variance_joint',
]
Loading