Source code for charted.charts.heatmap

"""Heatmap chart for visualizing matrix data as colored cells.

Displays a 2D grid where each cell is colored according to its value,
using a configurable color scale from low (cool) to high (warm) values.
Supports row and column labels, value annotations, and auto color scaling.
"""

from __future__ import annotations

from charted.charts.chart import Chart
from charted.constants import DEFAULT_CHART_HEIGHT, DEFAULT_CHART_WIDTH
from charted.html.element import G, Path, Text
from charted.themes.core import ColorScale, Theme
from charted.utils.types import Labels, SeriesStyleConfig


def _lerp_color(c1: str, c2: str, t: float) -> str:
    from charted.utils.colors import hex_to_rgb, rgb_to_hex

    r1, g1, b1 = hex_to_rgb(c1)
    r2, g2, b2 = hex_to_rgb(c2)

    t = max(0.0, min(1.0, t))
    return rgb_to_hex(
        (
            int(r1 + (r2 - r1) * t),
            int(g1 + (g2 - g1) * t),
            int(b1 + (b2 - b1) * t),
        )
    )


[docs] class HeatmapChart(Chart): """Heatmap chart for visualizing matrix data as colored cells. Renders a 2D grid where each cell's color represents its value. Supports row and column labels, value annotations inside cells, and automatic color scaling based on data range. Args: data: 2D matrix (list of lists) where each inner list is a row. x_labels: Labels for each column (optional, auto-generated if omitted). y_labels: Labels for each row (optional, auto-generated if omitted). width, height: Chart dimensions in pixels. title: Optional chart title. theme: Optional theme configuration. series_names: Names for each series (shown in legend). series_styles: Per-series style overrides. low_color: Color for the lowest value in the data. Defaults to None, which derives the low endpoint from the first theme palette colour. high_color: Color for the highest value in the data. Defaults to None, which derives the high endpoint from the last theme palette colour. color_scale: Optional continuous color scale. Pass a ColorScale, a named palette string (e.g. 'viridis'), or a list of hex stops to color cells along a multi-stop gradient. Overrides low_color and high_color. Defaults to None (two-color low/high behavior). show_values: If True, display the numeric value in each cell (default True). value_format: Format string for displayed values (default '.1f'). cell_gap: Gap between cells as fraction of cell size (default 0.04). label_font_size: Font size for row/column labels (default 11). cell_border_width: Stroke width of the thin border around each cell, in pixels (default 0.25). colorbar_ticks: Number of evenly spaced tick labels on the colorbar, including the min and max endpoints (default 5, minimum 2). colorbar_title: Optional title rendered vertically beside the colorbar, e.g. a unit or measure name. Defaults to None (no title). colorbar_width: Width of the gradient strip in pixels (default 16). Example: >>> from charted import HeatmapChart >>> chart = HeatmapChart( ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... x_labels=['A', 'B', 'C'], ... y_labels=['X', 'Y', 'Z'], ... ) >>> chart.save('matrix.svg') """ render_axes = False def __init__( self, data: list[list[float]], x_labels: Labels | None = None, y_labels: Labels | None = None, width: float = DEFAULT_CHART_WIDTH, height: float = DEFAULT_CHART_HEIGHT, title: str | None = None, theme: Theme | None = None, series_names: list[str] | None = None, series_styles: list[SeriesStyleConfig] | None = None, low_color: str | None = None, high_color: str | None = None, color_scale: "ColorScale | str | list[str] | None" = None, show_values: bool = True, value_format: str = ".1f", cell_gap: float = 0.04, label_font_size: int = 11, cell_border_width: float = 0.25, colorbar_ticks: int = 5, colorbar_title: str | None = None, colorbar_width: float = 16, ): if not data or not isinstance(data, list) or len(data) == 0: raise ValueError("Data must be a non-empty 2D list") if not isinstance(data[0], list): raise ValueError("Data must be a 2D matrix (list of lists)") n_rows = len(data) n_cols = len(data[0]) if n_rows > 0 else 0 if n_cols == 0: raise ValueError("Each row must have at least one column") for i, row in enumerate(data): if len(row) != n_cols: raise ValueError(f"Row {i} has {len(row)} columns, expected {n_cols}") self._matrix = data self._n_rows = n_rows self._n_cols = n_cols # Defer resolving low/high colours until after super().__init__ has # loaded the theme. None means "derive the gradient endpoints from the # theme palette" (so presets like high-contrast and custom palettes # drive the heatmap); explicit hex strings are honoured as-is. self._low_color_override = low_color self._high_color_override = high_color self.show_values = show_values self.value_format = value_format self.cell_gap = cell_gap self._label_font_size = label_font_size self.cell_border_width = max(0.0, float(cell_border_width)) self.colorbar_ticks = max(2, int(colorbar_ticks)) self.colorbar_title = colorbar_title self.colorbar_width = max(1.0, float(colorbar_width)) if x_labels is None: x_labels = [str(i + 1) for i in range(n_cols)] if y_labels is None: y_labels = [str(i + 1) for i in range(n_rows)] if len(x_labels) != n_cols: raise ValueError( f"x_labels count ({len(x_labels)}) must match columns ({n_cols})" ) if len(y_labels) != n_rows: raise ValueError( f"y_labels count ({len(y_labels)}) must match rows ({n_rows})" ) self._x_labels = list(x_labels) self._y_labels = list(y_labels) all_values = [v for row in data for v in row] self._data_min = min(all_values) self._data_max = max(all_values) self._data_range = self._data_max - self._data_min self.color_scale = self._resolve_color_scale(color_scale) # Resolve the theme palette up front (super().__init__ triggers # _build_children, which renders cells before super returns) so the # gradient endpoints can derive from the theme when the caller did not # override them. The default theme palette equals the historical # low/high defaults' source, keeping default renders unchanged. from charted.utils.theme_manager import ThemeManager resolved_theme = ThemeManager.load_theme(theme, "heatmap") palette = list(resolved_theme.colors) if resolved_theme.colors else [] # low = first palette colour, high = second. With the default palette # (#5fab9e, #f58b51, ...) this reproduces the historical low/high # defaults exactly, so default renders are byte-for-byte unchanged, # while presets like high-contrast pick up their own two endpoints. self.low_color = self._low_color_override or ( palette[0] if palette else "#5fab9e" ) self.high_color = self._high_color_override or ( palette[1] if len(palette) > 1 else "#f58b51" ) x_data = [[float(i) for i in range(n_cols)] for _ in range(n_rows)] y_data = [[float(i) for i in range(n_rows)] for _ in range(n_cols)] super().__init__( width=width, height=height, x_data=x_data, y_data=y_data, x_labels=x_labels, y_labels=y_labels, title=title, zero_index=True, theme=theme, chart_type="heatmap", series_styles=series_styles, series_names=series_names, ) self.layout.h_padding = 0.07 self.children.clear() children = [self.container, self.title, self.representation, self.legend] self.add_children(*children) # Colorbar geometry. The gap between the plot and the gradient strip, # the gap between the strip and its tick marks, and the tick mark length. _COLORBAR_GAP = 14 _COLORBAR_TICK_GAP = 5 _COLORBAR_TICK_LEN = 4 def _colorbar_label_width(self) -> float: """Estimate the pixel width of the widest colorbar tick label. Used to reserve enough right padding so the colorbar, its tick labels and (optional) title stay inside the SVG bounds. """ labels = [ format( self._data_min + (self._data_max - self._data_min) * (i / (self.colorbar_ticks - 1)), self.value_format, ) for i in range(self.colorbar_ticks) ] longest = max((len(s) for s in labels), default=1) # ~0.6em per character at the label font size. return longest * self._label_font_size * 0.6 def _legend_layout_position(self) -> str: # Always reserve a right-hand band for the colorbar. return "right" def _legend_layout_extent(self) -> float: band = ( self._COLORBAR_GAP + self.colorbar_width + self._COLORBAR_TICK_GAP + self._COLORBAR_TICK_LEN + self._colorbar_label_width() ) if self.colorbar_title: band += self._label_font_size + 6 return band @property def cell_width(self) -> float: return self.plot_width / self._n_cols @property def cell_height(self) -> float: return self.plot_height / self._n_rows @property def cell_gap_x(self) -> float: return self.cell_width * self.cell_gap @property def cell_gap_y(self) -> float: return self.cell_height * self.cell_gap @property def draw_cell_width(self) -> float: return self.cell_width - self.cell_gap_x @property def draw_cell_height(self) -> float: return self.cell_height - self.cell_gap_y def _value_labels_fit(self) -> bool: """Whether per-cell value labels can be drawn legibly. On dense grids the cells shrink and the formatted numbers collapse into an unreadable smear because adjacent labels are wider than their cells. Hide them once a cell can no longer hold its widest formatted value: the drawable cell must be at least one font-size tall and wide enough for the longest label string. Normal small heatmaps (3x3, 4x2, etc.) keep their values because their cells stay comfortably larger than the rendered text. """ min_height = float(self._label_font_size) if self.draw_cell_height < min_height: return False longest = max( (len(format(v, self.value_format)) for row in self._matrix for v in row), default=1, ) # ~0.6em per character at the label font size. min_width = longest * self._label_font_size * 0.6 return self.draw_cell_width >= min_width def _resolve_color_scale( self, color_scale: "ColorScale | str | list[str] | None" ) -> ColorScale | None: """Normalize the color_scale argument into a ColorScale or None. A None argument keeps the two-color low/high behavior. A string or list is wrapped into a ColorScale spanning the data range. Note: the heatmap always pins the color-scale domain to the data range (min, max). When a ColorScale is passed in, only its palette is reused; its own ``domain`` is discarded so cell colors stay aligned with the displayed value range and legend bar. """ if color_scale is None: return None domain = (self._data_min, self._data_max) if isinstance(color_scale, ColorScale): return ColorScale(palette=color_scale.palette, domain=domain) return ColorScale(palette=color_scale, domain=domain) def _value_to_color(self, value: float) -> str: if self.color_scale is not None: return self.color_scale(value) if self._data_range == 0: return self.low_color t = (value - self._data_min) / self._data_range return _lerp_color(self.low_color, self.high_color, t) @property def representation(self) -> G: result = G( transform=f"translate({self.left_padding}, {self.top_padding})", ) grid_color = self.theme.grid_color label_color = self.theme.title_color font_family = self.theme.title_font_family label_font_size = self._label_font_size show_values = self.show_values and self._value_labels_fit() for row_idx in range(self._n_rows): for col_idx in range(self._n_cols): value = self._matrix[row_idx][col_idx] fill = self._value_to_color(value) x = col_idx * self.cell_width + self.cell_gap_x / 2 y = row_idx * self.cell_height + self.cell_gap_y / 2 cell = Path( fill=fill, stroke=grid_color, stroke_width=self.cell_border_width, d=Path.get_path( x, y, self.draw_cell_width, self.draw_cell_height, ), ) result.add_child(cell) if show_values: text_x = x + self.draw_cell_width / 2 text_y = y + self.draw_cell_height / 2 formatted = format(value, self.value_format) from charted.utils.colors import get_contrast_color text_color = get_contrast_color(fill) result.add_child( Text( text=formatted, x=text_x, y=text_y, fill=text_color, font_family=font_family, font_size=label_font_size, text_anchor="middle", dominant_baseline="central", ) ) for col_idx in range(self._n_cols): label_x = col_idx * self.cell_width + self.cell_width / 2 result.add_child( Text( text=self._x_labels[col_idx], x=label_x, y=-self.cell_gap_y, fill=label_color, font_family=font_family, font_size=label_font_size, text_anchor="middle", dominant_baseline="bottom", ) ) for row_idx in range(self._n_rows): label_y = row_idx * self.cell_height + self.cell_height / 2 result.add_child( Text( text=self._y_labels[row_idx], x=-self.cell_gap_x, y=label_y, fill=label_color, font_family=font_family, font_size=label_font_size, text_anchor="end", dominant_baseline="central", ) ) self._add_colorbar(result, label_color, font_family) return result def _colorbar_color_at(self, t: float) -> str: """Colour at fraction ``t`` (0 = data min, 1 = data max).""" if self.color_scale is not None: return self._value_to_color(self._data_min + t * self._data_range) return _lerp_color(self.low_color, self.high_color, t) def _add_colorbar(self, result: G, label_color: str, font_family: str) -> None: """Render a gradient colorbar with tick marks, labels and a title. Drawn in the right-hand band reserved by ``_legend_layout_extent``. The strip runs from the data max at the top to the data min at the bottom. ``colorbar_ticks`` evenly spaced labels (including both endpoints) annotate the scale, each with a short tick mark. """ bar_x = self.plot_width + self._COLORBAR_GAP bar_y = 0 bar_width = self.colorbar_width bar_height = self.plot_height font_size = self._label_font_size # Gradient strip: many thin bands so the transition reads as smooth. n_stops = 64 stop_height = bar_height / n_stops for i in range(n_stops): # i = 0 is the top band -> data max; i = n_stops-1 -> data min. t = 1.0 - (i / (n_stops - 1) if n_stops > 1 else 0) result.add_child( Path( fill=self._colorbar_color_at(t), d=Path.get_path( bar_x, bar_y + i * stop_height, bar_width, # Overlap by 1px to avoid hairline seams between bands. stop_height + 1, ), ) ) # Outline around the strip for a crisp edge. result.add_child( Path( fill="none", stroke=label_color, stroke_width=0.75, d=Path.get_path(bar_x, bar_y, bar_width, bar_height), ) ) # Tick marks + labels. tick_x_end = bar_x + bar_width for i in range(self.colorbar_ticks): frac = i / (self.colorbar_ticks - 1) value = self._data_min + frac * self._data_range # frac = 0 is data min -> bottom of the strip. ty = bar_y + bar_height * (1.0 - frac) result.add_child( Path( stroke=label_color, stroke_width=0.75, d=" ".join( [ f"M{tick_x_end} {ty}", f"h{self._COLORBAR_TICK_LEN}", ] ), ) ) result.add_child( Text( text=format(value, self.value_format), x=tick_x_end + self._COLORBAR_TICK_LEN + self._COLORBAR_TICK_GAP, y=ty, fill=label_color, font_family=font_family, font_size=font_size, text_anchor="start", dominant_baseline="central", ) ) # Optional vertical title to the right of the tick labels. if self.colorbar_title: title_x = ( tick_x_end + self._COLORBAR_TICK_LEN + self._COLORBAR_TICK_GAP + self._colorbar_label_width() + font_size ) title_y = bar_y + bar_height / 2 result.add_child( Text( text=self.colorbar_title, x=title_x, y=title_y, fill=label_color, font_family=font_family, font_size=font_size, text_anchor="middle", dominant_baseline="central", transform=f"rotate(-90 {title_x} {title_y})", ) ) @property def legend(self) -> None: return None