Source code for charted.charts.scatter

from __future__ import annotations

import math
from typing import TYPE_CHECKING, TypedDict, cast

from charted.charts.chart import Chart
from charted.constants import (
    DEFAULT_CHART_HEIGHT,
    DEFAULT_CHART_WIDTH,
    QUADRANT_BOTTOM_MARGIN_FACTOR,
    QUADRANT_LABEL_LINE_GAP,
)
from charted.html.element import Circle, Element, G, Path, Rect, Text
from charted.themes.core import Theme
from charted.utils.types import (
    PointStyleConfig,
    ReferenceLineDict,
    SeriesStyleConfig,
    Vector,
    Vector2D,
)

if TYPE_CHECKING:
    from charted.charts.chart import _Annotation


class _PlacedLabel(TypedDict):
    """A data label's placement state during collision avoidance."""

    text: str
    px: float
    py: float
    cx: float
    cy: float
    w: float
    h: float
    marker: float


[docs] class ScatterChart(Chart): """Scatter plot for displaying relationships between two variables. Plots individual data points at (x, y) coordinates to show correlations, clusters, or distributions. Supports multi-series data with custom marker shapes and sizes. Args: data: Single series (list of y-values with x=indices) or multi-series (list of lists) or list of (x, y) tuples x_data: Optional x-coordinates for each point labels: Optional series names width, height: Chart dimensions in pixels zero_index: Whether to include zero in both axes title: Optional chart title theme: Optional theme configuration series_names: Names for each series (shown in legend) series_styles: Per-series style overrides (marker_shape, marker_size) point_styles: Per-POINT marker overrides, a list of per-series rows mirroring the data shape. Each entry is a ``PointStyleConfig`` (``marker_shape``, ``marker_size``, ``fill``, ``opacity``) or ``None``. Any present field wins over the series-level/shape-cycle resolution; omitted fields fall through. Defaults to None, leaving every point styled by its series (existing behaviour). Data-label colour now comes from the theme's ``data_label_color`` token (override via ``Theme(data_label_color=...)``); the default token reproduces the previous axis-title colour. x_range: Optional (min, max) to fix the x-axis domain instead of deriving it from the data, removing the need for invisible anchor points to control the visible range. y_range: Optional (min, max) to fix the y-axis domain. domain_padding: Optional fraction (e.g. 0.1) padding the data-derived domain by that amount on each side. Ignored on an axis with an explicit range. Defaults to None (no padding). quadrant_label_inset: Extra padding (px) used to inset the four quadrant labels from the plot corners so they clear the axis tick numbers instead of sitting flush. Defaults to 12.0; pass 0 to restore the original flush-corner placement. quadrant_label_backplate: When True, draws a semi-opaque rounded background plate behind each quadrant label for contrast. Defaults to False. shape_cycle: Redundant shape encoding for multi-series scatters so series differ by marker SHAPE as well as colour. Defaults to None (every series uses circles, preserving existing behaviour). Pass True to enable the built-in cycle (circle, square, triangle, diamond, star), or a custom list of shape names to cycle through. A per-series ``marker_shape`` in ``series_styles`` always wins over the cycle. legend: Placement for a series legend that maps each ``series_names`` entry to its marker shape and colour swatch. One of ``'none'`` (default, no legend), ``'right'``, ``'bottom'``, or ``'top'``. When shown, layout space is reserved on that side so the legend never overlaps the plot. Requires ``series_names``; with no names there is nothing to label and the layout is left unchanged. avoid_label_collisions: When True, run a collision-avoidance pass over the data labels so they overlap each other (and their markers) as little as possible, drawing a thin leader line back to a point whenever its label is pushed noticeably away. Defaults to False, which keeps the original fixed-offset label placement so existing renders are unchanged. See ``_render_data_labels`` for the algorithm and its limitations. Example: >>> from charted import ScatterChart >>> # Basic scatter plot >>> chart = ScatterChart(data=[5, 8, 12, 15], x_data=[1, 2, 3, 4]) >>> chart.save('correlation.svg') >>> >>> # Multi-series with custom markers >>> chart = ScatterChart( ... data=[[5, 8, 12], [7, 10, 14]], ... x_data=[1, 2, 3], ... series_styles=[{'marker_shape': 'circle'}, {'marker_shape': 'square'}] ... ) """ # Default marker shapes cycled for redundant (shape + colour) encoding # when ``shape_cycle=True`` and no per-series shape is given. DEFAULT_SHAPE_CYCLE = ["circle", "square", "triangle", "diamond", "star"] @staticmethod def _resolve_shape_cycle( shape_cycle: list[str] | bool | None, ) -> list[str] | None: """Normalise the ``shape_cycle`` argument to a list of shapes or None. None/False disable cycling (every series uses circles). True selects the built-in cycle. A non-empty list is used verbatim. """ if shape_cycle is None or shape_cycle is False: return None if shape_cycle is True: return list(ScatterChart.DEFAULT_SHAPE_CYCLE) if isinstance(shape_cycle, list) and shape_cycle: return list(shape_cycle) return None def __init__( self, x_data: Vector | Vector2D, y_data: Vector | Vector2D, width: float = DEFAULT_CHART_WIDTH, height: float = DEFAULT_CHART_HEIGHT, title: str | None = None, subtitle: str | None = None, subtitle_leading: float = 8.0, theme: Theme | None = None, series_names: list[str] | None = None, series_styles: list[SeriesStyleConfig] | None = None, point_styles: list[list[PointStyleConfig | None]] | None = None, data_labels: list[str] | list[list[str]] | None = None, x_label: str | None = None, y_label: str | None = None, h_lines: list[float] | None = None, v_lines: list[float] | None = None, annotations: list[_Annotation] | None = None, quadrant_labels: list[str] | None = None, quadrant_label_inset: float = 12.0, quadrant_label_backplate: bool = False, shape_cycle: list[str] | bool | None = None, legend: str = "none", x_scale: object | None = None, y_scale: object | None = None, reference_lines: list[ReferenceLineDict] | None = None, colors: list[str] | None = None, x_range: tuple[float, float] | None = None, y_range: tuple[float, float] | None = None, domain_padding: float | None = None, avoid_label_collisions: bool = False, value_labels: bool | str | dict[str, object] | None = None, ): self._avoid_label_collisions = avoid_label_collisions self._point_styles = point_styles self._quadrant_labels = quadrant_labels self._quadrant_label_inset = quadrant_label_inset self._quadrant_label_backplate = quadrant_label_backplate self._shape_cycle = self._resolve_shape_cycle(shape_cycle) super().__init__( y_data=y_data, x_data=x_data, width=width, height=height, title=title, subtitle=subtitle, subtitle_leading=subtitle_leading, theme=theme, series_names=series_names, chart_type="scatter", series_styles=series_styles, data_labels=data_labels, x_label=x_label, y_label=y_label, h_lines=h_lines, v_lines=v_lines, annotations=annotations, x_scale=x_scale, y_scale=y_scale, reference_lines=reference_lines, colors=colors, x_range=x_range, y_range=y_range, domain_padding=domain_padding, value_labels=value_labels, legend=legend, ) # ===================================================================== # Legend (shape + colour mapping, reserved placement) # ===================================================================== def _legend_entries(self) -> list[tuple[str, str, str]]: """Build (name, colour, shape) tuples for each plotted series. Returns an empty list when there are no series names to label, which is what keeps the legend off by default for unnamed scatters. """ names = self.series_names if not names: return [] entries: list[tuple[str, str, str]] = [] for idx, name in enumerate(names): if idx >= len(self.y_values): break color = self.colors[idx] if idx < len(self.colors) else "#000000" shape = self._series_shape(idx) entries.append((str(name), color, shape)) return entries def _series_shape(self, series_idx: int) -> str: """Resolve the marker shape for a series (mirrors ``representation``).""" if self._shape_cycle: shape = self._shape_cycle[series_idx % len(self._shape_cycle)] else: shape = "circle" if self.series_styles and series_idx < len(self.series_styles): style = self.series_styles[series_idx] or {} if style.get("marker_shape"): shape = cast(str, style["marker_shape"]) return shape @property def representation(self) -> G: g = G( opacity=0.8, transform=[*self.get_base_transform()], clip_path="url(#plot-clip)", ) for series_idx, (y_values, y_offsets, x_values, color) in enumerate( zip(self.y_values, self.y_offsets, self.x_values, self.colors), ): # Apply style overrides from series_styles fill = color # Default marker size is 4px. A theme that explicitly sets # marker_size (e.g. high-contrast) raises it for legibility while # the standard themes keep the historical 4px. marker_size: float = 4 if self.theme._is_explicit("marker_size"): marker_size = self.theme.marker_size # Default shape is a circle; with shape_cycle enabled, each series # picks a shape from the cycle (redundant shape + colour encoding). if self._shape_cycle: marker_shape = self._shape_cycle[series_idx % len(self._shape_cycle)] else: marker_shape = "circle" if self.series_styles and series_idx < len(self.series_styles): style = self.series_styles[series_idx] or {} if style.get("fill"): fill = cast(str, style["fill"]) if style.get("marker_size"): marker_size = cast(float, style["marker_size"]) if style.get("marker_shape"): marker_shape = cast(str, style["marker_shape"]) series = G(fill=fill) x_offset = self.x_offset for i, (x, y, y_offset) in enumerate(zip(x_values, y_values, y_offsets)): x += x_offset y = self._apply_stacking(y, y_offset) title = self._tooltip_title(series_idx, i) # Per-point overrides (point_styles) win over the series-level # shape/size/fill resolved above. Each field is independent, so # an omitted field keeps the series value. p_shape = marker_shape p_size = marker_size p_fill = fill p_opacity: float | None = None pstyle = self._point_style(series_idx, i) if pstyle: if pstyle.get("marker_shape"): p_shape = cast(str, pstyle["marker_shape"]) if pstyle.get("marker_size"): p_size = cast(float, pstyle["marker_size"]) if pstyle.get("fill"): p_fill = cast(str, pstyle["fill"]) if pstyle.get("opacity") is not None: p_opacity = pstyle["opacity"] mark = self._marker_element(p_shape, x, y, p_size, p_fill) if mark is not None: # Only set fill on the marker when it differs from the # series group fill, so unstyled points stay byte-identical # (they inherit fill from the enclosing group). if p_fill != fill: mark.kwargs["fill"] = p_fill if p_opacity is not None: mark.kwargs["opacity"] = cast(str, p_opacity) if title is not None: mark.add_child(title) series.add_child(mark) g.add_children(series) # Data labels and quadrant labels rendered outside the clip group # so they don't get clipped at chart edges wrapper = G() wrapper.add_child(g) data_labels_g = self._render_data_labels() if data_labels_g: unclipped = G( transform=[*self.get_base_transform()], ) unclipped.add_child(data_labels_g) wrapper.add_child(unclipped) quadrant_g = self._render_quadrant_labels() if quadrant_g: unclipped_q = G( transform=[*self.get_base_transform()], ) unclipped_q.add_child(quadrant_g) wrapper.add_child(unclipped_q) return wrapper def _point_style(self, series_idx: int, point_idx: int) -> PointStyleConfig | None: """Return the ``PointStyleConfig`` for one point, or None. ``point_styles`` is a list of per-series rows mirroring the data shape; any missing series, point, or ``None`` entry yields no override. """ styles = self._point_styles if not styles or series_idx >= len(styles): return None row = styles[series_idx] if not row or point_idx >= len(row): return None return row[point_idx] or None def _marker_element( self, shape: str, x: float, y: float, size: float, fill: str ) -> Element | None: """Build a marker element centred at (x, y). ``size`` is the radius / half-extent, so every shape shares the same bounding box for a given size (a square, circle, diamond, triangle and star with the same ``size`` all fit the same 2*size box). Returns None for shape ``"none"``. """ if shape == "none": return None if shape == "square": return Rect(x=x - size, y=y - size, width=size * 2, height=size * 2) if shape == "diamond": pts = f"{x},{y - size} {x + size},{y} {x},{y + size} {x - size},{y}" return Path(d=f"M{pts} Z", fill=fill) if shape == "triangle": pts = self._polygon_points(x, y, size, sides=3) return Path(d=f"M{pts} Z", fill=fill) if shape == "star": pts = self._star_points(x, y, size) return Path(d=f"M{pts} Z", fill=fill) # default: circle return Circle(cx=x, cy=y, r=size) # The plot group is rendered with a net vertical flip (see # LayoutEngine.get_base_transform), so a vertex placed at the bottom in # local space appears at the top on screen. The +90 start angle therefore # makes the apex/first tip point UP for the viewer. @staticmethod def _polygon_points(cx: float, cy: float, r: float, sides: int) -> str: """Vertices of a regular polygon, apex pointing up on screen.""" out = [] for k in range(sides): a = math.radians(90 + k * 360 / sides) out.append(f"{cx + r * math.cos(a):.3f},{cy + r * math.sin(a):.3f}") return " ".join(out) @staticmethod def _star_points( cx: float, cy: float, r: float, points: int = 5, inner_ratio: float = 0.4 ) -> str: """Vertices of a star with ``points`` tips, first tip up on screen.""" inner = r * inner_ratio out = [] for k in range(points * 2): radius = r if k % 2 == 0 else inner a = math.radians(90 + k * 180 / points) out.append( f"{cx + radius * math.cos(a):.3f},{cy + radius * math.sin(a):.3f}" ) return " ".join(out) def _render_data_labels(self) -> G | None: """Render scatter data labels, optionally de-overlapping them. With ``avoid_label_collisions=False`` (the default) this defers to the base-class placement so existing renders are byte-for-byte unchanged. With it enabled, every label starts at a fixed offset above-right of its marker, then a greedy iterative pass nudges labels apart whenever their axis-aligned bounding boxes overlap another label or a marker. When a label ends up displaced far enough from its point a thin leader line is drawn back to the marker so the association stays readable. Limitations: the de-overlap is a local greedy relaxation, not a global optimiser, so dense clusters can still leave some residual overlap and the result depends on point order. Label widths are estimated from the font metrics helper (no real text shaping), labels are not clamped to the plot rectangle, and rotated/multi-line labels are not handled. Markers are approximated by their square bounding box. """ if not getattr(self, "_avoid_label_collisions", False): return super()._render_data_labels() if not self._data_labels: return None labels = self._data_labels if labels and not isinstance(labels[0], list): labels = [labels] from charted.utils.helpers import calculate_text_dimensions font_size = max(8, self.theme.title_font_size - 4) font_family = self.theme.title_font_family font_color = self.theme.resolved_data_label_color line_color = self.theme.resolved_reference_line_color # Gather placed labels and their anchor markers in plot coordinates. placed: list[_PlacedLabel] = [] for series_idx, label_row in enumerate(labels): if series_idx >= len(self.y_values): break y_vals = self.y_values[series_idx] y_offs = self.y_offsets[series_idx] x_vals = self.x_values[series_idx] marker_size = 4.0 if self.series_styles and series_idx < len(self.series_styles): style = self.series_styles[series_idx] or {} if style.get("marker_size"): marker_size = float(cast(float, style["marker_size"])) for i, label_text in enumerate(label_row): if i >= len(x_vals) or not label_text: continue px = x_vals[i] + self.x_offset py = self._apply_stacking(y_vals[i], y_offs[i]) text = str(label_text) tw = calculate_text_dimensions(text, font_size=font_size).width th = font_size # Initial offset: up and to the right of the marker. off = marker_size + th * 0.5 cx = px + off + tw / 2 cy = py + off + th / 2 placed.append( { "text": text, "px": px, "py": py, "cx": cx, "cy": cy, "w": tw, "h": th, "marker": marker_size, } ) if not placed: return None self._deoverlap_labels(placed) g = G() # Leader lines first so labels render on top. threshold = font_size * 1.6 for lab in placed: dx = lab["cx"] - lab["px"] dy = lab["cy"] - lab["py"] if (dx * dx + dy * dy) ** 0.5 > threshold: g.add_child( Path( d=f"M{lab['px']:.2f},{lab['py']:.2f} " f"L{lab['cx']:.2f},{lab['cy']:.2f}", stroke=line_color, stroke_width=1, fill="none", ) ) for lab in placed: tx = lab["cx"] ty = lab["cy"] g.add_child( Text( text=lab["text"], x=tx, y=ty, fill=font_color, font_size=font_size, font_family=font_family, text_anchor="middle", transform=f"translate({tx},{ty}) scale(1,-1) translate({-tx},{-ty})", ) ) return g @staticmethod def _deoverlap_labels(placed: list[_PlacedLabel], iterations: int = 60) -> None: """Greedily push overlapping label boxes apart, in place. Each iteration walks every label pair plus every label/marker pair and, for any overlapping axis-aligned boxes, shifts the label along the axis of least penetration. A small spring pulls each label back toward its own marker so labels do not drift indefinitely. This is a local heuristic with no global guarantee (see ``_render_data_labels``). """ def overlap( a_cx: float, a_cy: float, a_w: float, a_h: float, b_cx: float, b_cy: float, b_w: float, b_h: float, pad: float = 2.0, ) -> tuple[float, float] | None: ox = (a_w + b_w) / 2 + pad - abs(a_cx - b_cx) oy = (a_h + b_h) / 2 + pad - abs(a_cy - b_cy) if ox > 0 and oy > 0: return ox, oy return None for _ in range(iterations): moved = False for i, a in enumerate(placed): # Label vs every other label. for b in placed[i + 1 :]: res = overlap( a["cx"], a["cy"], a["w"], a["h"], b["cx"], b["cy"], b["w"], b["h"], ) if res is None: continue ox, oy = res moved = True if ox < oy: shift = ox / 2 + 0.1 sign = 1 if a["cx"] >= b["cx"] else -1 a["cx"] += sign * shift b["cx"] -= sign * shift else: shift = oy / 2 + 0.1 sign = 1 if a["cy"] >= b["cy"] else -1 a["cy"] += sign * shift b["cy"] -= sign * shift # Label vs markers (approximated by their bounding box). for b in placed: m = b["marker"] * 2 res = overlap( a["cx"], a["cy"], a["w"], a["h"], b["px"], b["py"], m, m, ) if res is None: continue ox, oy = res moved = True if ox < oy: sign = 1 if a["cx"] >= b["px"] else -1 a["cx"] += sign * (ox + 0.1) else: sign = 1 if a["cy"] >= b["py"] else -1 a["cy"] += sign * (oy + 0.1) # Weak spring back toward the marker keeps labels from wandering. for a in placed: a["cx"] += (a["px"] - a["cx"]) * 0.01 a["cy"] += (a["py"] - a["cy"]) * 0.01 if not moved: break def _render_quadrant_labels(self) -> G | None: """Render text labels in each quadrant of the scatter plot. Expects a list of 4 strings: [top-left, top-right, bottom-left, bottom-right]. Each string may contain newlines for multi-line labels. """ if not self._quadrant_labels: return None labels = self._quadrant_labels if len(labels) < 4: labels = list(labels) + [""] * (4 - len(labels)) g = G() font_size = max(8, self.theme.title_font_size - 4) font_family = self.theme.title_font_family font_color = self.theme.resolved_quadrant_label_color pw = self.plot_width ph = self.plot_height # Inset the labels away from the plot edge so they clear the axis # tick numbers instead of sitting flush in the corner. The inset is # added on top of the base corner pad. inset = max(0.0, float(self._quadrant_label_inset)) padding = 8 + inset # In the flipped coordinate system, high Y = top of chart # Corner-aligned: top labels hug top edge growing down, # bottom labels hug bottom edge growing up line_height = font_size + QUADRANT_LABEL_LINE_GAP top_margin = padding + font_size bottom_margin = padding * QUADRANT_BOTTOM_MARGIN_FACTOR for idx, label_text in enumerate(labels): if not label_text: continue lines = str(label_text).split("\n") is_left = idx % 2 == 0 is_top = idx < 2 anchor = "start" if is_left else "end" x = padding if is_left else pw - padding if self._quadrant_label_backplate: backplate = self._quadrant_backplate( lines, x, anchor, is_top, font_size, line_height, top_margin, bottom_margin, ph, ) if backplate is not None: g.add_child(backplate) for line_idx, line in enumerate(lines): if is_top: ty = ph - top_margin - line_idx * line_height else: ty = ( bottom_margin + font_size + (len(lines) - 1 - line_idx) * line_height ) g.add_child( Text( text=line, x=x, y=ty, fill=font_color, font_size=font_size, font_family=font_family, text_anchor=anchor, opacity=0.8, transform=f"translate({x},{ty}) scale(1,-1) translate({-x},{-ty})", ) ) return g def _quadrant_backplate( self, lines: list[str], x: float, anchor: str, is_top: bool, font_size: float, line_height: float, top_margin: float, bottom_margin: float, ph: float, ) -> Rect | None: """Build a semi-opaque rounded plate sized to a quadrant label block. Drawn in the flipped plot coordinate system (high Y = top), so the plate is emitted before the text and renders behind it. """ if not lines: return None pad_x = font_size * 0.5 pad_y = font_size * 0.35 # Width estimate is intentionally font-metric-free to stay stable # across environments (matches the pie chart's estimator). text_w = max((len(line) for line in lines), default=0) * font_size * 0.55 if text_w <= 0: return None block_h = font_size + (len(lines) - 1) * line_height rect_w = text_w + pad_x * 2 rect_h = block_h + pad_y * 2 if anchor == "start": rect_x = x - pad_x else: rect_x = x - text_w - pad_x if is_top: top_baseline = ph - top_margin rect_y = top_baseline - font_size - pad_y else: top_baseline = bottom_margin + font_size + (len(lines) - 1) * line_height rect_y = top_baseline - font_size - pad_y return Rect( x=round(rect_x, 2), y=round(rect_y, 2), width=round(rect_w, 2), height=round(rect_h, 2), rx=round(font_size * 0.35, 2), fill=self.theme.background_color, opacity=0.7, )