Skip to content

Commit f229f75

Browse files
authored
Merge pull request #25 from invenia/rf/vec-dates
Allow passing vectors of dates
2 parents 845b79f + 0dc394a commit f229f75

9 files changed

+33
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DateSelectors"
22
uuid = "c900ad91-5c7c-41b8-bce1-46b0264fae1d"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"

src/NoneSelector.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@ Assign all dates to the validation set, select no holdout dates.
55
"""
66
struct NoneSelector <: DateSelector end
77

8-
Iterators.partition(dates::StepRange{Date, Day}, ::NoneSelector) = _getdatesets(dates, Date[])
8+
function Iterators.partition(dates::AbstractVector{Date}, ::NoneSelector)
9+
# Just to maintain consistency between selectors
10+
if dates isa StepRange && step(dates) != Day(1)
11+
throw(ArgumentError("Expected step range over days, not ($(step(dates)))."))
12+
end
13+
14+
return _getdatesets(dates, Date[])
15+
end

src/PeriodicSelector.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ struct PeriodicSelector <: DateSelector
3232
end
3333
end
3434

35+
function Iterators.partition(dates::AbstractVector{Date}, s::PeriodicSelector)
36+
if dates isa StepRange && step(dates) != Day(1)
37+
throw(ArgumentError("Expected step range over days, not ($(step(dates)))."))
38+
end
3539

36-
function Iterators.partition(dates::StepRange{Date, Day}, s::PeriodicSelector)
3740
initial_time = _initial_date(s, dates)
3841
sd, ed = extrema(dates)
3942

@@ -43,7 +46,7 @@ function Iterators.partition(dates::StepRange{Date, Day}, s::PeriodicSelector)
4346
# is still is well under 1/4 second. so keeping it simple
4447

4548
holdout_dates = Date[]
46-
curr_window = initial_time:step(dates):(initial_time + s.stride - step(dates))
49+
curr_window = initial_time:Day(1):(initial_time + s.stride - Day(1))
4750
while first(curr_window) <= ed
4851
# optimization: only creating holdout window if intersect not empty
4952
if last(curr_window) >= sd

src/RandomSelector.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@ struct RandomSelector <: DateSelector
3333
end
3434
end
3535

36-
function Iterators.partition(dates::StepRange{Date, Day}, s::RandomSelector)
36+
function Iterators.partition(dates::AbstractVector{Date}, s::RandomSelector)
37+
if dates isa StepRange && step(dates) != Day(1)
38+
throw(ArgumentError("Expected step range over days, not ($(step(dates)))."))
39+
end
40+
3741
sd, ed = extrema(dates)
3842

3943
rng = MersenneTwister(s.seed)
4044

4145
holdout_dates = Date[]
4246
initial_time = _initial_date(s, dates)
43-
curr_window = initial_time:step(dates):(initial_time + s.block_size - step(dates))
47+
curr_window = initial_time:Day(1):(initial_time + s.block_size - Day(1))
4448
while first(curr_window) <= ed
4549
# Important: we must generate a random number for every block even before the start
4650
# so that the `rng` state is updated constistently no matter when the start is

src/common.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ abstract type DateSelector end
1010
"""
1111
partition(dates::AbstractInterval{Date}, s::DateSelector)
1212
partition(dates::StepRange{Date, Day}, selector::DateSelector)
13+
partition(dates::AbstractVector{Date}, s::DateSelector)
1314
1415
Partition the set of `dates` into disjoint `validation` and `holdout` sets according to the
1516
`selector` and return a `NamedTuple({:validation, :holdout})` of iterators.

test/NoneSelector.jl

+3
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@
99
validation, holdout = partition(date_range, selector)
1010
@test validation == date_range
1111
@test isempty(holdout)
12+
13+
# Test that we can also handle any abstract vector
14+
@test first(partition(collect(date_range), selector)) == validation
1215
end

test/PeriodicSelector.jl

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
@test isempty(intersect(result...))
1414

1515
@test all(isequal(Week(1)), diff(result.holdout))
16+
17+
# Test that we can also handle any abstract vector
18+
@test partition(collect(date_range), selector) == result
1619
end
1720

1821
@testset "2 week period, 5 day stride" begin

test/RandomSelector.jl

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
@test result.validation == result2.validation
4646
@test result.holdout == result2.holdout
4747

48+
# Test that we can also handle any abstract vector
49+
@test partition(collect(date_range), selector) == result
50+
4851
@testset "holdout fraction" begin
4952
# Setting holdout_fraction 1 all days leaves the validation set empty
5053
validation, holdout = partition(date_range, RandomSelector(42, 1))

test/sensibility_checks.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
end
3232
end
3333

34-
@testset "Vector of days is not allowed" begin
34+
@testset "Vector of days is allowed" begin
3535
date_range = collect(Date(2019, 1, 1):Day(1):Date(2019, 2, 1))
36-
@test_throws MethodError partition(date_range, NoneSelector())
36+
@test first(partition(date_range, NoneSelector())) == date_range
3737
end
3838

3939
@testset "Weekly intervals are not allowed" begin
@@ -46,7 +46,7 @@
4646
PeriodicSelector(Week(2), Week(1)),
4747
RandomSelector(42),
4848
)
49-
@test_throws MethodError partition(weekly_dates, selector)
49+
@test_throws ArgumentError partition(weekly_dates, selector)
5050
end
5151
end
5252
end

0 commit comments

Comments
 (0)