case_when() in Polars

python
polars
Author

Jerry Wu

Published

May 25, 2025

This is a follow-up to my previous post.

While the conditional branching mechanism of pl.when().then().otherwise() is quite powerful, I often find it a bit verbose—especially when the conditions are complex. In those cases, it becomes harder to validate the correctness of each branch at a glance.

On the other hand, I find the pd.Series.case_when() pattern in Pandas slightly more concise and readable. However, I’ve always wished it supported a fallback mechanism like Polars’ .otherwise().

In the end, I thought it would be interesting to borrow the concept behind pd.Series.case_when() and implement it as a standalone utility function in Polars.

case_when()

The case_when() function accepts two arguments:

  • caselist: A list of two-element tuples, where the first item is the condition (used in pl.when()), and the second is the corresponding result expression (used in .then()).
  • otherwise: A fallback expression used in .otherwise() if no conditions match.

The given example demonstrates how case_when() can simplify conditional logic compared to the more verbose pl.when().then().otherwise() chain.

from functools import cache
from typing import Any

import polars as pl


def case_when(
    caselist: list[tuple[pl.Expr, pl.Expr]], otherwise: pl.Expr | None = None
) -> pl.Expr:
    """
    Simplifies conditional logic in Polars by chaining multiple `when-then` expressions.

    Parameters
    ----------
    caselist
        A list of (condition, value) pairs. Each condition is evaluated in order,
        and the corresponding value is returned when a condition is met.
    otherwise
        The fallback value to use if none of the conditions match.

    Returns
    -------
    pl.Expr

    Examples:
    -------
    ```python
    import polars as pl

    df = pl.DataFrame({"x": [1, 2, 3, 4]})

    expr = case_when(
        caselist=[
            (pl.col("x") < 2, pl.lit("small")),
            (pl.col("x") < 4, pl.lit("medium"))
        ],
        otherwise=pl.lit("large"),
    ).alias("size")

    # This is equivalent to writing:
    # expr = (
    #     pl.when(pl.col("x") < 2)
    #       .then(pl.lit("small"))
    #       .when(pl.col("x") < 4)
    #       .then(pl.lit("medium"))
    #       .otherwise(pl.lit("large"))
    #       .alias("size")
    # )

    df.with_columns(expr)
    ```
    shape: (4, 2)
    ┌─────┬────────┐
    │ x   ┆ size   │
    │ --- ┆ ---    │
    │ i64 ┆ str    │
    ├─────┼────────┤
    │ 1   ┆ small  │
    │ 2   ┆ medium │
    │ 3   ┆ medium │
    │ 4   ┆ large  │
    └─────┴────────┘
    """
    (first_when, first_then), *cases = caselist

    # first
    expr = pl.when(first_when).then(first_then)

    # middles
    for when, then in cases:
        expr = expr.when(when).then(then)

    # last
    expr = expr.otherwise(otherwise)

    return expr

Custom Expression Namespace

With case_when() in place, we can refactor the DiscreteSplitter expression namespace like this:

@cache
def _mod_expr(n: int) -> pl.Expr:
    return pl.int_range(pl.len(), dtype=pl.UInt32).mod(n)


def _litify(lits: list[Any]) -> list[pl.lit]:
    return [pl.lit(lit) for lit in lits]


@pl.api.register_expr_namespace("spt")
class DiscreteSplitter:
    def __init__(self, expr: pl.Expr) -> None:
        self._expr = expr

    def _get_expr(self, lits: list[Any], name: str):
        n = len(lits)
        mod_expr = _mod_expr(n)
        *litified, litified_otherwise = _litify(lits)
        caselist = [(mod_expr.eq(i), lit) for i, lit in enumerate(litified)]
        return case_when(caselist, litified_otherwise).alias(name)

    def binarize(self, lit1: Any, lit2: Any, name: str = "binarized") -> pl.Expr:
        return self.bucketize([lit1, lit2], name)

    def trinarize(
        self, lit1: Any, lit2: Any, lit3: Any, name: str = "trinarized"
    ) -> pl.Expr:
        return self.bucketize([lit1, lit2, lit3], name)

    def bucketize(self, lits: list[Any], name: str = "bucketized") -> pl.Expr:
        return self._get_expr(lits, name)

Now, bucketize() is the primary method that encapsulates the core logic for categorical mapping. binarize() and trinarize() are just convenient wrappers for common cases.

Here’s a simple example of using the custom expression namespace:

