Weekend Challenge - Effective Data Visualization with Polars and Matplotlib

python
pandas
polars
matplotlib
Author

Jerry Wu

Published

February 10, 2025

It was an honor to be one of the reviewers for Matt Harrison’s new book, Effective Visualization. If you’re looking to deepen your understanding of how to use Pandas and Matplotlib to craft compelling data stories, this book is a must-read.

Last weekend, I decided to convert some of the Pandas code from the book into Polars just for fun, and I’d like to share an example in this post. You can find the original Pandas code in the repo (empty link for now).

The final figure, shown below, visualizes temperature trends for the ski season in Alta over the past few decades.

Show full code
from functools import partial

import matplotlib.pyplot as plt
import polars as pl
import polars.selectors as cs
from highlight_text import ax_text
from matplotlib import colormaps

# https://github.com/mattharrison/datasets/raw/refs/heads/master/data/alta-noaa-1980-2019.csv
data_path = "alta-noaa-1980-2019.csv"
columns = ["DATE", "TOBS"]
df = pl.scan_csv(data_path).select(columns).collect()


def get_season_expr(col: str = "DATE", alias: str = "SEASON") -> pl.expr:
    return (
        (
            pl.when((pl.col(col).dt.month().is_between(5, 10, closed="both")))
            .then(pl.lit("Summer "))
            .otherwise(pl.lit("Ski "))
        )
        .add(
            pl.when(pl.col(col).dt.month() < 11)
            .then(pl.col(col).dt.year().cast(pl.String))
            .otherwise(pl.col(col).dt.year().add(1).cast(pl.String))
        )
        .alias(alias)
    )


def add_day_of_season_expr(
    col: str = "DATE", group_col: str = "SEASON", alias: str = "DAY_OF_SEASON"
) -> pl.expr:
    return (
        (pl.col(col) - pl.col(col).min()).dt.total_days().over(group_col).alias(alias)
    )


