diff --git a/docs/flags.md b/docs/flags.md index bbe78c71d2..5b1fedfd8b 100644 --- a/docs/flags.md +++ b/docs/flags.md @@ -177,3 +177,4 @@ | `--webhook-provider-read-timeout=5s` | The read timeout for the webhook provider in duration format (default: 5s) | | `--webhook-provider-write-timeout=10s` | The write timeout for the webhook provider in duration format (default: 10s) | | `--[no-]webhook-server` | When enabled, runs as a webhook server instead of a controller. (default: false). | +| `--aws-domain-roles=AWS-DOMAIN-ROLES` | When using the AWS provider, specify the domain roles to use for the hosted zone (optional) | diff --git a/docs/tutorials/aws.md b/docs/tutorials/aws.md index 4e96bdde36..b1728614d6 100644 --- a/docs/tutorials/aws.md +++ b/docs/tutorials/aws.md @@ -1045,3 +1045,110 @@ Because those limits are in place, `aws-batch-change-size` can be set to any val ## Using CRD source to manage DNS records in AWS Please refer to the [CRD source documentation](../sources/crd.md#example) for more information. + +## Strategies for Scoping Zones + +> Without specifying these flags, management applies to all zones. + +In order to manage specific zones, you may need to combine multiple options + +| Argument | Description | Flow Control | +|:----------------------------|:----------------------------------------------------------------------------|:------------:| +| `--zone-id-filter` | Specify multiple times if needed | OR | +| `--domain-filter` | By domain suffix - specify multiple times if needed | OR | +| `--regex-domain-filter` | By domain suffix but as a regex - overrides domain-filter | AND | +| `--exclude-domains` | To exclude a domain or subdomain | OR | +| `--regex-domain-exclusion` | Subtracts its matches from `regex-domain-filter`'s matches | AND | +| `--aws-zone-type` | Only sync zones of this type `[public\|private]` | OR | +| `--aws-zone-tags` | Only sync zones with this tag | AND | + +Minimum required configuration + +```sh +args: + --provider=aws + --registry=txt + --source=service +``` + +### Filter by Zone Type + +> If this flag is not specified, management applies to both public and private zones. + +```sh +args: + --aws-zone-type=private|public # choose between public or private + ... +``` + +### Filter by Domain + +> Specify multiple times if needed. + +```sh +args: + --domain-filter=example.com + --domain-filter=.paradox.example.com + ... +``` + +Example `--domain-filter=example.com` will allow for zone `example.com` and any zones that end in `.example.com`, including `an.example.com`, i.e., the subdomains of example.com. + +When there are multiple domains, filter `--domain-filter=example.com` will match domains `example.com`, `ex.par.example.com`, `par.example.com`, `x.par.eu-west-1.example.com`. + +And if the filter is prepended with `.` e.g., `--domain-filter=.example.com` it will allow *only* zones that end in `.example.com`, i.e., the subdomains of example.com but not the `example.com` zone itself. Example result: `ex.par.eu-west-1.example.com`, `ex.par.example.com`, `par.example.com`. + +> Note: if you prepend the filter with ".", it will not attempt to match parent zones. + +### Filter by Zone ID + +> Specify multiple times if needed, the flow logic is OR + +```sh +args: + --zone-id-filter=ABCDEF12345678 + --zone-id-filter=XYZDEF12345888 + ... +``` + +### Filter by Tag + +> Specify multiple times if needed, the flow logic is AND + +Keys only + +```sh +args: + --aws-zone-tags=owner + --aws-zone-tags=vertical +``` + +Or specify keys with values + +```sh +args: + --aws-zone-tags=owner=k8s + --aws-zone-tags=vertical=k8s +``` + +Can't specify multiple or separate values with commas: `key1=val1,key2=val2` at the moment. +Filter only by value `--aws-zone-tags==tag-value` is not supported. + +```sh +args: + --aws-zone-tags=team=k8s,vertical=platform # this is not supported + --aws-zone-tags==tag-value # this is not supported +``` + +### Add Roles specific to the zone + +If you have multiple zones and want to manage them with different roles, you can configure `external-dns` with the following option: + +```sh +args: + --aws-domain-roles=example.com=arn:aws:iam::123456789012:role/external-dns-role + --aws-domain-roles=example.org=arn:aws:iam::123456789011:role/external-dns-role +``` + +`--aws-domain-roles` is a map of domain names to IAM roles. The domain/hosted zone names should match the `--domain-filter` values. +AWS also sets STS rate limits on a per account per region basis i.e. for a single account on a single region you can make 600 requests per second. diff --git a/main.go b/main.go index 47d1695dc0..035cfd7c25 100644 --- a/main.go +++ b/main.go @@ -212,9 +212,12 @@ func main() { p, err = alibabacloud.NewAlibabaCloudProvider(cfg.AlibabaCloudConfigFile, domainFilter, zoneIDFilter, cfg.AlibabaCloudZoneType, cfg.DryRun) case "aws": configs := aws.CreateV2Configs(cfg) - clients := make(map[string]aws.Route53API, len(configs)) - for profile, config := range configs { - clients[profile] = route53.NewFromConfig(config) + clients := make(map[string][]*aws.AWSZoneConfig, len(configs)) + for profile, configZones := range configs { + for _, configZone := range configZones { + configZone.Route53Config = route53.NewFromConfig(configZone.Config) + clients[profile] = append(clients[profile], configZone) + } } p, err = aws.NewAWSProvider( @@ -241,7 +244,7 @@ func main() { log.Infof("Registry \"%s\" cannot be used with AWS Cloud Map. Switching to \"aws-sd\".", cfg.Registry) cfg.Registry = "aws-sd" } - p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, cfg.AWSSDCreateTag, sd.NewFromConfig(aws.CreateDefaultV2Config(cfg))) + p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, cfg.AWSSDCreateTag, sd.NewFromConfig(aws.CreateDefaultV2Config(cfg).Config)) case "azure-dns", "azure": p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.DryRun) case "azure-private-dns": @@ -400,7 +403,7 @@ func main() { }, } } - r, err = registry.NewDynamoDBRegistry(p, cfg.TXTOwnerID, dynamodb.NewFromConfig(aws.CreateDefaultV2Config(cfg), dynamodbOpts...), cfg.AWSDynamoDBTable, cfg.TXTPrefix, cfg.TXTSuffix, cfg.TXTWildcardReplacement, cfg.ManagedDNSRecordTypes, cfg.ExcludeDNSRecordTypes, []byte(cfg.TXTEncryptAESKey), cfg.TXTCacheInterval) + r, err = registry.NewDynamoDBRegistry(p, cfg.TXTOwnerID, dynamodb.NewFromConfig(aws.CreateDefaultV2Config(cfg).Config, dynamodbOpts...), cfg.AWSDynamoDBTable, cfg.TXTPrefix, cfg.TXTSuffix, cfg.TXTWildcardReplacement, cfg.ManagedDNSRecordTypes, cfg.ExcludeDNSRecordTypes, []byte(cfg.TXTEncryptAESKey), cfg.TXTCacheInterval) case "noop": r, err = registry.NewNoopRegistry(p) case "txt": diff --git a/pkg/apis/externaldns/types.go b/pkg/apis/externaldns/types.go index f0697d4982..bf14683b97 100644 --- a/pkg/apis/externaldns/types.go +++ b/pkg/apis/externaldns/types.go @@ -215,6 +215,7 @@ type Config struct { TraefikDisableLegacy bool TraefikDisableNew bool NAT64Networks []string + AWSDomainRoles map[string]string } var defaultConfig = &Config{ @@ -375,12 +376,14 @@ var defaultConfig = &Config{ TraefikDisableLegacy: false, TraefikDisableNew: false, NAT64Networks: []string{}, + AWSDomainRoles: map[string]string{}, } // NewConfig returns new Config object func NewConfig() *Config { return &Config{ AWSSDCreateTag: map[string]string{}, + AWSDomainRoles: map[string]string{}, } } @@ -638,6 +641,7 @@ func App(cfg *Config) *kingpin.Application { app.Flag("webhook-provider-write-timeout", "The write timeout for the webhook provider in duration format (default: 10s)").Default(defaultConfig.WebhookProviderWriteTimeout.String()).DurationVar(&cfg.WebhookProviderWriteTimeout) app.Flag("webhook-server", "When enabled, runs as a webhook server instead of a controller. (default: false).").BoolVar(&cfg.WebhookServer) + app.Flag("aws-domain-roles", "When using the AWS provider, specify the domain roles to use for the hosted zone (optional)").StringMapVar(&cfg.AWSDomainRoles) return app } diff --git a/pkg/apis/externaldns/types_test.go b/pkg/apis/externaldns/types_test.go index 9ce16a6beb..4592c1b954 100644 --- a/pkg/apis/externaldns/types_test.go +++ b/pkg/apis/externaldns/types_test.go @@ -131,6 +131,7 @@ var ( WebhookProviderURL: "http://localhost:8888", WebhookProviderReadTimeout: 5 * time.Second, WebhookProviderWriteTimeout: 10 * time.Second, + AWSDomainRoles: map[string]string{}, } overriddenConfig = &Config{ @@ -245,6 +246,7 @@ var ( WebhookProviderURL: "http://localhost:8888", WebhookProviderReadTimeout: 5 * time.Second, WebhookProviderWriteTimeout: 10 * time.Second, + AWSDomainRoles: map[string]string{"example.com": "arn:aws:iam::123456789012:role/role1", "example.org": "arn:aws:iam::123456789012:role/role2"}, } ) @@ -351,6 +353,8 @@ func TestParseFlags(t *testing.T) { "--aws-sd-service-cleanup", "--aws-sd-create-tag=key1=value1", "--aws-sd-create-tag=key2=value2", + "--aws-domain-roles=example.com=arn:aws:iam::123456789012:role/role1", + "--aws-domain-roles=example.org=arn:aws:iam::123456789012:role/role2", "--no-aws-evaluate-target-health", "--policy=upsert-only", "--registry=noop", @@ -508,6 +512,7 @@ func TestParseFlags(t *testing.T) { "EXTERNAL_DNS_IBMCLOUD_CONFIG_FILE": "ibmcloud.json", "EXTERNAL_DNS_TENCENT_CLOUD_CONFIG_FILE": "tencent-cloud.json", "EXTERNAL_DNS_TENCENT_CLOUD_ZONE_TYPE": "private", + "EXTERNAL_DNS_AWS_DOMAIN_ROLES": "example.com=arn:aws:iam::123456789012:role/role1\nexample.org=arn:aws:iam::123456789012:role/role2", }, expected: overriddenConfig, }, diff --git a/provider/aws/aws.go b/provider/aws/aws.go index cec056c38f..db56f59f5c 100644 --- a/provider/aws/aws.go +++ b/provider/aws/aws.go @@ -225,8 +225,10 @@ type Route53Change struct { type Route53Changes []*Route53Change type profiledZone struct { - profile string - zone *route53types.HostedZone + profile string + zone *route53types.HostedZone + zoneName string + client Route53API } func (cs Route53Changes) Route53Changes() []route53types.Change { @@ -269,7 +271,7 @@ type zonesListCache struct { // AWSProvider is an implementation of Provider for AWS Route53. type AWSProvider struct { provider.BaseProvider - clients map[string]Route53API + clients map[string][]*AWSZoneConfig dryRun bool batchChangeSize int batchChangeSizeBytes int @@ -310,7 +312,7 @@ type AWSConfig struct { } // NewAWSProvider initializes a new AWS Route53 based Provider. -func NewAWSProvider(awsConfig AWSConfig, clients map[string]Route53API) (*AWSProvider, error) { +func NewAWSProvider(awsConfig AWSConfig, clients map[string][]*AWSZoneConfig) (*AWSProvider, error) { provider := &AWSProvider{ clients: clients, domainFilter: awsConfig.DomainFilter, @@ -356,55 +358,13 @@ func (p *AWSProvider) zones(ctx context.Context) (map[string]*profiledZone, erro zones := make(map[string]*profiledZone) - for profile, client := range p.clients { - paginator := route53.NewListHostedZonesPaginator(client, &route53.ListHostedZonesInput{}) + for profile, hostedZoneClients := range p.clients { + var err error + for _, client := range hostedZoneClients { + zones, err = p.fetchFilteredZonesForClient(ctx, client, profile) - for paginator.HasMorePages() { - resp, err := paginator.NextPage(ctx) if err != nil { - var te *route53types.ThrottlingException - if errors.As(err, &te) { - log.Infof("Skipping AWS profile %q due to provider side throttling: %v", profile, te.ErrorMessage()) - continue - } - // nothing to do here. Falling through to general error handling - return nil, provider.NewSoftError(fmt.Errorf("failed to list hosted zones: %w", err)) - } - var zonesToTagFilter []string - for _, zone := range resp.HostedZones { - if !p.zoneIDFilter.Match(*zone.Id) { - continue - } - - if !p.zoneTypeFilter.Match(zone) { - continue - } - - if !p.domainFilter.Match(*zone.Name) { - if !p.zoneMatchParent { - continue - } - if !p.domainFilter.MatchParent(*zone.Name) { - continue - } - } - - if !p.zoneTagFilter.IsEmpty() { - zonesToTagFilter = append(zonesToTagFilter, cleanZoneID(*zone.Id)) - } - - zones[*zone.Id] = &profiledZone{ - profile: profile, - zone: &zone, - } - } - - if len(zonesToTagFilter) > 0 { - if zTags, err := p.tagsForZone(ctx, zonesToTagFilter, profile); err != nil { - return nil, provider.NewSoftErrorf("failed to list tags for zones %w", err) - } else { - zTags.filterZonesByTags(p, zones) - } + return nil, provider.NewSoftErrorf("failed to list zones tags: %w", err) } } } @@ -423,6 +383,64 @@ func (p *AWSProvider) zones(ctx context.Context) (map[string]*profiledZone, erro return zones, nil } +func (p *AWSProvider) fetchFilteredZonesForClient(ctx context.Context, client *AWSZoneConfig, profile string) (map[string]*profiledZone, error) { + profileZones := make(map[string]*profiledZone) + paginator := route53.NewListHostedZonesPaginator(client.Route53Config, &route53.ListHostedZonesInput{}) + + for paginator.HasMorePages() { + resp, err := paginator.NextPage(ctx) + if err != nil { + var te *route53types.ThrottlingException + if errors.As(err, &te) { + log.Infof("Skipping AWS profile %q due to provider side throttling: %v", profile, te.ErrorMessage()) + continue + } + // nothing to do here. Falling through to general error handling + return nil, provider.NewSoftError(fmt.Errorf("failed to list hosted zones: %w", err)) + } + var zonesToTagFilter []string + for _, zone := range resp.HostedZones { + if !p.zoneIDFilter.Match(*zone.Id) { + continue + } + + if !p.zoneTypeFilter.Match(zone) { + continue + } + + if !p.domainFilter.Match(*zone.Name) { + if !p.zoneMatchParent { + continue + } + if !p.domainFilter.MatchParent(*zone.Name) { + continue + } + } + + if !p.zoneTagFilter.IsEmpty() { + zonesToTagFilter = append(zonesToTagFilter, cleanZoneID(*zone.Id)) + } + + profileZones[*zone.Id] = &profiledZone{ + profile: profile, + zone: &zone, + zoneName: client.HostedZoneName, + client: client.Route53Config, + } + } + + if len(zonesToTagFilter) > 0 { + if zTags, err := p.tagsForZone(ctx, zonesToTagFilter, client.Route53Config); err != nil { + return nil, provider.NewSoftErrorf("failed to list tags for zones %w", err) + } else { + zTags.filterZonesByTags(p, profileZones) + } + } + } + + return profileZones, nil +} + // wildcardUnescape converts \\052.abc back to *.abc // Route53 stores wildcards escaped: http://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DomainNameFormat.html?shortFooter=true#domain-name-format-asterisk func wildcardUnescape(s string) string { @@ -465,7 +483,7 @@ func (p *AWSProvider) records(ctx context.Context, zones map[string]*profiledZon endpoints := make([]*endpoint.Endpoint, 0) for _, z := range zones { - client := p.clients[z.profile] + client := z.client paginator := route53.NewListResourceRecordSetsPaginator(client, &route53.ListResourceRecordSetsInput{ HostedZoneId: z.zone.Id, @@ -699,7 +717,7 @@ func (p *AWSProvider) submitChanges(ctx context.Context, changes Route53Changes, successfulChanges := 0 - client := p.clients[zones[z].profile] + client := zones[z].client if _, err := client.ChangeResourceRecordSets(ctx, params); err != nil { log.Errorf("Failure in zone %s when submitting change batch: %v", *zones[z].zone.Name, err) @@ -975,9 +993,7 @@ func groupChangesByNameAndOwnershipRelation(cs Route53Changes) map[string]Route5 return changesByOwnership } -func (p *AWSProvider) tagsForZone(ctx context.Context, zoneIDs []string, profile string) (zoneTags, error) { - client := p.clients[profile] - +func (p *AWSProvider) tagsForZone(ctx context.Context, zoneIDs []string, client Route53API) (zoneTags, error) { result := zoneTags{} for i := 0; i < len(zoneIDs); i += batchSize { diff --git a/provider/aws/aws_test.go b/provider/aws/aws_test.go index 7305f06487..bf40f9617e 100644 --- a/provider/aws/aws_test.go +++ b/provider/aws/aws_test.go @@ -352,7 +352,7 @@ func TestAWSZones(t *testing.T) { func TestAWSZonesWithTagFilterError(t *testing.T) { client := NewRoute53APIStub(t) provider := &AWSProvider{ - clients: map[string]Route53API{defaultAWSProfile: client}, + clients: map[string][]*AWSZoneConfig{defaultAWSProfile: {{Route53Config: client}}}, zoneTagFilter: provider.NewZoneTagFilter([]string{"zone=2"}), dryRun: false, zonesCache: &zonesListCache{duration: 1 * time.Minute}, @@ -982,14 +982,14 @@ func TestAWSApplyChanges(t *testing.T) { ctx := tt.setup(provider) provider.zonesCache = &zonesListCache{duration: 0 * time.Minute} - counter := NewRoute53APICounter(provider.clients[defaultAWSProfile]) - provider.clients[defaultAWSProfile] = counter + counter := NewRoute53APICounter(provider.clients[defaultAWSProfile][0].Route53Config) + provider.clients[defaultAWSProfile][0].Route53Config = counter require.NoError(t, provider.ApplyChanges(ctx, changes)) assert.Equal(t, 1, counter.calls["ListHostedZonesPages"], tt.name) assert.Equal(t, tt.listRRSets, counter.calls["ListResourceRecordSetsPages"], tt.name) - validateRecords(t, listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do."), []route53types.ResourceRecordSet{ + validateRecords(t, listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do."), []route53types.ResourceRecordSet{ { Name: aws.String("create-test.zone-1.ext-dns-test-2.teapot.zalan.do."), Type: route53types.RRTypeA, @@ -1119,7 +1119,7 @@ func TestAWSApplyChanges(t *testing.T) { ResourceRecords: []route53types.ResourceRecord{{Value: aws.String("10 mailhost1.foo.elb.amazonaws.com")}}, }, }) - validateRecords(t, listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do."), []route53types.ResourceRecordSet{ + validateRecords(t, listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do."), []route53types.ResourceRecordSet{ { Name: aws.String("escape-\\045\\041s\\050\\074nil\\076\\051-codes.zone-2.ext-dns-test-2.teapot.zalan.do."), Type: route53types.RRTypeA, @@ -1320,8 +1320,8 @@ func TestAWSApplyChangesDryRun(t *testing.T) { validateRecords(t, append( - listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do."), - listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.")...), + listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do."), + listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.")...), originalRecords) } @@ -1945,7 +1945,7 @@ func TestAWSCreateRecordsWithCNAME(t *testing.T) { Create: adjusted, })) - recordSets := listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") + recordSets := listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") validateRecords(t, recordSets, []route53types.ResourceRecordSet{ { @@ -2006,7 +2006,7 @@ func TestAWSCreateRecordsWithALIAS(t *testing.T) { Create: adjusted, })) - recordSets := listAWSRecords(t, provider.clients[defaultAWSProfile], "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") + recordSets := listAWSRecords(t, provider.clients[defaultAWSProfile][0].Route53Config, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") validateRecords(t, recordSets, []route53types.ResourceRecordSet{ { @@ -2160,7 +2160,7 @@ func createAWSZone(t *testing.T, provider *AWSProvider, zone *route53types.Hoste HostedZoneConfig: zone.Config, } - if _, err := provider.clients[defaultAWSProfile].CreateHostedZone(context.Background(), params); err != nil { + if _, err := provider.clients[defaultAWSProfile][0].Route53Config.CreateHostedZone(context.Background(), params); err != nil { var hzExists *route53types.HostedZoneAlreadyExists require.ErrorAs(t, err, &hzExists) } @@ -2216,7 +2216,7 @@ func newAWSProviderWithTagFilter(t *testing.T, domainFilter endpoint.DomainFilte client := NewRoute53APIStub(t) provider := &AWSProvider{ - clients: map[string]Route53API{defaultAWSProfile: client}, + clients: map[string][]*AWSZoneConfig{defaultAWSProfile: {{Route53Config: client}}}, batchChangeSize: defaultBatchChangeSize, batchChangeSizeBytes: defaultBatchChangeSizeBytes, batchChangeSizeValues: defaultBatchChangeSizeValues, @@ -2256,7 +2256,7 @@ func newAWSProviderWithTagFilter(t *testing.T, domainFilter endpoint.DomainFilte Config: &route53types.HostedZoneConfig{PrivateZone: false}, }) - setupZoneTags(provider.clients[defaultAWSProfile].(*Route53APIStub)) + setupZoneTags(provider.clients[defaultAWSProfile][0].Route53Config.(*Route53APIStub)) setAWSRecords(t, provider, records) diff --git a/provider/aws/aws_utils_test.go b/provider/aws/aws_utils_test.go index a7c7e2aac1..f59c1aa1fd 100644 --- a/provider/aws/aws_utils_test.go +++ b/provider/aws/aws_utils_test.go @@ -50,7 +50,7 @@ type Route53APIFixtureStub struct { func providerFilters(client *Route53APIFixtureStub, options ...func(awsProvider *AWSProvider)) *AWSProvider { p := &AWSProvider{ - clients: map[string]Route53API{defaultAWSProfile: client}, + clients: map[string][]*AWSZoneConfig{defaultAWSProfile: {{Route53Config: client}}}, evaluateTargetHealth: false, dryRun: false, domainFilter: endpoint.NewDomainFilter([]string{}), diff --git a/provider/aws/config.go b/provider/aws/config.go index 5908150e77..5328736913 100644 --- a/provider/aws/config.go +++ b/provider/aws/config.go @@ -33,53 +33,76 @@ import ( "sigs.k8s.io/external-dns/pkg/apis/externaldns" ) +// STSClient is an interface that defines the methods used by the AWS provider to assume roles. +// This is defined as an interface to make testing easier by allowing the use of a mock client. It is in accordance with how AWS SDK defines the STS client. +type STSClient interface { + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + // AWSSessionConfig contains configuration to create a new AWS provider. type AWSSessionConfig struct { AssumeRole string AssumeRoleExternalID string APIRetries int Profile string + DomainRolesMap map[string]string +} + +type AWSZoneConfig struct { + Config awsv2.Config + HostedZoneName string + Route53Config Route53API } -func CreateDefaultV2Config(cfg *externaldns.Config) awsv2.Config { +func CreateDefaultV2Config(cfg *externaldns.Config) *AWSZoneConfig { result, err := newV2Config( AWSSessionConfig{ AssumeRole: cfg.AWSAssumeRole, AssumeRoleExternalID: cfg.AWSAssumeRoleExternalID, APIRetries: cfg.AWSAPIRetries, - }, + DomainRolesMap: cfg.AWSDomainRoles, + }, nil, ) if err != nil { logrus.Fatal(err) } - return result + + if len(result) == 0 { + logrus.Fatal("No AWS credentials found") + } + + return result[0] } -func CreateV2Configs(cfg *externaldns.Config) map[string]awsv2.Config { - result := make(map[string]awsv2.Config) +func CreateV2Configs(cfg *externaldns.Config) map[string][]*AWSZoneConfig { + result := make(map[string][]*AWSZoneConfig) if len(cfg.AWSProfiles) == 0 || (len(cfg.AWSProfiles) == 1 && cfg.AWSProfiles[0] == "") { cfg := CreateDefaultV2Config(cfg) - result[defaultAWSProfile] = cfg + result[defaultAWSProfile] = make([]*AWSZoneConfig, 0) + result[defaultAWSProfile] = append(result[defaultAWSProfile], cfg) } else { for _, profile := range cfg.AWSProfiles { - cfg, err := newV2Config( + configs, err := newV2Config( AWSSessionConfig{ AssumeRole: cfg.AWSAssumeRole, AssumeRoleExternalID: cfg.AWSAssumeRoleExternalID, APIRetries: cfg.AWSAPIRetries, Profile: profile, + DomainRolesMap: cfg.AWSDomainRoles, }, + nil, ) if err != nil { logrus.Fatal(err) } - result[profile] = cfg + result[profile] = configs } } return result } -func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { +func newV2Config(awsConfig AWSSessionConfig, stsClient STSClient) ([]*AWSZoneConfig, error) { + hostedZonesConfigs := make([]*AWSZoneConfig, 0) defaultOpts := []func(*config.LoadOptions) error{ config.WithRetryer(func() awsv2.Retryer { return retry.AddWithMaxAttempts(retry.NewStandard(), awsConfig.APIRetries) @@ -95,26 +118,46 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { cfg, err := config.LoadDefaultConfig(context.Background(), defaultOpts...) if err != nil { - return awsv2.Config{}, fmt.Errorf("instantiating AWS config: %w", err) + return nil, fmt.Errorf("instantiating AWS config: %w", err) } - if awsConfig.AssumeRole != "" { - stsSvc := sts.NewFromConfig(cfg) - var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions) - if awsConfig.AssumeRoleExternalID != "" { - logrus.Infof("Assuming role %s with external id", awsConfig.AssumeRole) - logrus.Debugf("External id: %s", awsConfig.AssumeRoleExternalID) - assumeRoleOpts = []func(*stscredsv2.AssumeRoleOptions){ - func(opts *stscredsv2.AssumeRoleOptions) { - opts.ExternalID = &awsConfig.AssumeRoleExternalID - }, + if len(awsConfig.DomainRolesMap) == 0 { + // If AssumeRole is set, use it to assume the role and return the config. This is kept for backward compatibility. + if awsConfig.AssumeRole != "" { + if stsClient == nil { + stsClient = sts.NewFromConfig(cfg) } - } else { - logrus.Infof("Assuming role: %s", awsConfig.AssumeRole) + creds := stscredsv2.NewAssumeRoleProvider(stsClient, awsConfig.AssumeRole) + cfg.Credentials = awsv2.NewCredentialsCache(creds) + } + hostedZonesConfigs = append(hostedZonesConfigs, &AWSZoneConfig{Config: cfg}) + return hostedZonesConfigs, nil + } + + for domain, role := range awsConfig.DomainRolesMap { + if role != "" { + cfCopy := cfg.Copy() + if stsClient == nil { + stsClient = sts.NewFromConfig(cfCopy) + } + var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions) + if awsConfig.AssumeRoleExternalID != "" { + logrus.Infof("Assuming role %s with external id", awsConfig.AssumeRole) + logrus.Debugf("External id: %q", awsConfig.AssumeRoleExternalID) + assumeRoleOpts = []func(*stscredsv2.AssumeRoleOptions){ + func(opts *stscredsv2.AssumeRoleOptions) { + opts.ExternalID = &awsConfig.AssumeRoleExternalID + }, + } + } else { + logrus.Infof("Assuming role: %s", role) + } + creds := stscredsv2.NewAssumeRoleProvider(stsClient, role, assumeRoleOpts...) + cfCopy.Credentials = awsv2.NewCredentialsCache(creds) + + hostedZonesConfigs = append(hostedZonesConfigs, &AWSZoneConfig{Config: cfCopy, HostedZoneName: domain}) } - creds := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...) - cfg.Credentials = awsv2.NewCredentialsCache(creds) } - return cfg, nil + return hostedZonesConfigs, nil } diff --git a/provider/aws/config_test.go b/provider/aws/config_test.go index 00b3b46aac..8849676f54 100644 --- a/provider/aws/config_test.go +++ b/provider/aws/config_test.go @@ -18,13 +18,25 @@ package aws import ( "context" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type mockSTSClient struct { + AssumeRoleFunc func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + return m.AssumeRoleFunc(ctx, params, optFns...) +} + func Test_newV2Config(t *testing.T) { t.Run("should use profile from credentials file", func(t *testing.T) { // setup @@ -35,9 +47,13 @@ func Test_newV2Config(t *testing.T) { defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") // when - cfg, err := newV2Config(AWSSessionConfig{Profile: "profile2"}) + cfgs, err := newV2Config(AWSSessionConfig{Profile: "profile2"}, nil) require.NoError(t, err) - creds, err := cfg.Credentials.Retrieve(context.Background()) + + assert.GreaterOrEqual(t, len(cfgs), 1) + cfg := cfgs[0] + + creds, err := cfg.Config.Credentials.Retrieve(context.Background()) // then assert.NoError(t, err) @@ -53,15 +69,93 @@ func Test_newV2Config(t *testing.T) { defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") // when - cfg, err := newV2Config(AWSSessionConfig{}) + cfgs, err := newV2Config(AWSSessionConfig{}, nil) require.NoError(t, err) - creds, err := cfg.Credentials.Retrieve(context.Background()) + assert.GreaterOrEqual(t, len(cfgs), 1) + cfg := cfgs[0] + + creds, err := cfg.Config.Credentials.Retrieve(context.Background()) // then assert.NoError(t, err) assert.Equal(t, "AKIAIOSFODNN7EXAMPLE", creds.AccessKeyID) assert.Equal(t, "topsecret", creds.SecretAccessKey) }) + + t.Run("should use roles for different domains", func(t *testing.T) { + os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") + os.Setenv("AWS_SECRET_ACCESS_KEY", "topsecret") + defer os.Unsetenv("AWS_ACCESS_KEY_ID") + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + roles := make([]string, 0) + mockClient := &mockSTSClient{ + AssumeRoleFunc: func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + roles = append(roles, aws.ToString(params.RoleArn)) + return &sts.AssumeRoleOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("AKIAIOSFODNN7EXAMPLE"), + SecretAccessKey: aws.String("topsecret"), + SessionToken: aws.String("session-token"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil + }, + } + + cfgs, err := newV2Config(AWSSessionConfig{ + DomainRolesMap: map[string]string{ + "example.com": "arn:aws:iam::123456789012:role/role1", + "example.org": "arn:aws:iam::123456789012:role/role2", + }, + }, mockClient) + + for _, cfg := range cfgs { + _, err := cfg.Config.Credentials.Retrieve(context.Background()) + require.NoError(t, err) + } + + require.NoError(t, err) + assert.Contains(t, roles, "arn:aws:iam::123456789012:role/role1") + assert.Contains(t, roles, "arn:aws:iam::123456789012:role/role2") + assert.NotNil(t, cfgs, "expected at least one config") + }) + + t.Run("should use assume role", func(t *testing.T) { + // setup + os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") + os.Setenv("AWS_SECRET_ACCESS_KEY", "topsecret") + defer os.Unsetenv("AWS_ACCESS_KEY_ID") + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + roles := make([]string, 0) + mockClient := &mockSTSClient{ + AssumeRoleFunc: func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + roles = append(roles, aws.ToString(params.RoleArn)) + return &sts.AssumeRoleOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("AKIAIOSFODNN7EXAMPLE"), + SecretAccessKey: aws.String("topsecret"), + SessionToken: aws.String("session-token"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil + }, + } + + cfgs, err := newV2Config(AWSSessionConfig{ + AssumeRole: "arn:aws:iam::123456789012:role/role1", + }, mockClient) + + for _, cfg := range cfgs { + _, err := cfg.Config.Credentials.Retrieve(context.Background()) + require.NoError(t, err) + } + + require.NoError(t, err) + assert.Contains(t, roles, "arn:aws:iam::123456789012:role/role1") + assert.NotNil(t, cfgs, "expected at least one config") + }) } func prepareCredentialsFile(t *testing.T) (*os.File, error) {