Source code for syelink.plotting.validation

"""Create validation plots showing target positions and gaze offsets.

Shows left and right eye data with error vectors and offset labels.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
from adjustText import adjust_text
from matplotlib import patches
from matplotlib.lines import Line2D
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from PIL import Image

from syelink.plotting.style import DEFAULT_VALIDATION_STYLE

if TYPE_CHECKING:
    from pathlib import Path

    from matplotlib.figure import Figure

    from syelink.models import SessionData
    from syelink.plotting.style import ValidationPlotStyle


[docs] def plot_validation( session: SessionData, validation_index: int = 0, save_path: str | Path | None = None, target_scale: float = 0.015, target_image_path: str | Path | None = None, style: ValidationPlotStyle | None = None, show_connectors: bool = False, ) -> Figure: """Plot a single validation showing left and right eye data. Args: session: SessionData object containing validation data validation_index: Index of validation to plot (0-based) save_path: Optional path to save the plot target_scale: Scaling factor for target image size (default: 0.015) target_image_path: Optional path to custom target image style: Optional ValidationPlotStyle for customizing colors, markers, etc. show_connectors: Whether to draw connector lines from original label positions to adjusted label positions (default: False) Returns: matplotlib Figure object Example: >>> from syelink.plotting import plot_validation, ValidationPlotStyle >>> style = ValidationPlotStyle(color_left="red", color_right="green") >>> fig = plot_validation(session, style=style, show_connectors=True) """ if style is None: style = DEFAULT_VALIDATION_STYLE # Get display coordinates if not session.display_coords: msg = "Session data missing display_coords" raise ValueError(msg) dc = session.display_coords screen_w = dc.width screen_h = dc.height # Get validation data validation = session.validations[validation_index] if not validation.targets: msg = f"Validation {validation_index} has no target data" raise ValueError(msg) # Get targets targets = validation.targets.targets target_x = [t[0] for t in targets] target_y = [t[1] for t in targets] # Get validation points points = validation.points # Separate left and right eye data left_data = [p for p in points if p.eye == "LEFT"] right_data = [p for p in points if p.eye == "RIGHT"] # Get summary statistics summary_left = validation.summary_left summary_right = validation.summary_right # Create figure fig, ax = plt.subplots(figsize=style.figsize) # Calculate axis limits with padding all_gaze_x = [p.gaze_x for p in points] all_gaze_y = [p.gaze_y for p in points] padding = 20 xlim = [min(0, *all_gaze_x) - padding, max(screen_w, *all_gaze_x) + padding] ylim = [min(0, *all_gaze_y) - padding, max(screen_h, *all_gaze_y) + padding] # Draw screen boundary screen_rect = patches.Rectangle( (0, 0), screen_w, screen_h, linewidth=2, edgecolor=style.color_screen, facecolor="none", ) ax.add_patch(screen_rect) # Load and display target image if provided, otherwise use simple markers if target_image_path: target_img = Image.open(target_image_path) zoom = target_scale * (xlim[1] - xlim[0]) / target_img.width for tx, ty in zip(target_x, target_y, strict=False): imagebox = OffsetImage(target_img, zoom=zoom) ab = AnnotationBbox(imagebox, (tx, ty), frameon=False, pad=0) ax.add_artist(ab) else: # Fallback: plot target markers as simple crosses ax.scatter( target_x, target_y, c=style.color_target, marker="x", s=style.marker_size, linewidths=style.marker_linewidth, zorder=1, ) # Collect text labels for adjustment texts = [] # Plot left eye data for p in left_data: target = targets[p.point_number] tx, ty = target[0], target[1] gx, gy = p.gaze_x, p.gaze_y # Draw line from target to gaze point ax.plot( [tx, gx], [ty, gy], color=style.color_left, linewidth=style.line_width, linestyle=style.line_style, alpha=style.line_alpha, zorder=2, ) # Draw gaze point marker ax.scatter( gx, gy, c=style.color_left, marker=style.marker, s=style.marker_size, linewidths=style.marker_linewidth, zorder=3, ) # Add label if style.show_labels: label = f"{p.offset_deg:.2f}" # Place label at gaze point (let adjustText handle overlap) text = ax.text( gx, gy, label, fontsize=style.label_fontsize, fontweight=style.label_fontweight, color=style.color_left, zorder=4, ) texts.append(text) # Plot right eye data for p in right_data: target = targets[p.point_number] tx, ty = target[0], target[1] gx, gy = p.gaze_x, p.gaze_y # Draw line from target to gaze point ax.plot( [tx, gx], [ty, gy], color=style.color_right, linewidth=style.line_width, linestyle=style.line_style, alpha=style.line_alpha, zorder=2, ) # Draw gaze point marker ax.scatter( gx, gy, c=style.color_right, marker=style.marker, s=style.marker_size, linewidths=style.marker_linewidth, zorder=3, ) # Add label if style.show_labels: label = f"{p.offset_deg:.2f}" # Place label at gaze point (let adjustText handle overlap) text = ax.text( gx, gy, label, fontsize=style.label_fontsize, fontweight=style.label_fontweight, color=style.color_right, zorder=4, ) texts.append(text) # Store original text positions before adjustment original_positions = [(t.get_position(), t) for t in texts] # Adjust text labels to avoid overlaps if texts: # Prepare target positions for adjustText points argument target_points = [(float(tx), float(ty)) for tx, ty in zip(target_x, target_y, strict=False)] adjust_text( texts, points=target_points, expand=(1.1, 1.1), force_points=(1.5, 1.5), force_text=(0.7, 0.7), ax=ax, ) # Draw connector lines only if requested if show_connectors: for orig_pos, text in original_positions: new_pos = text.get_position() # Only draw line if text was actually moved if abs(orig_pos[0] - new_pos[0]) > 1 or abs(orig_pos[1] - new_pos[1]) > 1: ax.plot( [orig_pos[0], new_pos[0]], [orig_pos[1], new_pos[1]], color="gray", lw=0.5, alpha=0.5, zorder=3, ) # Set axis properties ax.set_xlim(xlim) ax.set_ylim(ylim) ax.invert_yaxis() # Invert y-axis to match screen coordinates ax.set_aspect("equal", adjustable="box") ax.set_xlabel("X (pixels)", fontsize=12) ax.set_ylabel("Y (pixels)", fontsize=12) # Create title with summary statistics title_parts = [f"Validation #{validation_index + 1}"] if summary_left: title_parts.append(f"Left eye: {summary_left.error_avg_deg:.2f}° avg, {summary_left.error_max_deg:.2f}° max") if summary_right: title_parts.append( f"Right eye: {summary_right.error_avg_deg:.2f}° avg, {summary_right.error_max_deg:.2f}° max" ) ax.set_title("\n".join(title_parts), fontsize=style.title_fontsize, fontweight="bold", pad=20) # Add legend if style.show_legend: legend_elements = [ Line2D( [0], [0], marker=style.marker, color=style.color_left, markerfacecolor=style.color_left, markeredgecolor=style.color_left, markersize=12, markeredgewidth=2, label="Left eye", linestyle=style.line_style, linewidth=style.line_width, ), Line2D( [0], [0], marker=style.marker, color=style.color_right, markerfacecolor=style.color_right, markeredgecolor=style.color_right, markersize=12, markeredgewidth=2, label="Right eye", linestyle=style.line_style, linewidth=style.line_width, ), Line2D( [0], [0], marker=style.marker, color=style.color_target, markerfacecolor=style.color_target, markeredgecolor=style.color_target, markersize=10, label="Target", linestyle="", ), ] ax.legend(handles=legend_elements, loc=style.legend_loc, fontsize=style.legend_fontsize, framealpha=0.9) plt.tight_layout() # Save if path provided if save_path: plt.savefig(save_path, dpi=style.dpi, bbox_inches="tight", facecolor="white") return fig