def plot_temps(df_: pl.DataFrame, idx_colname: str = "DAY_OF_SEASON") -> pl.DataFrame:
    plt.rcParams["font.family"] = "Roboto"
    figsize = (160, 165)  # pts

    def points_to_inches(points):
        return points / 72

    figsize_inches = [points_to_inches(dim) for dim in figsize]

    heading_fontsize = 9.5
    heading_fontweight = "bold"
    subheading_fontsize = 8
    subheading_fontweight = "normal"
    source_fontsize = 6.5
    source_fontweight = "light"
    axis_fontsize = 7
    axis_fontweight = "normal"

    grey = "#aaaaaa"
    red = "#e3120b"
    blue = "#0000ff"
    cmap = colormaps.get_cmap("Grays")

    layout = [["title"], ["plot"], ["notes"]]
    fig, axs = plt.subplot_mosaic(
        layout,
        gridspec_kw={"height_ratios": [6, 12, 1]},
        figsize=figsize_inches,
        dpi=300,
        constrained_layout=True,
    )

    # ----- Title -----
    ax_title = axs["title"]
    ax_title.axis("off")
    sub_props = {"fontsize": subheading_fontsize, "fontweight": subheading_fontweight}
    ax_text(
        s="<Alta Ski Resort>\n<Temperature trends by >\n<decade>< and ><2019>",
        x=0,
        y=0,
        fontsize=heading_fontsize,
        ax=ax_title,
        va="bottom",
        ha="left",
        zorder=5,
        highlight_textprops=[
            {"fontsize": heading_fontsize, "fontweight": heading_fontweight},
            sub_props,
            {"color": blue, **sub_props},
            sub_props,
            {"color": red, **sub_props},
        ],
    )

    # ----- Plot -----
    ax = axs["plot"]
    season_temps = df_.filter(pl.col("SEASON").str.contains("Ski")).pivot(
        "SEASON", index=idx_colname, values="TMEAN", aggregate_function="first"
    )
    season_temps_index = season_temps[idx_colname]

    columns = season_temps.columns
    columns.remove(idx_colname)
    columns.remove("Ski 2019")
    for i, column in enumerate(columns):
        color = cmap(i / len(columns))
        ax.plot(
            season_temps_index,
            season_temps[column],
            color=color,
            linewidth=1,
            alpha=0.2,
            zorder=1,
        )

    # # ---- Decade Averages ----
    decades = [1980, 1990, 2000, 2010]
    blues = ["#0055EE", "#0033CC", "#0011AA", "#3377FF"]
    for decade, color in zip(decades, blues):
        match = str(decade)[:-1]  # 1980 -> "198", 2010 -> "201"
        decade_temps = season_temps.select(cs.contains(match)).mean_horizontal()
        ax.plot(season_temps_index, decade_temps, color=color, linewidth=1)

        # add label to right of line
        last_y_label = decade_temps.last()

        if decade == 2000:
            last_y_label -= 3
        elif decade == 2010:
            last_y_label -= 0.3

        ax.text(
            185,
            last_y_label,
            f"{decade}",
            va="center",
            ha="left",
            fontsize=axis_fontsize,
            fontweight=axis_fontweight,
            color=color,
        )
        # # add dot to start and end of each line
        ax.plot(
            season_temps_index.first(),
            decade_temps.first(),
            marker="o",
            color=color,
            markersize=1,
            zorder=2,
        )
        ax.plot(
            season_temps_index.last(),
            decade_temps.last(),
            marker="o",
            color=color,
            markersize=1,
            zorder=2,
        )

    # # ------ Ski 2019 ------
    # # plot `Ski 2019` in red
    ski_2019 = season_temps.select(idx_colname, cs.by_name("Ski 2019")).drop_nulls()
    ski_2019_index = ski_2019[idx_colname]
    ski_2019 = ski_2019.drop([idx_colname]).to_series()
    ax.plot(ski_2019_index, ski_2019, color="red", linewidth=1)

    # add dot to start and end of each line
    ax.plot(
        ski_2019_index.first(),
        ski_2019.first(),
        marker="o",
        color="red",
        markersize=2,
        zorder=2,
    )
    ax.plot(
        ski_2019_index.last(),
        ski_2019.last(),
        marker="o",
        color="red",
        markersize=2,
        zorder=2,
    )

    # # ------ Ticks & Lines ------
    # # remove spines
    for side in ["top", "left", "right"]:
        ax.spines[side].set_visible(False)

    # # add horizontal line at 32F
    ax.axhline(32, color="black", linestyle="--", linewidth=1, zorder=1)

    # # set y ticks
    ax.set_yticks(ticks=[10, 32, 40])

    # # set y limit
    ax.set_ylim([10, 55])

    # # set x label
    ax.set_xlabel("Day of season", fontsize=axis_fontsize, fontweight=axis_fontweight)

    # # ------ Source ------
    ax_notes = axs["notes"]
    # add source
    ax_notes.axis("off")
    ax_notes.text(
        0,
        0,
        "Source: NOAA",
        fontsize=source_fontsize,
        fontweight=source_fontweight,
        color=grey,
    )
    return df_


idx_colname = "DAY_OF_SEASON"

data = (
    df.with_columns(
        pl.col("DATE").str.to_datetime(),
        pl.col("TOBS").interpolate(),
    )
    .sort("DATE")
    .with_columns(
        # Caveat: Cannot be placed in the previous `with_columns()`
        # due to different statuses of `TOBS`.
        pl.col("TOBS").rolling_mean(window_size=28, center=True).alias("TMEAN"),
        get_season_expr(col="DATE", alias="SEASON"),
    )
    .with_columns(
        add_day_of_season_expr(col="DATE", group_col="SEASON", alias=idx_colname)
    )
    .pipe(partial(plot_temps, idx_colname=idx_colname))
)

Alta ski resort

Loading the Data

To begin, we load the dataset, focusing on two key columns:

  • DATE: The dates.
  • TOBS: The recorded temperature in Fahrenheit.
from functools import partial

import matplotlib.pyplot as plt
import polars as pl
import polars.selectors as cs
from highlight_text import ax_text
from matplotlib import colormaps

