Skip to content

mosaic_heatmap

mosaic_heatmap(data, ...)

Plot mosaic data as a color-encoded matrix.

Creates a mosaic heatmap where the column widths and row heights are proportional to the marginal sums of the data matrix. This provides a visualization that encodes both the cell values through color and the marginal distributions through cell sizes.

Parameters:

Name Type Description Default
data array - like

2D dataset that can be coerced into an ndarray. If a pandas DataFrame is provided, the index/column information will be used to label the columns and rows.

required
vmin float

Values to anchor the colormap. If not provided, they are inferred from the data and other keyword arguments.

None
vmax float

Values to anchor the colormap. If not provided, they are inferred from the data and other keyword arguments.

None
cmap str or Colormap

The mapping from data values to color space. If not provided, the default depends on whether center is set.

None
center float

The value at which to center the colormap for divergent data. Changes the default cmap if none is specified.

None
robust bool

If True and vmin or vmax are absent, compute colormap range using robust quantiles instead of extreme values.

False
annot bool or array - like

If True, write the data value in each cell. If array-like with same shape as data, use this for annotation instead of the data. DataFrames match on position, not index.

None
fmt str

String formatting code for annotation values. Default: '.2g'

'.2g'
annot_kws dict

Keyword arguments for matplotlib.axes.Axes.text when annot is True.

None
linewidths float

Width of cell divider lines. Default: 0

0
linecolor color

Color of cell divider lines. Default: 'white'

'white'
cbar bool

Whether to draw a colorbar. Default: True

True
cbar_kws dict

Keyword arguments for matplotlib.figure.Figure.colorbar.

None
cbar_ax Axes

Axes in which to draw the colorbar. If None, takes space from main Axes.

None
square bool

If True, set aspect ratio to "equal" for square cells. Default: False

False
xticklabels 'auto', bool, array-like, or int
  • True: plot column/row names
  • False: don't plot labels
  • array-like: plot custom labels
  • int: plot every nth label
  • 'auto': plot non-overlapping labels
'auto'
yticklabels 'auto', bool, array-like, or int
  • True: plot column/row names
  • False: don't plot labels
  • array-like: plot custom labels
  • int: plot every nth label
  • 'auto': plot non-overlapping labels
'auto'
mask bool array or DataFrame

If True in a cell, data is not shown. Missing values are auto-masked.

None
ax Axes

Axes in which to draw the plot. Uses current axes if None.

None
**kwargs dict

Additional keyword arguments passed to matplotlib.axes.Axes.pcolormesh.

{}

Returns:

Type Description
Axes

The Axes object with the heatmap.

Examples:

>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from mheatmap import mosaic_heatmap
>>>
>>> # Generate sample confusion matrix data
>>> data = np.array([[10, 2, 0], [1, 8, 3], [0, 1, 12]])
>>>
>>> # Create mosaic heatmap with annotations
>>> fig, ax = plt.subplots(figsize=(8, 6))
>>> mosaic_heatmap(data, annot=True, cmap='YlOrRd', fmt='d',
...               xticklabels=['A', 'B', 'C'],
...               yticklabels=['A', 'B', 'C'])
>>> plt.title('Mosaic Confusion Matrix')
>>> plt.show()
Notes

The mosaic heatmap is particularly useful for confusion matrices and contingency tables where the marginal distributions provide additional context beyond the cell values themselves.

