Skip to content

Commit da0b589

Browse files
authored
fix: Output index type instead of u32 for sum_horizontal with boolean inputs (#20531)
1 parent 409f091 commit da0b589

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

crates/polars-ops/src/series/ops/horizontal.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ pub fn sum_horizontal(
221221

222222
// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
223223
if !ignore_nulls && non_null_cols.len() < columns.len() {
224-
// We must first determine the correct return dtype.
224+
// We must determine the correct return dtype.
225225
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
226-
DataType::Boolean => DataType::UInt32,
226+
DataType::Boolean => IDX_DTYPE,
227227
dt => dt,
228228
};
229229
return Ok(Some(Column::full_null(
@@ -244,7 +244,7 @@ pub fn sum_horizontal(
244244
},
245245
1 => Ok(Some(
246246
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
247-
non_null_cols[0].cast(&DataType::UInt32)?
247+
non_null_cols[0].cast(&IDX_DTYPE)?
248248
} else {
249249
non_null_cols[0].clone()
250250
})?

crates/polars-plan/src/dsl/function_expr/schema.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,10 @@ impl FunctionExpr {
331331
MinHorizontal => mapper.map_to_supertype(),
332332
SumHorizontal { .. } => {
333333
mapper.map_to_supertype().map(|mut f| {
334-
match f.dtype {
335-
// Booleans sum to UInt32.
336-
DataType::Boolean => { f.dtype = DataType::UInt32; f},
337-
_ => f,
334+
if f.dtype == DataType::Boolean {
335+
f.dtype = IDX_DTYPE;
338336
}
337+
f
339338
})
340339
},
341340
MeanHorizontal { .. } => {

py-polars/tests/unit/operations/aggregation/test_horizontal.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,39 @@ def test_sum_single_col() -> None:
319319
)
320320

321321

322+
@pytest.mark.parametrize("ignore_nulls", [False, True])
323+
def test_sum_correct_supertype(ignore_nulls: bool) -> None:
324+
values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item]
325+
lf = pl.LazyFrame(
326+
{
327+
"null": [None, None],
328+
"int": pl.Series(values, dtype=pl.Int32),
329+
"float": pl.Series(values, dtype=pl.Float32),
330+
}
331+
)
332+
333+
# null + int32 should produce int32
334+
out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls))
335+
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)})
336+
assert_frame_equal(out.collect(), expected.collect())
337+
assert out.collect_schema() == expected.collect_schema()
338+
339+
# null + float32 should produce float32
340+
out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls))
341+
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)})
342+
assert_frame_equal(out.collect(), expected.collect())
343+
assert out.collect_schema() == expected.collect_schema()
344+
345+
# null + int32 + float32 should produce float64
346+
values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item]
347+
out = lf.select(
348+
pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls)
349+
)
350+
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)})
351+
assert_frame_equal(out.collect(), expected.collect())
352+
assert out.collect_schema() == expected.collect_schema()
353+
354+
322355
def test_cum_sum_horizontal() -> None:
323356
df = pl.DataFrame(
324357
{
@@ -541,17 +574,17 @@ def test_horizontal_sum_boolean_with_null() -> None:
541574

542575
expected_schema = pl.Schema(
543576
{
544-
"null_first": pl.UInt32,
545-
"bool_first": pl.UInt32,
577+
"null_first": pl.get_index_type(),
578+
"bool_first": pl.get_index_type(),
546579
}
547580
)
548581

549582
assert out.collect_schema() == expected_schema
550583

551584
expected_df = pl.DataFrame(
552585
{
553-
"null_first": pl.Series([1, 0], dtype=pl.UInt32),
554-
"bool_first": pl.Series([1, 0], dtype=pl.UInt32),
586+
"null_first": pl.Series([1, 0], dtype=pl.get_index_type()),
587+
"bool_first": pl.Series([1, 0], dtype=pl.get_index_type()),
555588
}
556589
)
557590

@@ -563,7 +596,7 @@ def test_horizontal_sum_boolean_with_null() -> None:
563596
("dtype_in", "dtype_out"),
564597
[
565598
(pl.Null, pl.Null),
566-
(pl.Boolean, pl.UInt32),
599+
(pl.Boolean, pl.get_index_type()),
567600
(pl.UInt8, pl.UInt8),
568601
(pl.Float32, pl.Float32),
569602
(pl.Float64, pl.Float64),
@@ -589,6 +622,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy(
589622
values = [None, None, None] # type: ignore[list-item]
590623
expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
591624
assert_frame_equal(result, expected)
625+
assert result.collect_schema() == expected.collect_schema()
592626

593627

594628
@pytest.mark.parametrize("ignore_nulls", [True, False])

0 commit comments

Comments
 (0)