Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions diff_diff/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,14 @@ def _compute_robust_vcov_numpy(
# Vectorized meat computation: X' diag(u^2) X = (X * u^2)' X
meat = X.T @ (X * u_squared[:, np.newaxis])
else:
# Cluster-robust standard errors (vectorized via groupby)
# Cluster-robust standard errors (vectorized via NumPy aggregation)
cluster_ids = np.asarray(cluster_ids)
unique_clusters = np.unique(cluster_ids)
n_clusters = len(unique_clusters)
valid_mask = ~pd.isna(cluster_ids)
if not np.all(valid_mask):
cluster_ids = cluster_ids[valid_mask]
# Factorize to contiguous int codes for fast aggregation
cluster_codes = pd.factorize(cluster_ids, sort=False)[0].astype(np.int64)
n_clusters = int(cluster_codes.max()) + 1 if cluster_codes.size else 0

if n_clusters < 2:
raise ValueError(
Expand All @@ -844,10 +848,12 @@ def _compute_robust_vcov_numpy(
# Compute cluster-level scores: sum of X_i * u_i within each cluster
# scores[i] = X[i] * residuals[i] for each observation
scores = X * residuals[:, np.newaxis] # (n, k)
if not np.all(valid_mask):
scores = scores[valid_mask]

# Sum scores within each cluster using pandas groupby (vectorized)
# This is much faster than looping over clusters
cluster_scores = pd.DataFrame(scores).groupby(cluster_ids).sum().values # (G, k)
# Aggregate by cluster using NumPy (faster than pandas groupby)
cluster_scores = np.zeros((n_clusters, k), dtype=scores.dtype)
np.add.at(cluster_scores, cluster_codes, scores)

# Meat is the outer product sum: sum_g (score_g)(score_g)'
# Equivalent to cluster_scores.T @ cluster_scores
Expand Down