Source code in mheatmap/matrix.py
def mosaic_heatmap(
    data,
    *,
    vmin=None,
    vmax=None,
    cmap=None,
    center=None,
    robust=False,
    annot=None,
    fmt=".2g",
    annot_kws=None,
    linewidths=0,
    linecolor="white",
    cbar=True,
    cbar_kws=None,
    cbar_ax=None,
    square=False,
    xticklabels="auto",
    yticklabels="auto",
    mask=None,
    ax=None,
    **kwargs,
):
    """`mosaic_heatmap(data, ...)`

    Plot mosaic data as a color-encoded matrix.

    Creates a mosaic heatmap where the column widths and row heights are proportional
    to the marginal sums of the data matrix. This provides a visualization that
    encodes both the cell values through color and the marginal distributions
    through cell sizes.

    Parameters
    ----------
    data : array-like
        2D dataset that can be coerced into an ndarray. If a pandas DataFrame
        is provided, the index/column information will be used to label the
        columns and rows.
    vmin, vmax : float, optional
        Values to anchor the colormap. If not provided, they are inferred from the
        data and other keyword arguments.
    cmap : str or matplotlib.colors.Colormap, optional
        The mapping from data values to color space. If not provided, the
        default depends on whether ``center`` is set.
    center : float, optional
        The value at which to center the colormap for divergent data.
        Changes the default ``cmap`` if none is specified.
    robust : bool, optional
        If True and ``vmin`` or ``vmax`` are absent, compute colormap range using
        robust quantiles instead of extreme values.
    annot : bool or array-like, optional
        If True, write the data value in each cell. If array-like with same shape
        as ``data``, use this for annotation instead of the data. DataFrames match
        on position, not index.
    fmt : str, optional
        String formatting code for annotation values. Default: '.2g'
    annot_kws : dict, optional
        Keyword arguments for matplotlib.axes.Axes.text when ``annot`` is True.
    linewidths : float, optional
        Width of cell divider lines. Default: 0
    linecolor : color, optional
        Color of cell divider lines. Default: 'white'
    cbar : bool, optional
        Whether to draw a colorbar. Default: True
    cbar_kws : dict, optional
        Keyword arguments for matplotlib.figure.Figure.colorbar.
    cbar_ax : matplotlib.axes.Axes, optional
        Axes in which to draw the colorbar. If None, takes space from main Axes.
    square : bool, optional
        If True, set aspect ratio to "equal" for square cells. Default: False
    xticklabels, yticklabels : 'auto', bool, array-like, or int, optional
        - True: plot column/row names
        - False: don't plot labels
        - array-like: plot custom labels
        - int: plot every nth label
        - 'auto': plot non-overlapping labels
    mask : bool array or DataFrame, optional
        If True in a cell, data is not shown. Missing values are auto-masked.
    ax : matplotlib.axes.Axes, optional
        Axes in which to draw the plot. Uses current axes if None.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib.axes.Axes.pcolormesh.

    Returns
    -------
    matplotlib.axes.Axes
        The Axes object with the heatmap.

    Examples
    --------
    >>> import numpy as np
    >>> import matplotlib.pyplot as plt
    >>> from mheatmap import mosaic_heatmap
    >>>
    >>> # Generate sample confusion matrix data
    >>> data = np.array([[10, 2, 0], [1, 8, 3], [0, 1, 12]])
    >>>
    >>> # Create mosaic heatmap with annotations
    >>> fig, ax = plt.subplots(figsize=(8, 6))
    >>> mosaic_heatmap(data, annot=True, cmap='YlOrRd', fmt='d',
    ...               xticklabels=['A', 'B', 'C'],
    ...               yticklabels=['A', 'B', 'C'])
    >>> plt.title('Mosaic Confusion Matrix')
    >>> plt.show()

    Notes
    -----
    The mosaic heatmap is particularly useful for confusion matrices and contingency
    tables where the marginal distributions provide additional context beyond the
    cell values themselves.
    """
    # Initialize the _MosaicHeatMapper class
    plotter = _MosaicHeatMapper(
        data,
        vmin,
        vmax,
        cmap,
        center,
        robust,
        annot,
        fmt,
        annot_kws,
        cbar,
        cbar_kws,
        xticklabels,
        yticklabels,
        mask,
    )

    # Add the linewidths and linecolor kwargs
    # kwargs["linewidths"] = linewidths
    # kwargs["linecolor"] = linecolor

    # Draw the plot and return the Axes
    if ax is None:
        ax = plt.gca()
    if square:
        ax.set_aspect("equal")
    plotter.plot(ax, cbar_ax, kwargs)
    return ax