df = (
    pl.DataFrame({"n": [100, 50, 72, 83, 97, 42, 20, 51, 77]})
    .with_row_index(offset=1)
    .with_columns(
        pl.col("").spt.binarize("lightblue", "papayawhip"),
        pl.col("").spt.trinarize("one", "two", "three"),
        pl.col("").spt.bucketize([1, 2, 3, 4]),
    )
)
shape: (9, 5)
┌───────┬─────┬────────────┬────────────┬────────────┐
│ index ┆ n   ┆ binarized  ┆ trinarized ┆ bucketized │
│ ---   ┆ --- ┆ ---        ┆ ---        ┆ ---        │
│ u32   ┆ i64 ┆ str        ┆ str        ┆ i32        │
╞═══════╪═════╪════════════╪════════════╪════════════╡
│ 1     ┆ 100 ┆ lightblue  ┆ one        ┆ 1          │
│ 2     ┆ 50  ┆ papayawhip ┆ two        ┆ 2          │
│ 3     ┆ 72  ┆ lightblue  ┆ three      ┆ 3          │
│ 4     ┆ 83  ┆ papayawhip ┆ one        ┆ 4          │
│ 5     ┆ 97  ┆ lightblue  ┆ two        ┆ 1          │
│ 6     ┆ 42  ┆ papayawhip ┆ three      ┆ 2          │
│ 7     ┆ 20  ┆ lightblue  ┆ one        ┆ 3          │
│ 8     ┆ 51  ┆ papayawhip ┆ two        ┆ 4          │
│ 9     ┆ 77  ┆ lightblue  ┆ three      ┆ 1          │
└───────┴─────┴────────────┴────────────┴────────────┘

Custom DataFrame Namespace

Instead of relying on pl.DataFrame.with_row_index(), we can also use _mod_expr() directly to enable similar categorization.

Here’s how the DiscreteSplitter can be implemented as a custom DataFrame namespace:

@pl.api.register_dataframe_namespace("spt")
class DiscreteSplitter:
    def __init__(self, df: pl.DataFrame) -> None:
        self._df = df

    def _get_expr(self, lits: list[Any], name: str):
        n = len(lits)
        mod_expr = _mod_expr(n)
        *litified, litified_otherwise = _litify(lits)
        caselist = [(mod_expr.eq(i), lit) for i, lit in enumerate(litified)]
        return case_when(caselist, litified_otherwise).alias(name)

    def _get_final_df(self, lits: list[Any], name: str) -> pl.DataFrame:
        cls = type(self)
        expr = self._get_expr(lits, name)
        new_spt = cls(self._df.with_columns(expr))
        return new_spt._df

    def binarize(self, lit1: Any, lit2: Any, name: str = "binarized") -> pl.DataFrame:
        return self.bucketize([lit1, lit2], name=name)

    def trinarize(
        self, lit1: Any, lit2: Any, lit3: Any, name: str = "trinarized"
    ) -> pl.DataFrame:
        return self.bucketize([lit1, lit2, lit3], name=name)

    def bucketize(self, lits: list[Any], name: str = "bucketized") -> pl.DataFrame:
        return self._get_final_df(lits, name)

Example usage:

df = (
    pl.DataFrame({"n": [100, 50, 72, 83, 97, 42, 20, 51, 77]})
    .spt.binarize("lightblue", "papayawhip")
    .spt.trinarize("one", "two", "three")
    .spt.bucketize([1, 2, 3, 4])
    .with_row_index(offset=1)
)
shape: (9, 5)
┌───────┬─────┬────────────┬────────────┬────────────┐
│ index ┆ n   ┆ binarized  ┆ trinarized ┆ bucketized │
│ ---   ┆ --- ┆ ---        ┆ ---        ┆ ---        │
│ u32   ┆ i64 ┆ str        ┆ str        ┆ i32        │
╞═══════╪═════╪════════════╪════════════╪════════════╡
│ 1     ┆ 100 ┆ lightblue  ┆ one        ┆ 1          │
│ 2     ┆ 50  ┆ papayawhip ┆ two        ┆ 2          │
│ 3     ┆ 72  ┆ lightblue  ┆ three      ┆ 3          │
│ 4     ┆ 83  ┆ papayawhip ┆ one        ┆ 4          │
│ 5     ┆ 97  ┆ lightblue  ┆ two        ┆ 1          │
│ 6     ┆ 42  ┆ papayawhip ┆ three      ┆ 2          │
│ 7     ┆ 20  ┆ lightblue  ┆ one        ┆ 3          │
│ 8     ┆ 51  ┆ papayawhip ┆ two        ┆ 4          │
│ 9     ┆ 77  ┆ lightblue  ┆ three      ┆ 1          │
└───────┴─────┴────────────┴────────────┴────────────┘

Conclusion

Extracting the conditional logic into a standalone case_when() function turned out to be both a practical and satisfying exercise—perfect for a rainy afternoon of coding. It not only improves readability but also makes the branching logic easier to reuse and reason about.

Disclaimer

This post was drafted by me, with AI assistance to refine the content.