Skip to content

Commit a9d8c95

Browse files
committed
route: fix jwt_issuer_format
1 parent ef1bc39 commit a9d8c95

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

internal/provider/enum.go

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package provider
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/hashicorp/terraform-plugin-framework/diag"
7+
"github.com/hashicorp/terraform-plugin-framework/types"
8+
"google.golang.org/protobuf/reflect/protoreflect"
9+
)
10+
11+
// GetValidEnumValues returns a list of valid enum values for a given protobuf enum type.
12+
// it includes zero value as well to match its use in the current api
13+
func GetValidEnumValues[T protoreflect.Enum]() []string {
14+
var values []string
15+
var v T
16+
descriptor := v.Descriptor()
17+
for i := 0; i < descriptor.Values().Len(); i++ {
18+
values = append(values, string(descriptor.Values().Get(i).Name()))
19+
}
20+
return values
21+
}
22+
23+
// EnumValueToPBWithDefault converts a string to a protobuf enum value.
24+
func EnumValueToPBWithDefault[T interface {
25+
~int32
26+
protoreflect.Enum
27+
}](
28+
dst *T,
29+
src types.String,
30+
defaultValue T,
31+
diagnostics *diag.Diagnostics,
32+
) {
33+
if src.IsNull() || src.ValueString() == "" {
34+
*dst = defaultValue
35+
return
36+
}
37+
38+
var v T
39+
enumValue := v.Descriptor().Values().ByName(protoreflect.Name(src.ValueString()))
40+
if enumValue == nil {
41+
diagnostics.AddError(
42+
"InvalidEnumValue",
43+
fmt.Sprintf("The provided %s enum value %q is not valid.", v.Descriptor().FullName(), src.ValueString()),
44+
)
45+
return
46+
}
47+
48+
*dst = T(enumValue.Number())
49+
}
50+
51+
func EnumValueFromPB[T interface {
52+
~int32
53+
protoreflect.Enum
54+
}](
55+
src T,
56+
) types.String {
57+
v := src.Descriptor().Values().ByNumber(protoreflect.EnumNumber(src))
58+
if v == nil {
59+
return types.StringNull()
60+
}
61+
return types.StringValue(string(v.Name()))
62+
}

internal/provider/enum_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package provider_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/hashicorp/terraform-plugin-framework/diag"
7+
"github.com/hashicorp/terraform-plugin-framework/types"
8+
"github.com/pomerium/enterprise-client-go/pb"
9+
"github.com/pomerium/enterprise-terraform-provider/internal/provider"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestEnumValueToPB(t *testing.T) {
15+
t.Parallel()
16+
17+
defaultValue := pb.IssuerFormat(-1)
18+
tests := []struct {
19+
name types.String
20+
expect pb.IssuerFormat
21+
expectError bool
22+
}{
23+
{types.StringValue("IssuerHostOnly"), pb.IssuerFormat_IssuerHostOnly, false},
24+
{types.StringValue("IssuerURI"), pb.IssuerFormat_IssuerURI, false},
25+
{types.StringValue("InvalidInexistentTest"), pb.IssuerFormat(-2), true},
26+
{types.StringNull(), defaultValue, false},
27+
{types.StringValue(""), defaultValue, false},
28+
}
29+
30+
for _, tt := range tests {
31+
t.Run(tt.name.String(), func(t *testing.T) {
32+
var got pb.IssuerFormat
33+
var diagnostics diag.Diagnostics
34+
provider.EnumValueToPBWithDefault(&got, tt.name, defaultValue, &diagnostics)
35+
if tt.expectError {
36+
assert.True(t, diagnostics.HasError())
37+
} else {
38+
require.False(t, diagnostics.HasError(), diagnostics.Errors())
39+
assert.Equal(t, tt.expect, got)
40+
}
41+
})
42+
}
43+
}
44+
45+
func TestEnumValueFromPB(t *testing.T) {
46+
t.Parallel()
47+
48+
tests := []struct {
49+
name pb.IssuerFormat
50+
expect types.String
51+
}{
52+
{pb.IssuerFormat_IssuerHostOnly, types.StringValue("IssuerHostOnly")},
53+
{pb.IssuerFormat_IssuerURI, types.StringValue("IssuerURI")},
54+
{pb.IssuerFormat(-1), types.StringNull()},
55+
}
56+
57+
for _, tt := range tests {
58+
t.Run(tt.expect.String(), func(t *testing.T) {
59+
got := provider.EnumValueFromPB(tt.name)
60+
assert.Equal(t, tt.expect, got)
61+
})
62+
}
63+
}

internal/provider/route_model_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package provider_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/hashicorp/terraform-plugin-framework/types"
8+
"github.com/pomerium/enterprise-client-go/pb"
9+
"github.com/pomerium/enterprise-terraform-provider/internal/provider"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestConvertRouteFromPB(t *testing.T) {
15+
t.Run("jwt_issuer_format", func(t *testing.T) {
16+
testCases := []struct {
17+
name string
18+
input pb.IssuerFormat
19+
expected string
20+
isNull bool
21+
}{
22+
{
23+
name: "host_only",
24+
input: pb.IssuerFormat_IssuerHostOnly,
25+
expected: "IssuerHostOnly",
26+
},
27+
{
28+
name: "uri",
29+
input: pb.IssuerFormat_IssuerURI,
30+
expected: "IssuerURI",
31+
},
32+
{
33+
name: "invalid value",
34+
input: pb.IssuerFormat(999),
35+
isNull: true,
36+
},
37+
}
38+
39+
for _, tc := range testCases {
40+
t.Run(tc.name, func(t *testing.T) {
41+
m := &provider.RouteModel{}
42+
r := &pb.Route{
43+
JwtIssuerFormat: tc.input,
44+
}
45+
diags := provider.ConvertRouteFromPB(m, r)
46+
require.False(t, diags.HasError())
47+
if tc.isNull {
48+
assert.True(t, m.JWTIssuerFormat.IsNull())
49+
} else {
50+
assert.Equal(t, tc.expected, m.JWTIssuerFormat.ValueString())
51+
}
52+
})
53+
}
54+
})
55+
}
56+
57+
func TestConvertRouteToPB(t *testing.T) {
58+
t.Run("jwt_issuer_format", func(t *testing.T) {
59+
testCases := []struct {
60+
name string
61+
input string
62+
expected pb.IssuerFormat
63+
expectError bool
64+
}{
65+
{"host_only", "IssuerHostOnly", pb.IssuerFormat_IssuerHostOnly, false},
66+
{"uri", "IssuerURI", pb.IssuerFormat_IssuerURI, false},
67+
{"invalid_value", "invalid_value", -1, true},
68+
}
69+
70+
for _, tc := range testCases {
71+
t.Run(tc.name, func(t *testing.T) {
72+
m := &provider.RouteModel{
73+
JWTIssuerFormat: types.StringValue(tc.input),
74+
}
75+
r, diag := provider.ConvertRouteToPB(context.Background(), m)
76+
if tc.expectError {
77+
require.True(t, diag.HasError())
78+
} else {
79+
require.False(t, diag.HasError())
80+
assert.Equal(t, tc.expected, r.JwtIssuerFormat)
81+
}
82+
})
83+
}
84+
})
85+
}

0 commit comments

Comments
 (0)