data_path = "alta-noaa-1980-2019.csv"
columns = ["DATE", "TOBS"]
df = pl.scan_csv(data_path).select(columns).collect()
print(df.head())
shape: (5, 2)
┌────────────┬──────┐
│ DATE       ┆ TOBS │
│ ---        ┆ ---  │
│ str        ┆ i64  │
╞════════════╪══════╡
│ 1980-01-01 ┆ 25   │
│ 1980-01-02 ┆ 18   │
│ 1980-01-03 ┆ 18   │
│ 1980-01-04 ┆ 27   │
│ 1980-01-05 ┆ 34   │
└────────────┴──────┘

Data Processing Pipeline

Here’s the pipeline for generating the final figure:

idx_colname = "DAY_OF_SEASON"

data = (
    df.with_columns(
1        pl.col("DATE").str.to_datetime(),
2        pl.col("TOBS").interpolate(),
    )
    .sort("DATE")
    .with_columns(
        # Caveat: Cannot be placed in the previous `with_columns()`
        # due to different statuses of `TOBS`.
3        pl.col("TOBS").rolling_mean(window_size=28, center=True).alias("TMEAN"),
4        get_season_expr(col="DATE", alias="SEASON"),
    )
    .with_columns(
5        add_day_of_season_expr(col="DATE", group_col="SEASON", alias=idx_colname)
    ) 
6    .pipe(partial(plot_temps, idx_colname=idx_colname))
)
1
Convert the DATE column to a datetime format.
2
Perform interpolation on the TOBS column.
3
Compute a 28-day rolling average for TOBS.
4
Use get_season_expr() to categorize each date into a SEASON.
5
Apply add_day_of_season_expr() to calculate DAY_OF_SEASON, representing days elapsed since the start of the season.
6
Use plot_temps() to generate the final visualization with Matplotlib.

The first three steps involve straightforward Polars expressions. In the following sections, we’ll dive deeper into steps 4 to 6.

Categorizing Dates into Summer and Ski Seasons

To analyze seasonal trends, we classify dates into two categories:

  • Summer: Covers May through October.
  • Ski: Covers November through April.

If a date falls in November or December, it is assigned to the following year’s season. For example, 2015-10-31 is categorized as Summer 2015, while 2015-11-01 belongs to Ski 2016.

To implement this logic, we define get_season_expr(), which leverages Polars’ when-then-otherwise expressions to determine the season and year.

def get_season_expr(col: str = "DATE", alias: str = "SEASON") -> pl.expr:
    return (
        (
            pl.when((pl.col(col).dt.month().is_between(5, 10, closed="both")))
            .then(pl.lit("Summer "))
            .otherwise(pl.lit("Ski "))
        )
        .add(
            pl.when(pl.col(col).dt.month() < 11)
            .then(pl.col(col).dt.year().cast(pl.String))
            .otherwise(pl.col(col).dt.year().add(1).cast(pl.String))
        )
        .alias(alias)
    )

In this function:

  • If the month is between May and October, the function assigns "Summer ". Otherwise, it assigns "Ski " (with a trailing space for concatenation).
  • The year is determined based on the month: dates from January to October retain their current year, while those in November and December are shifted to the next year.

By applying this function, we can add a SEASON column to a Polars DataFrame, ensuring each date is categorized correctly.

Calculating the Total Days for Each Season

Once we have the seasonal categories, we calculate DAY_OF_SEASON, which tracks the number of days elapsed within each season. This is achieved using the pl.expr.over() expression, which operates similarly to Pandas’ groupby().transform(), applying transformations within groups.

def add_day_of_season_expr(
    col: str = "DATE", group_col: str = "SEASON", alias: str = "DAY_OF_SEASON"
) -> pl.expr:
    return (
        (pl.col(col) - pl.col(col).min()).dt.total_days().over(group_col).alias(alias)
    )

Wrapping Up

Recreating this figure with Polars turned out to be more involved than I initially expected. However, the process was incredibly rewarding, as it deepened my understanding of Pandas, Polars, and Matplotlib. Switching between Pandas and Polars required a shift in mindset, but it also reinforced key concepts in both libraries. I look forward to exploring more of these challenges in the future.

Disclaimer

This post was drafted by me, with AI assistance to refine the content.