from functools import cache
from typing import Any
import polars as pl
def case_when(
list[tuple[pl.Expr, pl.Expr]], otherwise: pl.Expr | None = None
caselist: -> pl.Expr:
) """
Simplifies conditional logic in Polars by chaining multiple `when-then` expressions.
Parameters
----------
caselist
A list of (condition, value) pairs. Each condition is evaluated in order,
and the corresponding value is returned when a condition is met.
otherwise
The fallback value to use if none of the conditions match.
Returns
-------
pl.Expr
Examples:
-------
```python
import polars as pl
df = pl.DataFrame({"x": [1, 2, 3, 4]})
expr = case_when(
caselist=[
(pl.col("x") < 2, pl.lit("small")),
(pl.col("x") < 4, pl.lit("medium"))
],
otherwise=pl.lit("large"),
).alias("size")
# This is equivalent to writing:
# expr = (
# pl.when(pl.col("x") < 2)
# .then(pl.lit("small"))
# .when(pl.col("x") < 4)
# .then(pl.lit("medium"))
# .otherwise(pl.lit("large"))
# .alias("size")
# )
df.with_columns(expr)
```
shape: (4, 2)
┌─────┬────────┐
│ x ┆ size │
│ --- ┆ --- │
│ i64 ┆ str │
├─────┼────────┤
│ 1 ┆ small │
│ 2 ┆ medium │
│ 3 ┆ medium │
│ 4 ┆ large │
└─────┴────────┘
"""
*cases = caselist
(first_when, first_then),
# first
= pl.when(first_when).then(first_then)
expr
# middles
for when, then in cases:
= expr.when(when).then(then)
expr
# last
= expr.otherwise(otherwise)
expr
return expr
This is a follow-up to my previous post.
While the conditional branching mechanism of pl.when().then().otherwise() is quite powerful, I often find it a bit verbose—especially when the conditions are complex. In those cases, it becomes harder to validate the correctness of each branch at a glance.
On the other hand, I find the pd.Series.case_when() pattern in Pandas slightly more concise and readable. However, I’ve always wished it supported a fallback mechanism like Polars’ .otherwise()
.
In the end, I thought it would be interesting to borrow the concept behind pd.Series.case_when()
and implement it as a standalone utility function in Polars.
case_when()
The case_when()
function accepts two arguments:
caselist
: A list of two-element tuples, where the first item is the condition (used inpl.when()
), and the second is the corresponding result expression (used in.then()
).otherwise
: A fallback expression used in.otherwise()
if no conditions match.
The given example demonstrates how case_when()
can simplify conditional logic compared to the more verbose pl.when().then().otherwise()
chain.
Custom Expression Namespace
With case_when()
in place, we can refactor the DiscreteSplitter
expression namespace like this:
@cache
def _mod_expr(n: int) -> pl.Expr:
return pl.int_range(pl.len(), dtype=pl.UInt32).mod(n)
def _litify(lits: list[Any]) -> list[pl.lit]:
return [pl.lit(lit) for lit in lits]
@pl.api.register_expr_namespace("spt")
class DiscreteSplitter:
def __init__(self, expr: pl.Expr) -> None:
self._expr = expr
def _get_expr(self, lits: list[Any], name: str):
= len(lits)
n = _mod_expr(n)
mod_expr *litified, litified_otherwise = _litify(lits)
= [(mod_expr.eq(i), lit) for i, lit in enumerate(litified)]
caselist return case_when(caselist, litified_otherwise).alias(name)
def binarize(self, lit1: Any, lit2: Any, name: str = "binarized") -> pl.Expr:
return self.bucketize([lit1, lit2], name)
def trinarize(
self, lit1: Any, lit2: Any, lit3: Any, name: str = "trinarized"
-> pl.Expr:
) return self.bucketize([lit1, lit2, lit3], name)
def bucketize(self, lits: list[Any], name: str = "bucketized") -> pl.Expr:
return self._get_expr(lits, name)
Now, bucketize()
is the primary method that encapsulates the core logic for categorical mapping. binarize()
and trinarize()
are just convenient wrappers for common cases.
Here’s a simple example of using the custom expression namespace:
= (
df "n": [100, 50, 72, 83, 97, 42, 20, 51, 77]})
pl.DataFrame({=1)
.with_row_index(offset
.with_columns("").spt.binarize("lightblue", "papayawhip"),
pl.col("").spt.trinarize("one", "two", "three"),
pl.col("").spt.bucketize([1, 2, 3, 4]),
pl.col(
) )
shape: (9, 5)
┌───────┬─────┬────────────┬────────────┬────────────┐
│ index ┆ n ┆ binarized ┆ trinarized ┆ bucketized │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ str ┆ str ┆ i32 │
╞═══════╪═════╪════════════╪════════════╪════════════╡
│ 1 ┆ 100 ┆ lightblue ┆ one ┆ 1 │
│ 2 ┆ 50 ┆ papayawhip ┆ two ┆ 2 │
│ 3 ┆ 72 ┆ lightblue ┆ three ┆ 3 │
│ 4 ┆ 83 ┆ papayawhip ┆ one ┆ 4 │
│ 5 ┆ 97 ┆ lightblue ┆ two ┆ 1 │
│ 6 ┆ 42 ┆ papayawhip ┆ three ┆ 2 │
│ 7 ┆ 20 ┆ lightblue ┆ one ┆ 3 │
│ 8 ┆ 51 ┆ papayawhip ┆ two ┆ 4 │
│ 9 ┆ 77 ┆ lightblue ┆ three ┆ 1 │
└───────┴─────┴────────────┴────────────┴────────────┘
Custom DataFrame Namespace
Instead of relying on pl.DataFrame.with_row_index(), we can also use _mod_expr()
directly to enable similar categorization.
Here’s how the DiscreteSplitter
can be implemented as a custom DataFrame namespace:
@pl.api.register_dataframe_namespace("spt")
class DiscreteSplitter:
def __init__(self, df: pl.DataFrame) -> None:
self._df = df
def _get_expr(self, lits: list[Any], name: str):
= len(lits)
n = _mod_expr(n)
mod_expr *litified, litified_otherwise = _litify(lits)
= [(mod_expr.eq(i), lit) for i, lit in enumerate(litified)]
caselist return case_when(caselist, litified_otherwise).alias(name)
def _get_final_df(self, lits: list[Any], name: str) -> pl.DataFrame:
= type(self)
cls = self._get_expr(lits, name)
expr = cls(self._df.with_columns(expr))
new_spt return new_spt._df
def binarize(self, lit1: Any, lit2: Any, name: str = "binarized") -> pl.DataFrame:
return self.bucketize([lit1, lit2], name=name)
def trinarize(
self, lit1: Any, lit2: Any, lit3: Any, name: str = "trinarized"
-> pl.DataFrame:
) return self.bucketize([lit1, lit2, lit3], name=name)
def bucketize(self, lits: list[Any], name: str = "bucketized") -> pl.DataFrame:
return self._get_final_df(lits, name)
Example usage:
= (
df "n": [100, 50, 72, 83, 97, 42, 20, 51, 77]})
pl.DataFrame({"lightblue", "papayawhip")
.spt.binarize("one", "two", "three")
.spt.trinarize(1, 2, 3, 4])
.spt.bucketize([=1)
.with_row_index(offset )
shape: (9, 5)
┌───────┬─────┬────────────┬────────────┬────────────┐
│ index ┆ n ┆ binarized ┆ trinarized ┆ bucketized │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ str ┆ str ┆ i32 │
╞═══════╪═════╪════════════╪════════════╪════════════╡
│ 1 ┆ 100 ┆ lightblue ┆ one ┆ 1 │
│ 2 ┆ 50 ┆ papayawhip ┆ two ┆ 2 │
│ 3 ┆ 72 ┆ lightblue ┆ three ┆ 3 │
│ 4 ┆ 83 ┆ papayawhip ┆ one ┆ 4 │
│ 5 ┆ 97 ┆ lightblue ┆ two ┆ 1 │
│ 6 ┆ 42 ┆ papayawhip ┆ three ┆ 2 │
│ 7 ┆ 20 ┆ lightblue ┆ one ┆ 3 │
│ 8 ┆ 51 ┆ papayawhip ┆ two ┆ 4 │
│ 9 ┆ 77 ┆ lightblue ┆ three ┆ 1 │
└───────┴─────┴────────────┴────────────┴────────────┘
Conclusion
Extracting the conditional logic into a standalone case_when()
function turned out to be both a practical and satisfying exercise—perfect for a rainy afternoon of coding. It not only improves readability but also makes the branching logic easier to reuse and reason about.
This post was drafted by me, with AI assistance to refine the content.