Skip to content

Commit 9bb7660

Browse files
committed
update code to pass lbARN for sheild protection setup and add tests
Signed-off-by: Saurabh Choudhary <[email protected]>
1 parent 3e25b5f commit 9bb7660

3 files changed

+127
-16
lines changed

pkg/service/model_build_load_balancer_addons.go

+10-14
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ package service
33
import (
44
"context"
55

6-
"github.com/pkg/errors"
76
"sigs.k8s.io/aws-load-balancer-controller/pkg/annotations"
7+
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/core"
88
shieldmodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/shield"
99
)
1010

11-
func (t *defaultModelBuildTask) buildLoadBalancerAddOns(ctx context.Context) error {
12-
if _, err := t.buildShieldProtection(ctx); err != nil {
11+
func (t *defaultModelBuildTask) buildLoadBalancerAddOns(ctx context.Context, lbARN core.StringToken) error {
12+
if _, err := t.buildShieldProtection(ctx, lbARN); err != nil {
1313
return err
1414
}
1515
return nil
1616
}
1717

18-
func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context) (*shieldmodel.Protection, error) {
18+
func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context, lbARN core.StringToken) (*shieldmodel.Protection, error) {
1919
explicitEnableProtections := make(map[bool]struct{})
2020
rawEnableProtection := false
2121
exists, err := t.annotationParser.ParseBoolAnnotation(annotations.SvcLBSuffixShieldAdvancedProtection, &rawEnableProtection, t.service.Annotations)
@@ -28,14 +28,10 @@ func (t *defaultModelBuildTask) buildShieldProtection(_ context.Context) (*shiel
2828
if len(explicitEnableProtections) == 0 {
2929
return nil, nil
3030
}
31-
if len(explicitEnableProtections) > 1 {
32-
return nil, errors.New("conflicting enable shield advanced protection")
33-
}
34-
if _, enableProtection := explicitEnableProtections[true]; enableProtection {
35-
protection := shieldmodel.NewProtection(t.stack, resourceIDLoadBalancer, shieldmodel.ProtectionSpec{
36-
ResourceARN: t.loadBalancer.LoadBalancerARN(),
37-
})
38-
return protection, nil
39-
}
40-
return nil, nil
31+
_, enableProtection := explicitEnableProtections[true]
32+
protection := shieldmodel.NewProtection(t.stack, resourceIDLoadBalancer, shieldmodel.ProtectionSpec{
33+
Enabled: enableProtection,
34+
ResourceARN: lbARN,
35+
})
36+
return protection, nil
4137
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package service
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/google/go-cmp/cmp"
8+
"github.com/google/go-cmp/cmp/cmpopts"
9+
"github.com/stretchr/testify/assert"
10+
corev1 "k8s.io/api/core/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"sigs.k8s.io/aws-load-balancer-controller/pkg/annotations"
13+
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/core"
14+
shieldmodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/shield"
15+
)
16+
17+
func Test_defaultModelBuildTask_buildShieldProtection(t *testing.T) {
18+
type args struct {
19+
lbARN core.StringToken
20+
}
21+
tests := []struct {
22+
testName string
23+
svc *corev1.Service
24+
args args
25+
want *shieldmodel.Protection
26+
wantError bool
27+
}{
28+
{
29+
testName: "when shield-advanced-protection annotation is not specified",
30+
svc: &corev1.Service{
31+
ObjectMeta: metav1.ObjectMeta{
32+
Annotations: map[string]string{},
33+
},
34+
},
35+
args: args{
36+
lbARN: core.LiteralStringToken("awesome-lb-arn"),
37+
},
38+
want: nil,
39+
wantError: false,
40+
},
41+
{
42+
testName: "when shield-advanced-protection annotation set to true",
43+
svc: &corev1.Service{
44+
ObjectMeta: metav1.ObjectMeta{
45+
Annotations: map[string]string{
46+
"service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "true",
47+
},
48+
},
49+
},
50+
args: args{
51+
lbARN: core.LiteralStringToken("awesome-lb-arn"),
52+
},
53+
want: &shieldmodel.Protection{
54+
Spec: shieldmodel.ProtectionSpec{
55+
Enabled: true,
56+
ResourceARN: core.LiteralStringToken("awesome-lb-arn"),
57+
},
58+
},
59+
wantError: false,
60+
},
61+
{
62+
testName: "when shield-advanced-protection annotation set to false",
63+
svc: &corev1.Service{
64+
ObjectMeta: metav1.ObjectMeta{
65+
Annotations: map[string]string{
66+
"service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "false",
67+
},
68+
},
69+
},
70+
args: args{
71+
lbARN: core.LiteralStringToken("awesome-lb-arn"),
72+
},
73+
want: &shieldmodel.Protection{
74+
Spec: shieldmodel.ProtectionSpec{
75+
Enabled: false,
76+
ResourceARN: core.LiteralStringToken("awesome-lb-arn"),
77+
},
78+
},
79+
wantError: false,
80+
},
81+
{
82+
testName: "when shield-advanced-protection annotation has non boolean value",
83+
svc: &corev1.Service{
84+
ObjectMeta: metav1.ObjectMeta{
85+
Annotations: map[string]string{
86+
"service.beta.kubernetes.io/aws-load-balancer-nlb-shield-advanced-protection": "FalSe1",
87+
},
88+
},
89+
},
90+
args: args{
91+
lbARN: core.LiteralStringToken("awesome-lb-arn"),
92+
},
93+
wantError: true,
94+
},
95+
}
96+
for _, tt := range tests {
97+
t.Run(tt.testName, func(t *testing.T) {
98+
stack := core.NewDefaultStack(core.StackID{Name: "awesome-stack"})
99+
annotationParser := annotations.NewSuffixAnnotationParser("service.beta.kubernetes.io")
100+
task := &defaultModelBuildTask{
101+
service: tt.svc,
102+
annotationParser: annotationParser,
103+
stack: stack,
104+
}
105+
got, err := task.buildShieldProtection(context.Background(), tt.args.lbARN)
106+
if tt.wantError {
107+
assert.Error(t, err)
108+
} else {
109+
opts := cmpopts.IgnoreTypes(core.ResourceMeta{})
110+
assert.True(t, cmp.Equal(tt.want, got, opts), "diff", cmp.Diff(tt.want, got, opts))
111+
}
112+
})
113+
}
114+
}

pkg/service/model_builder.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package service
22

33
import (
44
"context"
5-
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
65
"strconv"
76
"sync"
87

8+
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
9+
910
"github.com/go-logr/logr"
1011
"github.com/pkg/errors"
1112
corev1 "k8s.io/api/core/v1"
@@ -249,7 +250,7 @@ func (t *defaultModelBuildTask) buildModel(ctx context.Context) error {
249250
if err != nil {
250251
return err
251252
}
252-
if err := t.buildLoadBalancerAddOns(ctx); err != nil {
253+
if err := t.buildLoadBalancerAddOns(ctx, t.loadBalancer.LoadBalancerARN()); err != nil {
253254
return err
254255
}
255256
return nil

0 commit comments

Comments
 (0)