Skip to content

Commit d551a21

Browse files
committed
adds jwt_groups_filter
1 parent 842803d commit d551a21

7 files changed

+225
-1
lines changed

example/main.tf

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ terraform {
22
required_providers {
33
pomerium = {
44
source = "pomerium/pomerium"
5-
version = "0.0.5"
5+
version = "0.0.7"
66
}
77
}
88
}
@@ -52,6 +52,10 @@ resource "pomerium_settings" "settings" {
5252
proxy_log_level = "info"
5353

5454
timeout_idle = "5m"
55+
56+
jwt_groups_filter = {
57+
groups = ["id1", "id2"]
58+
}
5559
}
5660

5761
resource "pomerium_service_account" "test_sa" {
@@ -82,6 +86,9 @@ resource "pomerium_route" "test_route" {
8286
from = "https://verify-tf.localhost.pomerium.io"
8387
to = ["https://verify.pomerium.com"]
8488
policies = [pomerium_policy.test_policy.id]
89+
jwt_groups_filter = {
90+
infer_from_ppl = true
91+
}
8592
}
8693

8794
resource "pomerium_key_pair" "test_key_pair" {

internal/provider/jwt_group_filter.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
6+
"github.com/hashicorp/terraform-plugin-framework/attr"
7+
"github.com/hashicorp/terraform-plugin-framework/diag"
8+
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
9+
"github.com/hashicorp/terraform-plugin-framework/types"
10+
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
11+
"github.com/pomerium/enterprise-client-go/pb"
12+
)
13+
14+
var (
15+
jwtGroupsFilterSchema = schema.SingleNestedAttribute{
16+
Optional: true,
17+
Description: "JWT Groups Filter",
18+
Attributes: map[string]schema.Attribute{
19+
"groups": schema.SetAttribute{
20+
ElementType: types.StringType,
21+
Optional: true,
22+
Computed: false,
23+
Sensitive: false,
24+
Description: "Group IDs to filter",
25+
},
26+
"infer_from_ppl": schema.BoolAttribute{
27+
Optional: true,
28+
},
29+
},
30+
}
31+
jwtGroupsFilterSchemaAttr = map[string]attr.Type{
32+
"groups": types.SetType{
33+
ElemType: types.StringType,
34+
},
35+
"infer_from_ppl": types.BoolType,
36+
}
37+
)
38+
39+
func JWTGroupsFilterFromPB(
40+
dst *types.Object,
41+
src *pb.JwtGroupsFilter,
42+
diags *diag.Diagnostics,
43+
) {
44+
if src == nil {
45+
*dst = types.ObjectNull(jwtGroupsFilterSchemaAttr)
46+
return
47+
}
48+
49+
attrs := make(map[string]attr.Value)
50+
if src.Groups == nil {
51+
attrs["groups"] = types.SetNull(types.StringType)
52+
} else {
53+
var vals []attr.Value
54+
for _, v := range src.Groups {
55+
vals = append(vals, types.StringValue(v))
56+
}
57+
attrs["groups"] = types.SetValueMust(types.StringType, vals)
58+
}
59+
60+
attrs["infer_from_ppl"] = types.BoolValue(src.InferFromPpl)
61+
62+
*dst = types.ObjectValueMust(jwtGroupsFilterSchemaAttr, attrs)
63+
}
64+
65+
func JWTGroupsFilterToPB(
66+
ctx context.Context,
67+
dst **pb.JwtGroupsFilter,
68+
src types.Object,
69+
diags *diag.Diagnostics,
70+
) {
71+
if src.IsNull() {
72+
dst = nil
73+
return
74+
}
75+
76+
type jwtOptions struct {
77+
Groups []string `tfsdk:"groups"`
78+
InferFromPpl bool `tfsdk:"infer_from_ppl"`
79+
}
80+
var opts jwtOptions
81+
d := src.As(ctx, &opts, basetypes.ObjectAsOptions{
82+
UnhandledNullAsEmpty: true,
83+
UnhandledUnknownAsEmpty: false,
84+
})
85+
diags.Append(d...)
86+
if d.HasError() {
87+
return
88+
}
89+
90+
*dst = &pb.JwtGroupsFilter{
91+
Groups: opts.Groups,
92+
InferFromPpl: opts.InferFromPpl,
93+
}
94+
}
+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/google/go-cmp/cmp"
8+
"github.com/hashicorp/terraform-plugin-framework/attr"
9+
"github.com/hashicorp/terraform-plugin-framework/diag"
10+
"github.com/hashicorp/terraform-plugin-framework/types"
11+
"github.com/pomerium/enterprise-client-go/pb"
12+
"github.com/stretchr/testify/assert"
13+
"google.golang.org/protobuf/testing/protocmp"
14+
)
15+
16+
func TestJWTGroupsFilterFromPB(t *testing.T) {
17+
tests := []struct {
18+
name string
19+
input *pb.JwtGroupsFilter
20+
expected types.Object
21+
}{
22+
{
23+
name: "nil input",
24+
input: nil,
25+
expected: types.ObjectNull(jwtGroupsFilterSchemaAttr),
26+
},
27+
{
28+
name: "empty groups",
29+
input: &pb.JwtGroupsFilter{
30+
Groups: []string{},
31+
InferFromPpl: false,
32+
},
33+
expected: types.ObjectValueMust(jwtGroupsFilterSchemaAttr, map[string]attr.Value{
34+
"groups": types.SetValueMust(types.StringType, []attr.Value{}),
35+
"infer_from_ppl": types.BoolValue(false),
36+
}),
37+
},
38+
{
39+
name: "with groups",
40+
input: &pb.JwtGroupsFilter{
41+
Groups: []string{"group1", "group2"},
42+
InferFromPpl: true,
43+
},
44+
expected: types.ObjectValueMust(jwtGroupsFilterSchemaAttr, map[string]attr.Value{
45+
"groups": types.SetValueMust(types.StringType, []attr.Value{
46+
types.StringValue("group1"),
47+
types.StringValue("group2"),
48+
}),
49+
"infer_from_ppl": types.BoolValue(true),
50+
}),
51+
},
52+
}
53+
54+
for _, tc := range tests {
55+
t.Run(tc.name, func(t *testing.T) {
56+
var diags diag.Diagnostics
57+
var result types.Object
58+
JWTGroupsFilterFromPB(&result, tc.input, &diags)
59+
assert.False(t, diags.HasError())
60+
diff := cmp.Diff(tc.expected, result)
61+
assert.Empty(t, diff)
62+
})
63+
}
64+
}
65+
66+
func TestJWTGroupsFilterToPB(t *testing.T) {
67+
ctx := context.Background()
68+
tests := []struct {
69+
name string
70+
input types.Object
71+
expected *pb.JwtGroupsFilter
72+
}{
73+
{
74+
name: "null input",
75+
input: types.ObjectNull(jwtGroupsFilterSchemaAttr),
76+
expected: nil,
77+
},
78+
{
79+
name: "empty groups",
80+
input: types.ObjectValueMust(jwtGroupsFilterSchemaAttr, map[string]attr.Value{
81+
"groups": types.SetValueMust(types.StringType, []attr.Value{}),
82+
"infer_from_ppl": types.BoolValue(false),
83+
}),
84+
expected: &pb.JwtGroupsFilter{
85+
Groups: []string{},
86+
InferFromPpl: false,
87+
},
88+
},
89+
{
90+
name: "with groups",
91+
input: types.ObjectValueMust(jwtGroupsFilterSchemaAttr, map[string]attr.Value{
92+
"groups": types.SetValueMust(types.StringType, []attr.Value{
93+
types.StringValue("group1"),
94+
types.StringValue("group2"),
95+
}),
96+
"infer_from_ppl": types.BoolValue(true),
97+
}),
98+
expected: &pb.JwtGroupsFilter{
99+
Groups: []string{"group1", "group2"},
100+
InferFromPpl: true,
101+
},
102+
},
103+
}
104+
105+
for _, tc := range tests {
106+
t.Run(tc.name, func(t *testing.T) {
107+
var diags diag.Diagnostics
108+
var result *pb.JwtGroupsFilter
109+
JWTGroupsFilterToPB(ctx, &result, tc.input, &diags)
110+
assert.False(t, diags.HasError())
111+
diff := cmp.Diff(tc.expected, result, protocmp.Transform())
112+
assert.Empty(t, diff)
113+
})
114+
}
115+
}

