Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ Deferred items from PR reviews that were not addressed before merge.

| Issue | Location | PR | Priority |
|-------|----------|----|----------|
| Plotly renderers silently ignore styling kwargs (marker, markersize, linewidth, capsize, ci_linewidth) that the matplotlib backend honors; thread them through or reject when `backend="plotly"` | `visualization/_event_study.py`, `_diagnostic.py`, `_power.py` | #222 | Medium |
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
| ~376 `duplicate object description` Sphinx warnings — restructure `docs/api/*.rst` to avoid duplicate `:members:` + `autosummary` | `docs/api/*.rst` | — | Low |
Expand Down
36 changes: 36 additions & 0 deletions diff_diff/visualization/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,42 @@ def _color_to_rgba(color, alpha=1.0):
)


# Matplotlib marker code -> plotly symbol name mapping
_MPL_TO_PLOTLY_SYMBOL = {
"o": "circle",
"s": "square",
"D": "diamond",
"d": "diamond",
"^": "triangle-up",
"v": "triangle-down",
"<": "triangle-left",
">": "triangle-right",
"p": "pentagon",
"h": "hexagon",
"+": "cross",
"x": "x",
"*": "star",
".": "circle",
}


def _mpl_marker_to_plotly_symbol(marker):
"""Convert a matplotlib marker code to a plotly symbol name.

Parameters
----------
marker : str
Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``).

Returns
-------
str
Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``).
Returns ``"circle"`` for unrecognized markers.
"""
return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle")


