Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions dash_bio/component_factory/_clustergram.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,11 @@ def figure(self, computed_traces=None):
col_dendro_traces_max_y = np.concatenate(col_dendro_traces_y).max()

# ensure that everything is aligned properly
# with the heatmap
# with the heatmap and dendrograms zoom synchronously
yaxis9 = fig["layout"]["yaxis9"] # pylint: disable=invalid-sequence-index
yaxis9.update(scaleanchor="y11")
yaxis9.update(scaleanchor="y11", matches="y11")
xaxis3 = fig["layout"]["xaxis3"] # pylint: disable=invalid-sequence-index
xaxis3.update(scaleanchor="x11")
xaxis3.update(scaleanchor="x11", matches="x11")

if len(tickvals_col) == 0:
tickvals_col = [10 * i + 5 for i in range(len(self._column_ids))]
Expand All @@ -576,9 +576,10 @@ def figure(self, computed_traces=None):
showticklabels=True,
side="bottom",
showline=False,
range=[min(tickvals_col) - 5, max(tickvals_col) + 5]
range=[min(tickvals_col) - 5, max(tickvals_col) + 5],
# workaround for autoscale issues above; otherwise
# the graph cuts off and must be scaled manually
fixedrange=False
)

if len(tickvals_row) == 0:
Expand All @@ -592,17 +593,20 @@ def figure(self, computed_traces=None):
showticklabels=True,
side="right",
showline=False,
ticks="",
showgrid=False,
fixedrange=False
)

# hide labels, if necessary
for label in self._hidden_labels:
fig["layout"][label].update(ticks="", showticklabels=False)

row_colors_heatmap = self._get_row_colors_heatmap()
row_colors_heatmap = self._get_row_colors_heatmap(tickvals_row)
if row_colors_heatmap is not None:
fig.append_trace(self._get_row_colors_heatmap(), 3, 2)
fig.append_trace(row_colors_heatmap, 3, 2)

col_colors_heatmap = self._get_column_colors_heatmap()
col_colors_heatmap = self._get_column_colors_heatmap(tickvals_col)
if col_colors_heatmap is not None:
fig.append_trace(col_colors_heatmap, 2, 3)

Expand Down Expand Up @@ -712,6 +716,43 @@ def figure(self, computed_traces=None):
domain=[0, 1 - col_ratio - col_colors_ratio]
)

# Link color heatmap axes to main heatmap for zoom synchronization
# Using 'matches' to ensure the same coordinate system
# Row colors (yaxis10) matches main heatmap y-axis (yaxis11)
if len(tickvals_row) > 0:
fig["layout"]["yaxis10"].update(
matches="y11",
range=[min(tickvals_row), max(tickvals_row)],
showticklabels=False,
ticks="",
showgrid=False,
tickmode="array",
tickvals=[],
ticktext=[]
)
# Similar setup for column colors: (xaxis7 and xaxis6) match main heatmap x-axis (xaxis11)
if len(tickvals_col) > 0:
fig["layout"]["xaxis7"].update(
matches="x11",
range=[min(tickvals_col), max(tickvals_col)],
showticklabels=False,
ticks="",
showgrid=False,
tickmode="array",
tickvals=[],
ticktext=[]
)
fig["layout"]["xaxis6"].update(
matches="x11",
range=[min(tickvals_col), max(tickvals_col)],
showticklabels=False,
ticks="",
showgrid=False,
tickmode="array",
tickvals=[],
ticktext=[]
)

fig["layout"][
"legend"
] = dict( # pylint: disable=unsupported-assignment-operation
Expand Down Expand Up @@ -833,7 +874,7 @@ def _get_clusters(self):

return (Zcol, Zrow)

def _get_row_colors_heatmap(self):
def _get_row_colors_heatmap(self, tickvals_row=None):
colors = self._row_colors

if colors is None:
Expand All @@ -854,14 +895,21 @@ def _get_row_colors_heatmap(self):

z = [[i] for i in range(len(colors))]

return go.Heatmap(
z=z,
colorscale=colorscale,
colorbar={"xpad": 100},
showscale=False
)
heatmap_kwargs = {
"z": z,
"colorscale": colorscale,
"colorbar": {"xpad": 100},
"showscale": False
}

# Use the same y-coordinates as the main heatmap for proper
# zoom synchronization
if tickvals_row is not None:
heatmap_kwargs["y"] = tickvals_row

def _get_column_colors_heatmap(self):
return go.Heatmap(**heatmap_kwargs)

def _get_column_colors_heatmap(self, tickvals_col=None):
colors = self._column_colors

if colors is None:
Expand All @@ -882,12 +930,19 @@ def _get_column_colors_heatmap(self):

z = [[i * 5 for i in range(len(colors))]]

return go.Heatmap(
z=z,
colorscale=colorscale,
colorbar={"xpad": 100},
showscale=False
)
heatmap_kwargs = {
"z": z,
"colorscale": colorscale,
"colorbar": {"xpad": 100},
"showscale": False
}

# Use the same x-coordinates as the main heatmap for proper
# zoom synchronization
if tickvals_col is not None:
heatmap_kwargs["x"] = tickvals_col

return go.Heatmap(**heatmap_kwargs)

def _compute_clustered_data(self):
"""Get the traces that need to be plotted for the row and column
Expand Down