diff --git a/CHANGELOG.md b/CHANGELOG.md index 126ff280b49..b296ee382b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Fixed +- Bug was that function marked the axis to be connected, but the trace_kwargs still had unique axes [[#5427](https://github.com/plotly/plotly.py/issues/5427)] +- Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one +- Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix) + ### Fixed - Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)], with thanks to @antonymilne for the contribution! - Update tests to be compatible with numpy 2.4 [[#5522](https://github.com/plotly/plotly.py/pull/5522)], with thanks to @thunze for the contribution! diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 16a3958637e..edb1cc7b468 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -916,8 +916,10 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): if x_or_y == "x": span = spec["colspan"] + match_axis = 'xaxis' else: span = spec["rowspan"] + match_axis = 'yaxis' if subplot_ref.subplot_type == "xy" and span == 1: if first_axis_id is None: @@ -926,6 +928,7 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): else: axis_name = subplot_ref.layout_keys[layout_key_ind] axis_to_match = layout[axis_name] + subplot_ref.trace_kwargs[match_axis] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) axis_to_match.matches = first_axis_id if remove_label: axis_to_match.showticklabels = False @@ -981,6 +984,7 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): first_axis_id = update_axis_matches( first_axis_id, subplot_ref, spec, ok_to_remove_label ) + def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None): diff --git a/tests/test_optional/test_subplots/test_make_subplots.py b/tests/test_optional/test_subplots/test_make_subplots.py index 4552c66a694..23df450388c 100644 --- a/tests/test_optional/test_subplots/test_make_subplots.py +++ b/tests/test_optional/test_subplots/test_make_subplots.py @@ -56,3 +56,202 @@ def test_add_traces_with_integers(self): expected_data_length = 4 self.assertEqual(expected_data_length, len(fig2.data)) + +class TestSharedAxisOnMakeColumn(TestCase): + """ + Regression test for #5427: traces should reference the primary axis + when shared_xaxes=True, so spike lines and hover sync work correctly. + """ + + def test_xaxes_shared_columns_mode_single_column(self): + """ + When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=3, col=1) + + + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_columns_mode_multiple_columns(self): + """ + When 'columns' mode for shared_xaxis, different columns should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=1, col=2) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures that are in the same column should be the same, and different if they are in different columns + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_shared_rows_mode_single_row(self): + """ + When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=1, cols=3, shared_xaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=1, col=3) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_rows_mode_multiple_rows(self): + """ + When 'rows' mode for shared_xaxis, different rows should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures in the same row should be the same, and different if they are in different rows + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis row don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis row match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis row match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_shared_all_mode(self): + """ + When 'all' mode for shared_xaxis, all rows share the same x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") + self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_not_shared_mode(self): + """ + When not shared, all plots have different x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be different + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file