Weekend Challenge – Recreating a Data Visualization with Polars and Plotnine

python
polars
matplotlib
plotnine
Author

Jerry Wu

Published

April 12, 2025

This post is part of a visualization recreation challenge using Polars and plotnine, inspired by my earlier work.

It marks my first serious dive into plotnine—an impressive library with a bit of a learning curve.
I’ll walk through the journey I took to recreate the visualization. Some parts may overlap with the earlier post, but I believe that’s acceptable to keep this one self-contained.

The final figure, shown below, visualizes temperature trends for the ski season in Alta over the past few decades. Alta ski resort

Show full code
import matplotlib.pyplot as plt
import polars as pl
import polars.selectors as cs
from highlight_text import ax_text
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from plotnine import (
    aes,
    element_blank,
    element_text,
    geom_line,
    geom_point,
    geom_segment,
    geom_text,
    ggplot,
    labs,
    scale_color_cmap,
    scale_x_continuous,
    scale_y_continuous,
    theme,
    theme_classic,
)


# https://github.com/mattharrison/datasets/raw/refs/heads/master/data/alta-noaa-1980-2019.csv
data_path = "alta-noaa-1980-2019.csv"
columns = ["DATE", "TOBS"]
idx_colname = "DAY_OF_SEASON"

heading_fontsize = 9.5
heading_fontweight = "bold"
subheading_fontsize = 8
subheading_fontweight = "normal"
source_fontsize = 6.5
source_fontweight = "light"
axis_fontsize = 7
axis_fontweight = "normal"
sub_props = {"fontsize": subheading_fontsize, "fontweight": subheading_fontweight}

grey = "#aaaaaa"
red = "#e3120b"
blue = "#0000ff"


def get_season_expr(col: str = "DATE", alias: str = "SEASON") -> pl.expr:
    return (
        (
            pl.when((pl.col(col).dt.month().is_between(5, 10, closed="both")))
            .then(pl.lit("Summer "))
            .otherwise(pl.lit("Ski "))
        )
        .add(
            pl.when(pl.col(col).dt.month() < 11)
            .then(pl.col(col).dt.year().cast(pl.String))
            .otherwise(pl.col(col).dt.year().add(1).cast(pl.String))
        )
        .alias(alias)
    )


def add_day_of_season_expr(
    col: str = "DATE", group_col: str = "SEASON", alias: str = "DAY_OF_SEASON"
) -> pl.expr:
    return (
        (pl.col(col) - pl.col(col).min()).dt.total_days().over(group_col).alias(alias)
    )


def tweak_df(
    data_path: str, columns: list[str], idx_colname: str = "DAY_OF_SEASON"
) -> pl.DataFrame:
    return (
        pl.scan_csv(data_path)
        .select(columns)
        .with_columns(
            pl.col("DATE").str.to_datetime(),
            pl.col("TOBS").interpolate(),
        )
        .sort("DATE")
        .with_columns(
            # Caveat: Cannot be placed in the previous `with_columns()`
            # due to different statuses of `TOBS`.
            pl.col("TOBS").rolling_mean(window_size=28, center=True).alias("TMEAN"),
            get_season_expr(col="DATE", alias="SEASON"),
        )
        .with_columns(
            add_day_of_season_expr(col="DATE", group_col="SEASON", alias=idx_colname)
        )
        .collect()
    )


