@@ -319,6 +319,39 @@ def test_sum_single_col() -> None:
319
319
)
320
320
321
321
322
+ @pytest .mark .parametrize ("ignore_nulls" , [False , True ])
323
+ def test_sum_correct_supertype (ignore_nulls : bool ) -> None :
324
+ values = [1 , 2 ] if ignore_nulls else [None , None ] # type: ignore[list-item]
325
+ lf = pl .LazyFrame (
326
+ {
327
+ "null" : [None , None ],
328
+ "int" : pl .Series (values , dtype = pl .Int32 ),
329
+ "float" : pl .Series (values , dtype = pl .Float32 ),
330
+ }
331
+ )
332
+
333
+ # null + int32 should produce int32
334
+ out = lf .select (pl .sum_horizontal ("null" , "int" , ignore_nulls = ignore_nulls ))
335
+ expected = pl .LazyFrame ({"null" : pl .Series (values , dtype = pl .Int32 )})
336
+ assert_frame_equal (out .collect (), expected .collect ())
337
+ assert out .collect_schema () == expected .collect_schema ()
338
+
339
+ # null + float32 should produce float32
340
+ out = lf .select (pl .sum_horizontal ("null" , "float" , ignore_nulls = ignore_nulls ))
341
+ expected = pl .LazyFrame ({"null" : pl .Series (values , dtype = pl .Float32 )})
342
+ assert_frame_equal (out .collect (), expected .collect ())
343
+ assert out .collect_schema () == expected .collect_schema ()
344
+
345
+ # null + int32 + float32 should produce float64
346
+ values = [2 , 4 ] if ignore_nulls else [None , None ] # type: ignore[list-item]
347
+ out = lf .select (
348
+ pl .sum_horizontal ("null" , "int" , "float" , ignore_nulls = ignore_nulls )
349
+ )
350
+ expected = pl .LazyFrame ({"null" : pl .Series (values , dtype = pl .Float64 )})
351
+ assert_frame_equal (out .collect (), expected .collect ())
352
+ assert out .collect_schema () == expected .collect_schema ()
353
+
354
+
322
355
def test_cum_sum_horizontal () -> None :
323
356
df = pl .DataFrame (
324
357
{
@@ -541,17 +574,17 @@ def test_horizontal_sum_boolean_with_null() -> None:
541
574
542
575
expected_schema = pl .Schema (
543
576
{
544
- "null_first" : pl .UInt32 ,
545
- "bool_first" : pl .UInt32 ,
577
+ "null_first" : pl .get_index_type () ,
578
+ "bool_first" : pl .get_index_type () ,
546
579
}
547
580
)
548
581
549
582
assert out .collect_schema () == expected_schema
550
583
551
584
expected_df = pl .DataFrame (
552
585
{
553
- "null_first" : pl .Series ([1 , 0 ], dtype = pl .UInt32 ),
554
- "bool_first" : pl .Series ([1 , 0 ], dtype = pl .UInt32 ),
586
+ "null_first" : pl .Series ([1 , 0 ], dtype = pl .get_index_type () ),
587
+ "bool_first" : pl .Series ([1 , 0 ], dtype = pl .get_index_type () ),
555
588
}
556
589
)
557
590
@@ -563,7 +596,7 @@ def test_horizontal_sum_boolean_with_null() -> None:
563
596
("dtype_in" , "dtype_out" ),
564
597
[
565
598
(pl .Null , pl .Null ),
566
- (pl .Boolean , pl .UInt32 ),
599
+ (pl .Boolean , pl .get_index_type () ),
567
600
(pl .UInt8 , pl .UInt8 ),
568
601
(pl .Float32 , pl .Float32 ),
569
602
(pl .Float64 , pl .Float64 ),
@@ -589,6 +622,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy(
589
622
values = [None , None , None ] # type: ignore[list-item]
590
623
expected = pl .LazyFrame (pl .Series ("null" , values , dtype = dtype_out ))
591
624
assert_frame_equal (result , expected )
625
+ assert result .collect_schema () == expected .collect_schema ()
592
626
593
627
594
628
@pytest .mark .parametrize ("ignore_nulls" , [True , False ])
0 commit comments