Skip to content

Commit 0fbc70b

Browse files
committed
make replace/replace! work with count
1 parent 245d224 commit 0fbc70b

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

src/chainedvector.jl

+41-4
Original file line numberDiff line numberDiff line change
@@ -979,9 +979,46 @@ function Base.filter!(f, a::ChainedVector)
979979
return a
980980
end
981981

982-
Base.replace(f::Base.Callable, a::ChainedVector) = ChainedVector([replace(f, A) for A in a.arrays])
983-
Base.replace!(f::Base.Callable, a::ChainedVector) = (foreach(A -> replace!(f, A), a.arrays); return a)
984-
Base.replace(a::ChainedVector, old_new::Pair...; count::Union{Integer,Nothing}=nothing) = ChainedVector([replace(A, old_new...; count=count) for A in a.arrays])
985-
Base.replace!(a::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) = (foreach(A -> replace!(A, old_new...; count=count), a.arrays); return a)
982+
function _check_count(count::Integer)
983+
count < 0 && throw(DomainError(count, "`count` must not be negative"))
984+
return min(count, typemax(Int)) % Int
985+
end
986+
987+
Base.replace(f::Base.Callable, a::ChainedVector; count::Integer=typemax(Int)) =
988+
_replace!(f, copy(a), a, _check_count(count))
989+
990+
Base.replace!(f::Base.Callable, a::ChainedVector; count::Integer=typemax(Int)) =
991+
_replace!(f, a, a, _check_count(count))
992+
993+
Base.replace(A::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) =
994+
_replace_pairs!(copy(A), A, _check_count(count), old_new)
995+
996+
Base.replace!(A::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) =
997+
_replace_pairs!(A, A, _check_count(count), old_new)
998+
999+
function _replace_pairs!(res, A::ChainedVector{T}, count::Int, old_new::Tuple{Vararg{Pair}}) where {T}
1000+
@inline function new(x)
1001+
for (old, new) in old_new
1002+
isequal(x, old) && return new
1003+
end
1004+
return x # no replace
1005+
end
1006+
_replace!(new, res, A, count)
1007+
end
1008+
1009+
function _replace!(new::Base.Callable, res, A::ChainedVector{T}, count::Int) where {T}
1010+
count == 0 && return res
1011+
c = 0
1012+
for i in eachindex(A)
1013+
x = A[i]
1014+
y = new(x)
1015+
if x !== y
1016+
res[i] = y
1017+
c += 1
1018+
c == count && break
1019+
end
1020+
end
1021+
return res
1022+
end
9861023

9871024
Base.Broadcast.broadcasted(f::F, A::ChainedVector) where {F} = map(f, A)

test/chainedvector.jl

+49-1
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,13 @@
350350
@test map(x -> x == 1 ? 2.0 : x, x) == replace!(x, 1 => 2)
351351
@test isempty(x)
352352

353+
@test replace!(ChainedVector([[1,2], [1,2]]), 2=>20) == [1,20,1,20]
354+
@test replace!(ChainedVector([[1,2], [1,2]]), 2=>20, count=1) == [1,20,1,2]
355+
@test replace!(ChainedVector([[1,2], [1,2]]), 2=>20, count=2) == [1,20,1,20]
356+
x = [1,2]
357+
@test replace!(ChainedVector([x,[2,3]]), 2=>99) == [1,99,99,3]
358+
@test x == [1,99]
359+
353360
# copyto!
354361
# ChainedVector dest: doffs, soffs, n
355362
x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
@@ -593,6 +600,47 @@ end
593600
end
594601
end
595602

603+
@testset "replace[!] comparison with Vector" begin
604+
605+
testvecs = (
606+
[[1, 2], [3, 2, 5]],
607+
[[1, 2]],
608+
[[2],[2],[2],[2,3]],
609+
[[1,2,missing]],
610+
[[missing,1],[missing,2,1]],
611+
[[missing]]
612+
)
613+
function missing_equal(a,b)
614+
ismissing(a) && ismissing(b) && return true
615+
ismissing(a) ismissing(b) && return false
616+
return all(skipmissing(a) .== skipmissing(b))
617+
end
618+
gen_cv_v(x) = (c = ChainedVector(x); (c, collect(c)))
619+
for f in (replace, replace!)
620+
for x in testvecs
621+
cv, v = gen_cv_v(x)
622+
@test missing_equal(f(v, 2 => 22),f(cv, 2 => 22))
623+
@test missing_equal(v,cv)
624+
625+
cv, v = gen_cv_v(x)
626+
@test missing_equal(f(x -> x ÷ 2, v), f(x -> x ÷ 2, cv))
627+
@test missing_equal(v,cv)
628+
629+
for c in (0, 1, 2, 3)
630+
cv, v = gen_cv_v(x)
631+
@test missing_equal(f(x -> x ÷ 2, v, count=c), f(x -> x ÷ 2, cv, count=c))
632+
@test missing_equal(v,cv)
633+
634+
for p in ((2=>2,),(2 => 22,), (2 => 22, 3 => 33))
635+
cv, v = gen_cv_v(x)
636+
@test missing_equal(f(v, p..., count=c), f(cv, p..., count=c))
637+
@test missing_equal(v,cv)
638+
end
639+
end
640+
end
641+
end
642+
end
643+
596644

597645
@testset "iteration protocol on ChainedVector" begin
598646
for len in 0:6
@@ -752,7 +800,7 @@ end
752800
end
753801

754802
@testset "getindex with UnitRange" begin
755-
x = ChainedVector([collect(1:i) for i = 10:100])
803+
x = ChainedVector([collect(1:i) for i = 1:10])
756804
@test isempty(x[1:0])
757805
@test x[1:1] == [1]
758806
@test x[1:end] == x

0 commit comments

Comments
 (0)