def plot_temps(df_: pl.DataFrame, idx_colname: str = "DAY_OF_SEASON") -> ggplot:
    season_temps = df_.filter(pl.col("SEASON").str.contains("Ski")).pivot(
        "SEASON", index=idx_colname, values="TMEAN", aggregate_function="first"
    )

    # main
    df_main = season_temps.unpivot(
        (cs.starts_with("Ski") - cs.by_name("Ski 2019")),
        index=idx_colname,
        variable_name="year",
        value_name="temp",
    ).select(idx_colname, "temp", pl.col("year").str.slice(-4).cast(pl.Int32))

    # decades
    decades = [1980, 1990, 2000, 2010]
    blues = ["#0055EE", "#0033CC", "#0011AA", "#3377FF"]

    df_decade_sep = [
        pl.concat(
            [
                season_temps.select(idx_colname),
                season_temps.with_columns(
                    pl.mean_horizontal(cs.contains(str(decade)[:-1])).alias("temp")
                ).select("temp"),
                season_temps.with_columns(
                    pl.lit(int(str(decade)[:-1])).alias("DECADE")
                ).select("DECADE"),
                season_temps.with_columns(pl.lit(b).alias("color")).select("color"),
            ],
            how="horizontal",
        )
        for b, decade in zip(blues, decades)
    ]

    df_decade = pl.concat(df_decade_sep, how="vertical")

    #  decade annotations
    decade_annts = [
        one_df.select(pl.col("temp").last()).item() for one_df in df_decade_sep
    ]
    df_decade_annt = pl.DataFrame(
        {
            "x": [185] * len(decade_annts),
            # adjust y position for better appearance
            "y": [
                decade_annts[0],
                decade_annts[1] + 0.5,
                decade_annts[2] - 3,
                decade_annts[3],
            ],
            "color": blues,
            "label": decades,
        }
    )

    # ski_2019
    ski_2019 = (
        season_temps.select(idx_colname, pl.col("Ski 2019").alias("temp"))
        .drop_nulls()  # "DAY_OF_SEASON"=181, "temp"=null
        .with_columns(pl.lit(2019).alias("year"))
    )

    # start and end points
    decade_pts = [
        pl.concat([one_df.head(1), one_df.tail(1)], how="vertical").select(
            "DAY_OF_SEASON", "temp", "color"
        )
        for one_df in df_decade_sep
    ]

    df_decade_pts = pl.concat(decade_pts, how="vertical")
    df_ski_2019_pts = (
        pl.concat([ski_2019.head(1), ski_2019.tail(1)])
        .select("DAY_OF_SEASON", "temp")
        .with_columns(pl.lit(red).alias("color"))
    )

    # ggplot
    return (
        ggplot(mapping=aes(x=idx_colname, y="temp"))
        # multiple grey lines
        + geom_line(
            mapping=aes(color="factor(year)"),
            data=df_main,
            alpha=0.2,
            size=0.5,
        )
        + scale_color_cmap("Greys", guide=None, labels=[10, 32, 50])
        # 4 blue lines
        + geom_line(
            mapping=aes(fill="factor(DECADE)"),
            data=df_decade,
            color=df_decade["color"],
            size=0.5,
            lineend="round",
        )
        # 2019 red line
        + geom_line(
            data=ski_2019,
            color=red,
            size=0.8,
            lineend="round",
        )
        # 1 black dashed line for temp=32F
        + geom_segment(
            mapping=aes(x=0, xend=200, y=32, yend=32),
            color="black",
            size=0.5,
            linetype="dashed",
        )
        # start and end dots for 4 blue lines
        + geom_point(
            mapping=aes(x="DAY_OF_SEASON", y="temp"),
            data=df_decade_pts,
            color=df_decade_pts["color"],
            size=0.2,
        )
        # start and end dots for 2019 red line
        + geom_point(
            mapping=aes(x="DAY_OF_SEASON", y="temp"),
            data=df_ski_2019_pts,
            color=df_ski_2019_pts["color"],
            size=1,
        )
        + labs(x="Day of season", y="")
        + scale_x_continuous(
            breaks=[0, 50, 100, 150], limits=(0, 200), expand=(0, 10, 0, 15)
        )
        + scale_y_continuous(breaks=[10, 32, 40], limits=(10, 70), expand=(0, 0))
        # annotations for 4 blue lines
        + geom_text(
            mapping=aes(x="x", y="y", label="label"),
            data=df_decade_annt,
            color=df_decade_annt["color"],
            size=axis_fontsize,
            fontweight=axis_fontweight,
            ha="left",
            va="center",
        )
    )


