|
97 | 97 | struct InvertedIndexIterator{T,S,P} <: AbstractVector{T}
|
98 | 98 | skips::S
|
99 | 99 | picks::P
|
| 100 | + length::Int |
100 | 101 | end
|
101 |
| -InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks) |
102 |
| -Base.size(III::InvertedIndexIterator) = (length(III.picks) - length(III.skips),) |
| 102 | +InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks, length(picks) - length(skips)) |
| 103 | +Base.size(III::InvertedIndexIterator) = (III.length,) |
| 104 | + |
| 105 | +# Ensure iteration consumes all skips by the time it hits the end of the picks |
| 106 | +assert_iteration_finished(I, n, ::Nothing) = (@assert n == I.length "InvertedIndexIterator iterated $n values but expected $(I.length)"; true) |
| 107 | +assert_iteration_finished(I, _, (skipvalue, _)) = throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it")) |
| 108 | +# Ensure iteration does not generate more than I.length values |
| 109 | +assert_iteration_not_finished(I, n, ::Nothing) = @assert n <= I.length "InvertedIndexIterator iterated more values than expected" |
| 110 | +assert_iteration_not_finished(I, n, (skipvalue, _)) = n <= I.length || throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it")) |
103 | 111 |
|
104 | 112 | @inline function Base.iterate(I::InvertedIndexIterator)
|
| 113 | + n = 0 |
105 | 114 | skipitr = iterate(I.skips)
|
106 | 115 | pickitr = iterate(I.picks)
|
107 |
| - pickitr === nothing && return nothing |
| 116 | + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing |
108 | 117 | while should_skip(skipitr, pickitr)
|
109 | 118 | skipitr = iterate(I.skips, skipitr[2])
|
110 | 119 | pickitr = iterate(I.picks, pickitr[2])
|
111 |
| - pickitr === nothing && return nothing |
| 120 | + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing |
112 | 121 | end
|
| 122 | + n += 1; assert_iteration_not_finished(I, n, skipitr) |
113 | 123 | # This is a little silly, but splitting the tuple here allows inference to normalize
|
114 | 124 | # Tuple{Union{Nothing, Tuple}, Tuple} to Union{Tuple{Nothing, Tuple}, Tuple{Tuple, Tuple}}
|
115 | 125 | return skipitr === nothing ?
|
116 |
| - (pickitr[1], (nothing, pickitr[2])) : |
117 |
| - (pickitr[1], (skipitr, pickitr[2])) |
| 126 | + (pickitr[1], (nothing, pickitr[2], n)) : |
| 127 | + (pickitr[1], (skipitr, pickitr[2], n)) |
118 | 128 | end
|
119 |
| -@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate)::Tuple{Nothing, Any}) |
| 129 | +@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate, n)::Tuple{Nothing, Any, Any}) |
120 | 130 | pickitr = iterate(I.picks, pickstate)
|
121 |
| - pickitr === nothing && return nothing |
122 |
| - return (pickitr[1], (nothing, pickitr[2])) |
| 131 | + pickitr === nothing && assert_iteration_finished(I, n, nothing) && return nothing |
| 132 | + n += 1; assert_iteration_not_finished(I, n, nothing) |
| 133 | + return (pickitr[1], (nothing, pickitr[2], n)) |
123 | 134 | end
|
124 |
| -@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate)::Tuple) |
| 135 | +@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate, n)::Tuple) |
125 | 136 | pickitr = iterate(I.picks, pickstate)
|
126 |
| - pickitr === nothing && return nothing |
| 137 | + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing |
127 | 138 | while should_skip(skipitr, pickitr)
|
128 | 139 | skipitr = iterate(I.skips, tail(skipitr)...)
|
129 | 140 | pickitr = iterate(I.picks, tail(pickitr)...)
|
130 |
| - pickitr === nothing && return nothing |
| 141 | + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing |
131 | 142 | end
|
| 143 | + n += 1; assert_iteration_not_finished(I, n, skipitr) |
132 | 144 | return skipitr === nothing ?
|
133 |
| - (pickitr[1], (nothing, pickitr[2])) : |
134 |
| - (pickitr[1], (skipitr, pickitr[2])) |
| 145 | + (pickitr[1], (nothing, pickitr[2], n)) : |
| 146 | + (pickitr[1], (skipitr, pickitr[2], n)) |
135 | 147 | end
|
136 | 148 | function Base.collect(III::InvertedIndexIterator{T}) where {T}
|
137 | 149 | !isconcretetype(T) && return [i for i in III] # use widening if T is not concrete
|
|
0 commit comments