I cannot figure how to capture scroll event on YAxis set from a twinx command.
On axe created from twin1 = ax.twinx()
, the detect_artist with ax.yaxis.contains(event)[0]
is None.
Here is the code:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axis import XAxis, YAxis
#--------------------------------------------------------------
def detect_artist(event):
"""
Detect whether the event occurred on an axis (X or Y) or inside an Axes.
Returns the detected element (XAxis, YAxis, or Axes) or None.
"""
for ax in axs:
# Check if the event occurred on the X or Y axis ticks
if ax.xaxis.contains(event)[0]:
return ax.xaxis
if ax.yaxis.contains(event)[0]:
return ax.yaxis
# Check if the event occurred inside the Axes itself
if ax.contains(event)[0]:
return ax
return None
#--------------------------------------------------------------
def on_scroll(event):
"""
Handle scroll events to zoom in or out on the detected axis.
"""
scale_factor = 0.9 if event.button == 'up' else 1.1
artist = detect_artist(event) # Detect the Artist element under the mouse
print(artist)
if artist is None:
return # No axis detected, do nothing
if isinstance(artist, XAxis):
ax = artist.axes
# Zoom on the X axis
cur_xlim = ax.get_xlim()
xdata = event.xdata if event.xdata is not None else (cur_xlim[0] + cur_xlim[1]) / 2
new_xlim = [xdata - (xdata - cur_xlim[0]) * scale_factor,
xdata + (cur_xlim[1] - xdata) * scale_factor]
ax.set_xlim(new_xlim)
ax.figure.canvas.draw() # Redraw the canvas
print(f"Scrolled on XAxis of Axes: {ax}")
elif isinstance(artist, YAxis):
# Find the parent Axes of this YAxis
ax = artist.axes
ylim = ax.get_ylim()
center = (ylim[0] + ylim[1]) / 2
range_ = (ylim[1] - ylim[0]) * scale_factor / 2
ax.set_ylim(center - range_, center + range_)
ax.figure.canvas.draw() # Redraw the canvas
print(f"Scrolled on YAxis of Axes: {ax}")
#--------------------------------------------------------------
# Data
x = np.array([0, 1, 2])
y1 = np.array([0, 1, 2]) # Density
y2 = np.array([0, 3, 2]) # Temperature
y3 = np.array([50, 30, 15]) # Velocity
# Figure and Axes
fig, ax = plt.subplots(1, 1)
fig.subplots_adjust(left=0.4)
# Create twin axes
twin1 = ax.twinx()
twin1.spines["left"].set_position(("axes", -0.3))
twin1.spines["left"].set_visible(True)
twin1.yaxis.set_label_position("left")
twin1.yaxis.set_ticks_position("left")
twin2 = ax.twinx()
twin2.spines["left"].set_position(("axes", -0.5))
twin2.spines["left"].set_visible(True)
twin2.yaxis.set_label_position("left")
twin2.yaxis.set_ticks_position("left")
# Add all axes to a list for event handling
axs = [ax, twin1, twin2]
# Plot data
ax.plot(x, y1, color='red', label="Density")
twin1.plot(x, y2, color='blue', label="Temperature")
twin2.plot(x, y3, color='green', label="Velocity")
# Set axis labels
ax.set(xlabel="Distance", ylabel="Density")
twin1.set(ylabel="Temperature")
twin2.set(ylabel="Velocity")
# Color labels
ax.yaxis.label.set_color('red')
twin1.yaxis.label.set_color('blue')
twin2.yaxis.label.set_color('green')
# Connect event
fig.canvas.mpl_connect('scroll_event', on_scroll)
plt.show()