def points_to_inches(points):
    return points / 72


def themify(p: ggplot) -> Figure:
    figsize = (160, 165)  # pts
    figsize_inches = [points_to_inches(dim) for dim in figsize]

    return (
        p
        + theme_classic()
        + theme(
            axis_line_y=element_blank(),
            axis_title_x=element_text(weight=axis_fontweight, size=axis_fontsize),
            axis_title_y=element_text(weight=axis_fontweight, size=axis_fontsize),
            axis_text_x=element_text(color="black"),
            axis_text_y=element_text(color="black"),
            dpi=300,
            figure_size=figsize_inches,
            aspect_ratio=2 / 3,
            text=element_text("Roboto"),
        )
    ).draw(show=False)


def add_ax_text(ax: Axes) -> Axes:
    ax_text(
        s="<Alta Ski Resort>\n<Temperature trends by >\n<decade>< and ><2019>",
        x=-5,
        y=55,
        fontsize=heading_fontsize,
        ax=ax,
        va="bottom",
        ha="left",
        zorder=5,
        highlight_textprops=[
            {"fontsize": heading_fontsize, "fontweight": heading_fontweight},
            sub_props,
            {"color": blue, **sub_props},
            sub_props,
            {"color": red, **sub_props},
        ],
    )

    ax.text(
        0,
        -10,
        "Source: NOAA",
        fontsize=source_fontsize,
        fontweight=source_fontweight,
        color=grey,
    )
    return ax


df = tweak_df(data_path, columns, idx_colname)
p = plot_temps(df, idx_colname)
fig = themify(p)
ax = fig.axes[0]
ax = add_ax_text(ax)
fig

Data Processing Pipeline

Below is the data pipeline used to generate the DataFrame for the upcoming visualization stage:

def tweak_df(data_path: str, columns: list[str], idx_colname: str = "DAY_OF_SEASON"):
    return (
        pl.scan_csv(data_path)
        .select(columns)
        .with_columns(
1            pl.col("DATE").str.to_datetime(),
2            pl.col("TOBS").interpolate(),
        )
        .sort("DATE")
        .with_columns(
            # Caveat: Cannot be placed in the previous `with_columns()`
            # due to different statuses of `TOBS`.
3            pl.col("TOBS").rolling_mean(window_size=28, center=True).alias("TMEAN"),
4            get_season_expr(col="DATE", alias="SEASON"),
        )
        .with_columns(
5            add_day_of_season_expr(col="DATE", group_col="SEASON", alias=idx_colname)
        )
        .collect()
    )
1
Convert the DATE column to a datetime format.
2
Perform interpolation on the TOBS column.
3
Compute a 28-day rolling average for TOBS.
4
Use get_season_expr() to categorize each date into a SEASON.
5
Apply add_day_of_season_expr() to calculate DAY_OF_SEASON, representing days elapsed since the start of the season.

The first three steps involve straightforward Polars expressions. In the following two sub-sections, we’ll dive deeper into steps 4 and 5.

Categorizing Dates into Summer and Ski Seasons

To analyze seasonal trends, we classify dates into two categories:

  • Summer: Covers May through October.
  • Ski: Covers November through April.

If a date falls in November or December, it is assigned to the following year’s season. For example, 2015-10-31 is categorized as Summer 2015, while 2015-11-01 belongs to Ski 2016.

To implement this logic, we define get_season_expr(), which leverages Polars’ when-then-otherwise expressions to determine the season and year.

def get_season_expr(col: str = "DATE", alias: str = "SEASON") -> pl.expr:
    return (
        (
            pl.when((pl.col(col).dt.month().is_between(5, 10, closed="both")))
            .then(pl.lit("Summer "))
            .otherwise(pl.lit("Ski "))
        )
        .add(
            pl.when(pl.col(col).dt.month() < 11)
            .then(pl.col(col).dt.year().cast(pl.String))
            .otherwise(pl.col(col).dt.year().add(1).cast(pl.String))
        )
        .alias(alias)
    )

