Skip to content

Commit 049bc46

Browse files
committed
fix bug with symmetric matrices and Hemritian matrices
1 parent d8af82a commit 049bc46

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

src/iternz.jl

+17-11
Original file line numberDiff line numberDiff line change
@@ -209,32 +209,39 @@ end
209209
nothing
210210
end
211211

212+
const SymOrHerm = Union{<:Symmetric, <:Hermitian}
213+
Base.IteratorSize(::IterateNZ{2, <:SymOrHerm}) = Base.SizeUnknown()
212214

213-
Base.IteratorSize(::IterateNZ{2, <:Symmetric}) = Base.SizeUnknown()
214-
215-
@inline Base.iterate(x::IterateNZ{2, <:Symmetric}) =
215+
@inline Base.iterate(x::IterateNZ{2, <:SymOrHerm}) =
216216
let iterator = iternz(x.m.data)
217217
iternzsym(x.m, iterator, iterate(iterator))
218218
end
219+
# state is
220+
# - the iterator to the inner internz
221+
# - the (v, i, j) Tuple
222+
# - a boolean that indicates we have to return the transposed
219223

220-
@inline Base.iterate(x::IterateNZ{2, <:Symmetric}, state) =
224+
@inline Base.iterate(x::IterateNZ{2, <:SymOrHerm}, state) =
221225
let (iterator, (v, i, j), r, s) = state
222226
if r
223-
(v, j, i), (iterator, (v, i, j), false, s)
227+
(isa(x.m, Symmetric) ? transpose(v) : adjoint(v), j, i), (iterator, (v, i, j), false, s)
224228
else
225229
iternzsym(x.m, iterator, iterate(iterator, s))
226230
end
227231
end
228232

229-
@inline iternzsym(m::Symmetric, iterator, a) = @inbounds begin
233+
@inline iternzsym(m::SymOrHerm, iterator, a) = @inbounds begin
230234
while a !== nothing
231235
r, state = a
232-
(_, i, j) = r
233-
if m.uplo == 'U'
234-
i <= j && return r, (iterator, r, i != j, state)
236+
(v, i, j) = r
237+
if i == j
238+
v1 = isa(m, Symmetric) ? LinearAlgebra.symmetric(v, LinearAlgebra.sym_uplo(m.uplo)) : LinearAlgebra.hermitian(v, LinearAlgebra.sym_uplo(m.uplo))
239+
return (v1, i, i), (iterator, r, false, state)
240+
elseif m.uplo == 'U'
241+
i <= j && return r, (iterator, r, true, state)
235242
state = skip_col(iterator, state)
236243
elseif m.uplo == 'L'
237-
i >= j && return r, (iterator, r, i != j, state)
244+
i >= j && return r, (iterator, r, true, state)
238245
state = skip_row_to(iterator, state, j)
239246
end
240247
a = iterate(iterator, state)
@@ -244,4 +251,3 @@ end
244251

245252

246253

247-

test/runtests.jl

+15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ test_iternz_arr(_a::AbstractArray{T, N}, it=iternz(_a)) where {T, N} = begin
2828
@test iszero(a[I]) || Tuple(I) seen
2929
end
3030
end
31+
32+
33+
34+
@testset "iternz (Symmetric/Hermitian)" begin
35+
for i in 1:20
36+
for uplo in [:U, :L]
37+
A = Symmetric([randn(4, 4) for _ in 1:i, _ in 1:i], uplo)
38+
test_iternz_arr(A)
39+
B = Hermitian([randn(i * 2, i * 2) .+ randn(i * 2, i * 2) * 1im for _ in 1:i, _ in 1:i], uplo)
40+
test_iternz_arr(B)
41+
end
42+
end
43+
end
44+
3145
@testset "iternz (Symmetric)" begin
3246
for i in 1:20
3347
for uplo in [:U, :L]
@@ -40,6 +54,7 @@ end
4054
end
4155

4256

57+
4358
@testset "iternz (SparseMatrixCSC)" begin
4459
for i in 1:20
4560
A = sprandn(100, 100, 1 / i)

0 commit comments

Comments
 (0)