Skip to content

Commit b337d91

Browse files
committed
add consistency checks in iterate
* ensure that iterate generates `length(I)` values * ensure that iterate always makes it to the end of all skips addresses the most egregious bad behaviors in #7 and #31
1 parent 2efe37d commit b337d91

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

src/InvertedIndices.jl

+26-14
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,53 @@ end
9797
struct InvertedIndexIterator{T,S,P} <: AbstractVector{T}
9898
skips::S
9999
picks::P
100+
length::Int
100101
end
101-
InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks)
102-
Base.size(III::InvertedIndexIterator) = (length(III.picks) - length(III.skips),)
102+
InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks, length(picks) - length(skips))
103+
Base.size(III::InvertedIndexIterator) = (III.length,)
104+
105+
# Ensure iteration consumes all skips by the time it hits the end of the picks
106+
assert_iteration_finished(I, n, ::Nothing) = (@assert n == I.length "InvertedIndexIterator iterated $n values but expected $(I.length)"; true)
107+
assert_iteration_finished(I, _, (skipvalue, _)) = throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it"))
108+
# Ensure iteration does not generate more than I.length values
109+
assert_iteration_not_finished(I, n, ::Nothing) = @assert n <= I.length "InvertedIndexIterator iterated more values than expected"
110+
assert_iteration_not_finished(I, n, (skipvalue, _)) = n <= I.length || throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it"))
103111

104112
@inline function Base.iterate(I::InvertedIndexIterator)
113+
n = 0
105114
skipitr = iterate(I.skips)
106115
pickitr = iterate(I.picks)
107-
pickitr === nothing && return nothing
116+
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
108117
while should_skip(skipitr, pickitr)
109118
skipitr = iterate(I.skips, skipitr[2])
110119
pickitr = iterate(I.picks, pickitr[2])
111-
pickitr === nothing && return nothing
120+
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
112121
end
122+
n += 1; assert_iteration_not_finished(I, n, skipitr)
113123
# This is a little silly, but splitting the tuple here allows inference to normalize
114124
# Tuple{Union{Nothing, Tuple}, Tuple} to Union{Tuple{Nothing, Tuple}, Tuple{Tuple, Tuple}}
115125
return skipitr === nothing ?
116-
(pickitr[1], (nothing, pickitr[2])) :
117-
(pickitr[1], (skipitr, pickitr[2]))
126+
(pickitr[1], (nothing, pickitr[2], n)) :
127+
(pickitr[1], (skipitr, pickitr[2], n))
118128
end
119-
@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate)::Tuple{Nothing, Any})
129+
@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate, n)::Tuple{Nothing, Any, Any})
120130
pickitr = iterate(I.picks, pickstate)
121-
pickitr === nothing && return nothing
122-
return (pickitr[1], (nothing, pickitr[2]))
131+
pickitr === nothing && assert_iteration_finished(I, n, nothing) && return nothing
132+
n += 1; assert_iteration_not_finished(I, n, nothing)
133+
return (pickitr[1], (nothing, pickitr[2], n))
123134
end
124-
@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate)::Tuple)
135+
@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate, n)::Tuple)
125136
pickitr = iterate(I.picks, pickstate)
126-
pickitr === nothing && return nothing
137+
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
127138
while should_skip(skipitr, pickitr)
128139
skipitr = iterate(I.skips, tail(skipitr)...)
129140
pickitr = iterate(I.picks, tail(pickitr)...)
130-
pickitr === nothing && return nothing
141+
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
131142
end
143+
n += 1; assert_iteration_not_finished(I, n, skipitr)
132144
return skipitr === nothing ?
133-
(pickitr[1], (nothing, pickitr[2])) :
134-
(pickitr[1], (skipitr, pickitr[2]))
145+
(pickitr[1], (nothing, pickitr[2], n)) :
146+
(pickitr[1], (skipitr, pickitr[2], n))
135147
end
136148
function Base.collect(III::InvertedIndexIterator{T}) where {T}
137149
!isconcretetype(T) && return [i for i in III] # use widening if T is not concrete

test/runtests.jl

+35
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,38 @@ returns(val) = _->val
203203
@test @inferred(LinearIndices(arr)[collect(I)]) == vec(filter(!iseven, arr))
204204
end
205205
end
206+
207+
struct NamedVector{T,A,B} <: AbstractArray{T,1}
208+
data::A
209+
names::B
210+
end
211+
function NamedVector(data, names)
212+
@assert size(data) == size(names)
213+
NamedVector{eltype(data), typeof(data), typeof(names)}(data, names)
214+
end
215+
Base.size(n::NamedVector) = size(n.data)
216+
Base.getindex(n::NamedVector, i::Int) = n.data[i]
217+
Base.to_index(n::NamedVector, name::Symbol) = findfirst(==(name), n.names)
218+
Base.checkbounds(::Type{Bool}, n::NamedVector, names::AbstractArray{Symbol}) = all(name in n.names for name in names)
219+
220+
@testset "ensure skipped indices are skipped" begin
221+
@test_throws "did not find" [1, 2, 3, 4][Not([1.5])]
222+
@test_throws "did not find" [1, 2, 3, 4][Not(Not([1.5]))]
223+
# Without error checking/checkbounds, this segfaults with a large enough array:
224+
@test_throws "did not find" rand(100)[Not(begin+.5:end)]
225+
@test_broken @test_throws "invalid index" [1, 2, 3, 4][Not(Integer[true, 2])]
226+
227+
n = NamedVector(1:4, [:a, :b, :c, :d]);
228+
@test_broken n[Not([:a,:b])] == n[Not(1:2)] == [3, 4]
229+
@test_broken n[Not([:c,:d])] == n[Not(3:4)] == [1, 2]
230+
@test n[Not(:a)] == n[Not(1)] == [2,3,4]
231+
@test n[Not(:b)] == n[Not(2)] == [1,3,4]
232+
233+
n = NamedVector(1:4, [:d, :b, :c, :a]);
234+
@test_broken n[Not([:a,:b])] == n[Not([4,2])]== n[[:d,:c]] == [1, 3]
235+
@test_broken n[Not([:c,:d])] == n[Not([3,1])] == n[[:b,:a]] == [2, 4]
236+
@test n[Not(:a)] == n[Not(4)] == [1,2,3]
237+
@test n[Not(:b)] == n[Not(2)] == [1,3,4]
238+
@test n[Not(:c)] == n[Not(3)] == [1,2,4]
239+
@test n[Not(:d)] == n[Not(1)] == [2,3,4]
240+
end

0 commit comments

Comments
 (0)