Skip to content

Commit a71ebb2

Browse files
authored
perf: Implement linear-time rolling_min/max (#21770)
1 parent 38b298b commit a71ebb2

File tree

9 files changed

+223
-697
lines changed

9 files changed

+223
-697
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use std::collections::VecDeque;
2+
use std::marker::PhantomData;
3+
4+
use arrow::bitmap::Bitmap;
5+
use arrow::types::NativeType;
6+
use polars_utils::min_max::MinMaxPolicy;
7+
8+
use super::RollingFnParams;
9+
use super::no_nulls::RollingAggWindowNoNulls;
10+
use super::nulls::RollingAggWindowNulls;
11+
12+
// Algorithm: https://cs.stackexchange.com/questions/120915/interview-question-with-arrays-and-consecutive-subintervals/120936#120936
13+
pub struct MinMaxWindow<'a, T, P> {
14+
values: &'a [T],
15+
validity: Option<&'a Bitmap>,
16+
// values[monotonic_idxs[i]] is better than values[monotonic_idxs[i+1]] for
17+
// all i, as per the policy.
18+
monotonic_idxs: VecDeque<usize>,
19+
nonnulls_in_window: usize,
20+
last_end: usize,
21+
policy: PhantomData<P>,
22+
}
23+
24+
impl<T: NativeType, P: MinMaxPolicy> MinMaxWindow<'_, T, P> {
25+
/// # Safety
26+
/// The index must be in-bounds.
27+
unsafe fn insert_nonnull_value(&mut self, idx: usize) {
28+
unsafe {
29+
let value = self.values.get_unchecked(idx);
30+
31+
// Remove values which are older and worse.
32+
while let Some(tail_idx) = self.monotonic_idxs.back() {
33+
let tail_value = self.values.get_unchecked(*tail_idx);
34+
if !P::is_better(value, tail_value) {
35+
break;
36+
}
37+
self.monotonic_idxs.pop_back();
38+
}
39+
40+
self.monotonic_idxs.push_back(idx);
41+
self.nonnulls_in_window += 1;
42+
}
43+
}
44+
45+
fn remove_old_values(&mut self, window_start: usize) {
46+
// Remove values which have fallen outside the window start.
47+
while let Some(head_idx) = self.monotonic_idxs.front() {
48+
if *head_idx >= window_start {
49+
break;
50+
}
51+
self.monotonic_idxs.pop_front();
52+
self.nonnulls_in_window -= 1;
53+
}
54+
}
55+
}
56+
57+
impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<'a, T> for MinMaxWindow<'a, T, P> {
58+
unsafe fn new(
59+
slice: &'a [T],
60+
validity: &'a Bitmap,
61+
start: usize,
62+
end: usize,
63+
params: Option<RollingFnParams>,
64+
) -> Self {
65+
assert!(params.is_none());
66+
let mut slf = Self {
67+
values: slice,
68+
validity: Some(validity),
69+
monotonic_idxs: VecDeque::new(),
70+
nonnulls_in_window: 0,
71+
last_end: 0,
72+
policy: PhantomData,
73+
};
74+
unsafe {
75+
RollingAggWindowNulls::update(&mut slf, start, end);
76+
}
77+
slf
78+
}
79+
80+
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
81+
unsafe {
82+
let v = self.validity.unwrap_unchecked();
83+
self.remove_old_values(start);
84+
for i in start.max(self.last_end)..end {
85+
if v.get_bit_unchecked(i) {
86+
self.insert_nonnull_value(i);
87+
}
88+
}
89+
self.last_end = end;
90+
self.monotonic_idxs
91+
.front()
92+
.map(|idx| *self.values.get_unchecked(*idx))
93+
}
94+
}
95+
96+
fn is_valid(&self, min_periods: usize) -> bool {
97+
self.nonnulls_in_window >= min_periods
98+
}
99+
}
100+
101+
impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<'a, T> for MinMaxWindow<'a, T, P> {
102+
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
103+
assert!(params.is_none());
104+
let mut slf = Self {
105+
values: slice,
106+
validity: None,
107+
monotonic_idxs: VecDeque::new(),
108+
nonnulls_in_window: 0,
109+
last_end: 0,
110+
policy: PhantomData,
111+
};
112+
unsafe {
113+
RollingAggWindowNoNulls::update(&mut slf, start, end);
114+
}
115+
slf
116+
}
117+
118+
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
119+
unsafe {
120+
self.remove_old_values(start);
121+
for i in start.max(self.last_end)..end {
122+
self.insert_nonnull_value(i);
123+
}
124+
self.last_end = end;
125+
self.monotonic_idxs
126+
.front()
127+
.map(|idx| *self.values.get_unchecked(*idx))
128+
}
129+
}
130+
}

crates/polars-compute/src/rolling/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod min_max;
12
pub mod no_nulls;
23
pub mod nulls;
34
pub mod quantile_filter;
@@ -10,7 +11,6 @@ use arrow::bitmap::{Bitmap, MutableBitmap};
1011
use arrow::types::NativeType;
1112
use num_traits::{Bounded, Float, NumCast, One, Zero};
1213
use polars_utils::float::IsFloat;
13-
use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
1414
#[cfg(feature = "serde")]
1515
use serde::{Deserialize, Serialize};
1616
use strum_macros::IntoStaticStr;

0 commit comments

Comments
 (0)