internal/provider/route.go

+1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ func (r *RouteResource) Schema(_ context.Context, _ resource.SchemaRequest, resp
198198
Optional: true,
199199
Computed: true,
200200
},
201+
"jwt_groups_filter": jwtGroupsFilterSchema,
201202
},
202203
}
203204
}

internal/provider/route_model.go

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type RouteModel struct {
4747
IDPClientID types.String `tfsdk:"idp_client_id"`
4848
IDPClientSecret types.String `tfsdk:"idp_client_secret"`
4949
ShowErrorDetails types.Bool `tfsdk:"show_error_details"`
50+
JWTGroupsFilter types.Object `tfsdk:"jwt_groups_filter"`
5051
}
5152

5253
func ConvertRouteToPB(
@@ -89,6 +90,7 @@ func ConvertRouteToPB(
8990
pbRoute.IdpClientId = src.IDPClientID.ValueStringPointer()
9091
pbRoute.IdpClientSecret = src.IDPClientSecret.ValueStringPointer()
9192
pbRoute.ShowErrorDetails = src.ShowErrorDetails.ValueBool()
93+
JWTGroupsFilterToPB(ctx, &pbRoute.JwtGroupsFilter, src.JWTGroupsFilter, &diagnostics)
9294

9395
diags := src.To.ElementsAs(ctx, &pbRoute.To, false)
9496
diagnostics.Append(diags...)
@@ -147,6 +149,7 @@ func ConvertRouteFromPB(
147149
dst.IDPClientID = types.StringPointerValue(src.IdpClientId)
148150
dst.IDPClientSecret = types.StringPointerValue(src.IdpClientSecret)
149151
dst.ShowErrorDetails = types.BoolValue(src.ShowErrorDetails)
152+
JWTGroupsFilterFromPB(&dst.JWTGroupsFilter, src.JwtGroupsFilter, &diagnostics)
150153

151154
return diagnostics
152155
}

internal/provider/settings_model.go

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ type SettingsModel struct {
8080
TimeoutIdle timetypes.GoDuration `tfsdk:"timeout_idle"`
8181
TimeoutRead timetypes.GoDuration `tfsdk:"timeout_read"`
8282
TimeoutWrite timetypes.GoDuration `tfsdk:"timeout_write"`
83+
JWTGroupsFilter types.Object `tfsdk:"jwt_groups_filter"`
8384
}
8485

8586
func ConvertSettingsToPB(
@@ -151,6 +152,7 @@ func ConvertSettingsToPB(
151152
ToDuration(&pbSettings.TimeoutIdle, src.TimeoutIdle, &diagnostics)
152153
ToDuration(&pbSettings.TimeoutRead, src.TimeoutRead, &diagnostics)
153154
ToDuration(&pbSettings.TimeoutWrite, src.TimeoutWrite, &diagnostics)
155+
JWTGroupsFilterToPB(ctx, &pbSettings.JwtGroupsFilter, src.JWTGroupsFilter, &diagnostics)
154156

155157
return pbSettings, diagnostics
156158
}
@@ -223,6 +225,7 @@ func ConvertSettingsFromPB(
223225
dst.TimeoutIdle = FromDuration(src.TimeoutIdle)
224226
dst.TimeoutRead = FromDuration(src.TimeoutRead)
225227
dst.TimeoutWrite = FromDuration(src.TimeoutWrite)
228+
JWTGroupsFilterFromPB(&dst.JWTGroupsFilter, src.JwtGroupsFilter, &diagnostics)
226229

227230
return diagnostics
228231
}

internal/provider/settings_schema.go

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ var SettingsResourceSchema = schema.Schema{
190190
Optional: true,
191191
Description: "JWT claims headers mapping",
192192
},
193+
"jwt_groups_filter": jwtGroupsFilterSchema,
193194
"default_upstream_timeout": schema.StringAttribute{
194195
Optional: true,
195196
Description: "Default upstream timeout",

0 commit comments

Comments
 (0)