# Default color constants
DEFAULT_BLUE = "#2563eb"
DEFAULT_RED = "#dc2626"
Expand Down
27 changes: 23 additions & 4 deletions diff_diff/visualization/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def plot_sensitivity(
bounds_color=bounds_color,
bounds_alpha=bounds_alpha,
ci_color=ci_color,
ci_linewidth=ci_linewidth,
breakdown_color=breakdown_color,
original_color=original_color,
show=show,
Expand Down Expand Up @@ -242,6 +243,7 @@ def _render_sensitivity_plotly(
bounds_color,
bounds_alpha,
ci_color,
ci_linewidth,
breakdown_color,
original_color,
show,
Expand Down Expand Up @@ -291,7 +293,7 @@ def _render_sensitivity_plotly(
x=M_list,
y=list(ci_arr[:, 0]),
mode="lines",
line=dict(color=ci_color, width=1.5),
line=dict(color=ci_color, width=ci_linewidth),
name="Robust CI",
)
)
Expand All @@ -300,7 +302,7 @@ def _render_sensitivity_plotly(
x=M_list,
y=list(ci_arr[:, 1]),
mode="lines",
line=dict(color=ci_color, width=1.5),
line=dict(color=ci_color, width=ci_linewidth),
showlegend=False,
)
)
Expand Down Expand Up @@ -449,6 +451,8 @@ def plot_bacon(
xlabel=xlabel,
ylabel=ylabel,
colors=colors,
marker=marker,
markersize=markersize,
alpha=alpha,
show_weighted_avg=show_weighted_avg,
show_twfe_line=show_twfe_line,
Expand Down Expand Up @@ -699,13 +703,19 @@ def _render_bacon_plotly(
xlabel,
ylabel,
colors,
marker,
markersize,
alpha,
show_weighted_avg,
show_twfe_line,
show,
):
"""Render Bacon decomposition plot with plotly."""
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
from diff_diff.visualization._common import (
_mpl_marker_to_plotly_symbol,
_plotly_default_layout,
_require_plotly,
)

go = _require_plotly()

Expand All @@ -727,6 +737,10 @@ def _render_bacon_plotly(
"later_vs_earlier": "Later vs Earlier (forbidden)",
}

# Convert matplotlib scatter area (points^2) to plotly diameter (px)
plotly_size = max(1, int(round(markersize**0.5)))
symbol = _mpl_marker_to_plotly_symbol(marker)

for ctype, points in by_type.items():
if not points:
continue
Expand All @@ -737,7 +751,12 @@ def _render_bacon_plotly(
x=estimates,
y=weights,
mode="markers",
marker=dict(color=colors[ctype], size=10, opacity=alpha),
marker=dict(
color=colors[ctype],
size=plotly_size,
symbol=symbol,
opacity=alpha,
),
name=labels[ctype],
)
)
Expand Down
28 changes: 24 additions & 4 deletions diff_diff/visualization/_event_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def plot_event_study(
xlabel=xlabel,
ylabel=ylabel,
color=color,
marker=marker,
markersize=markersize,
shade_pre=shade_pre,
shade_color=shade_color,
show_zero_line=show_zero_line,
Expand Down Expand Up @@ -422,6 +424,8 @@ def _render_event_study_plotly(
xlabel,
ylabel,
color,
marker,
markersize,
shade_pre,
shade_color,
show_zero_line,
Expand All @@ -431,6 +435,7 @@ def _render_event_study_plotly(
"""Render event study plot with plotly."""
from diff_diff.visualization._common import (
_color_to_rgba,
_mpl_marker_to_plotly_symbol,
_plotly_default_layout,
_require_plotly,
)
Expand Down Expand Up @@ -504,13 +509,15 @@ def _render_event_study_plotly(

hover_tpl = "Period: %{customdata}<br>Effect: %{y:.4f}<extra></extra>"

symbol = _mpl_marker_to_plotly_symbol(marker)

if non_ref_x:
fig.add_trace(
go.Scatter(
x=non_ref_x,
y=non_ref_e,
mode="markers",
marker=dict(color=color, size=10),
marker=dict(color=color, size=markersize, symbol=symbol),
name="Effect",
customdata=non_ref_labels,
hovertemplate=hover_tpl,
Expand All @@ -525,7 +532,8 @@ def _render_event_study_plotly(
mode="markers",
marker=dict(
color="white",
size=10,
size=markersize,
symbol=symbol,
line=dict(color=color, width=2),
),
name="Reference",
Expand Down Expand Up @@ -842,6 +850,8 @@ def plot_honest_event_study(
ylabel=ylabel,
original_color=original_color,
honest_color=honest_color,
marker=marker,
markersize=markersize,
show=show,
)

Expand Down Expand Up @@ -987,11 +997,14 @@ def _render_honest_event_study_plotly(
ylabel,
original_color,
honest_color,
marker,
markersize,
show,
):
"""Render honest event study plot with plotly."""
from diff_diff.visualization._common import (
_color_to_rgba,
_mpl_marker_to_plotly_symbol,
_plotly_default_layout,
_require_plotly,
)
Expand Down Expand Up @@ -1036,13 +1049,15 @@ def _render_honest_event_study_plotly(
ref_p = [p for p, r in zip(periods, is_ref) if r]
ref_e = [e for e, r in zip(effects, is_ref) if r]

symbol = _mpl_marker_to_plotly_symbol(marker)

if non_ref_p:
fig.add_trace(
go.Scatter(
x=non_ref_p,
y=non_ref_e,
mode="markers",
marker=dict(color=honest_color, size=10),
marker=dict(color=honest_color, size=markersize, symbol=symbol),
name="Effect",
)
)
Expand All @@ -1053,7 +1068,12 @@ def _render_honest_event_study_plotly(
x=ref_p,
y=ref_e,
mode="markers",
marker=dict(color="white", size=10, line=dict(color=honest_color, width=2)),
marker=dict(
color="white",
size=markersize,
symbol=symbol,
line=dict(color=honest_color, width=2),
),
name="Reference",
)
)
Expand Down
18 changes: 14 additions & 4 deletions diff_diff/visualization/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def plot_power_curve(
color=color,
mde_color=mde_color,
target_color=target_color,
linewidth=linewidth,
show_mde_line=show_mde_line,
show_target_line=show_target_line,
show_grid=show_grid,
show=show,
)

Expand Down Expand Up @@ -291,8 +293,10 @@ def _render_power_curve_plotly(
color,
mde_color,
target_color,
linewidth,
show_mde_line,
show_target_line,
show_grid,
show,
):
"""Render power curve with plotly."""
Expand All @@ -307,7 +311,7 @@ def _render_power_curve_plotly(
x=effect_sizes,
y=powers,
mode="lines",
line=dict(color=color, width=2),
line=dict(color=color, width=linewidth),
name="Power",
)
)
Expand All @@ -331,7 +335,8 @@ def _render_power_curve_plotly(
)

_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
fig.update_xaxes(showgrid=show_grid)
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)

if show:
fig.show()
Expand Down Expand Up @@ -482,8 +487,10 @@ def plot_pretrends_power(
color=color,
mdv_color=mdv_color,
target_color=target_color,
linewidth=linewidth,
show_mdv_line=show_mdv_line,
show_target_line=show_target_line,
show_grid=show_grid,
show=show,
)

Expand Down Expand Up @@ -602,8 +609,10 @@ def _render_pretrends_power_plotly(
color,
mdv_color,
target_color,
linewidth,
show_mdv_line,
show_target_line,
show_grid,
show,
):
"""Render pre-trends power curve with plotly."""
Expand All @@ -619,7 +628,7 @@ def _render_pretrends_power_plotly(
x=M_values,
y=powers,
mode="lines",
line=dict(color=color, width=2),
line=dict(color=color, width=linewidth),
name="Power",
)
)
Expand All @@ -643,7 +652,8 @@ def _render_pretrends_power_plotly(
)

_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
fig.update_xaxes(showgrid=show_grid)
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)

if show:
fig.show()
Expand Down
Loading
Loading