In this function:

  • If the month is between May and October, the function assigns "Summer ". Otherwise, it assigns "Ski " (with a trailing space for concatenation).
  • The year is determined based on the month: dates from January to October retain their current year, while those in November and December are shifted to the next year.

By applying this function, we can add a SEASON column to a DataFrame, ensuring each date is categorized correctly.

Calculating the Total Days for Each Season

Once we have the seasonal categories, we calculate DAY_OF_SEASON, which tracks the number of days elapsed within each season. This is achieved using the pl.expr.over() expression, which operates similarly to Pandas’ groupby().transform(), applying transformations within groups.

def add_day_of_season_expr(
    col: str = "DATE", group_col: str = "SEASON", alias: str = "DAY_OF_SEASON"
) -> pl.expr:
    return (
        (pl.col(col) - pl.col(col).min()).dt.total_days().over(group_col).alias(alias)
    )

Touchups

Adding a Theme

We apply a custom theme using the themify() function, adjusting various themeable elements to refine the plot’s appearance:

def points_to_inches(points):
    return points / 72


def themify(p: ggplot) -> Figure:
    figsize = (160, 165)  # pts
    figsize_inches = [points_to_inches(dim) for dim in figsize]

    return (
        p
        + theme_classic()
        + theme(
            axis_line_y=element_blank(),
            axis_title_x=element_text(weight=axis_fontweight, size=axis_fontsize),
            axis_title_y=element_text(weight=axis_fontweight, size=axis_fontsize),
            axis_text_x=element_text(color="black"),
            axis_text_y=element_text(color="black"),
            dpi=300,
            figure_size=figsize_inches,
            aspect_ratio=2 / 3,
            text=element_text("Roboto"),
        )
    ).draw(show=False)

Adding a Title and Source Note

For the title, we use ax_text() from the HighlightText library. It allows inline text highlighting using < >, letting us emphasize specific parts of the title like <Alta Ski Resort>, <Temperature trends by >, <decade>, < and >, and <2019> with customized styles.

To add a source note, we simply use Matplotlib’s ax.text():

def add_ax_text(ax: Axes) -> Axes:
    ax_text(
        s="<Alta Ski Resort>\n<Temperature trends by >\n<decade>< and ><2019>",
        x=-5,
        y=55,
        fontsize=heading_fontsize,
        ax=ax,
        va="bottom",
        ha="left",
        zorder=5,
        highlight_textprops=[
            {"fontsize": heading_fontsize, "fontweight": heading_fontweight},
            sub_props,
            {"color": blue, **sub_props},
            sub_props,
            {"color": red, **sub_props},
        ],
    )

    ax.text(
        0,
        -10,
        "Source: NOAA",
        fontsize=source_fontsize,
        fontweight=source_fontweight,
        color=grey,
    )
    return ax

Rendering the Plot

Now we put everything together and render the final plot. A key trick here is retrieving the ax object using fig.axes[0], which allows us to apply both HighlightText and regular Matplotlib functions.

df = tweak_df(data_path, columns, idx_colname)
p = plot_temps(df, idx_colname)
fig = themify(p)
ax = fig.axes[0]
ax = add_ax_text(ax)
fig

Takeaways

Wrapping up this post, I’ve come to appreciate how powerful the plotnine library truly is. While its aesthetic system requires a bit of mental shift, it offers a clean, expressive way to build layered visualizations.

One key takeaway for me is that each layer can operate on its own dataset, which adds a lot of flexibility. What I enjoyed most, though, is the theme system—it makes it easy to define a consistent visual style that can be reused across different plots.

One limitation I ran into was the lack of a plotnine-native alternative to pli.subplot_mosaic(). This feature allows for more granular layout control—for example, dividing the figure into separate axes with custom height ratios for the title, main plot, and source note using gridspec_kw={"height_ratios": [6, 12, 1]}.

Disclaimer

This post was drafted by me, with AI assistance to refine the content.