Skip to content

Commit 02845dc

Browse files
committed
Refactor AuthProvider multi-tenant token credential
1 parent c6cb327 commit 02845dc

File tree

1 file changed

+129
-58
lines changed

1 file changed

+129
-58
lines changed

pkg/azclient/auth.go

+129-58
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"strings"
2424

2525
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
26+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
2627
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
2728
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
2829
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -32,23 +33,29 @@ import (
3233
)
3334

3435
type AuthProvider struct {
35-
ComputeCredential azcore.TokenCredential
36-
NetworkCredential azcore.TokenCredential
37-
MultiTenantCredential azcore.TokenCredential
38-
CloudConfig cloud.Configuration
36+
ComputeCredential azcore.TokenCredential
37+
AdditionalComputeClientOptions []func(option *arm.ClientOptions)
38+
NetworkCredential azcore.TokenCredential
39+
CloudConfig cloud.Configuration
3940
}
4041

41-
func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, clientOptionsMutFn ...func(option *policy.ClientOptions)) (*AuthProvider, error) {
42+
func NewAuthProvider(
43+
armConfig *ARMClientConfig,
44+
config *AzureAuthConfig,
45+
clientOptionsMutFn ...func(option *policy.ClientOptions),
46+
) (*AuthProvider, error) {
4247
clientOption, _, err := GetAzCoreClientOption(armConfig)
4348
if err != nil {
4449
return nil, err
4550
}
4651
for _, fn := range clientOptionsMutFn {
4752
fn(clientOption)
4853
}
49-
var computeCredential azcore.TokenCredential
50-
var networkTokenCredential azcore.TokenCredential
51-
var multiTenantCredential azcore.TokenCredential
54+
var (
55+
computeCredential azcore.TokenCredential
56+
networkCredential azcore.TokenCredential
57+
additionalComputeClientOptions []func(option *arm.ClientOptions)
58+
)
5259

5360
// federatedIdentityCredential is used for workload identity federation
5461
if aadFederatedTokenFile, enabled := config.GetAzureFederatedTokenFile(); enabled {
@@ -62,6 +69,7 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
6269
return nil, err
6370
}
6471
}
72+
6573
// managedIdentityCredential is used for managed identity extension
6674
if computeCredential == nil && config.UseManagedIdentityExtension {
6775
credOptions := &azidentity.ManagedIdentityCredentialOptions{
@@ -79,52 +87,80 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
7987
return nil, err
8088
}
8189
if config.AuxiliaryTokenProvider != nil && IsMultiTenant(armConfig) {
82-
networkTokenCredential, err = armauth.NewKeyVaultCredential(
90+
// Use AuxiliaryTokenProvider as the network credential
91+
networkCredential, err = armauth.NewKeyVaultCredential(
8392
computeCredential,
8493
config.AuxiliaryTokenProvider.SecretResourceID(),
8594
)
8695
if err != nil {
8796
return nil, fmt.Errorf("create KeyVaultCredential for auxiliary token provider: %w", err)
8897
}
98+
99+
// Additionally, we need to add the auxiliary token to the HTTP header when making requests to the compute resources
100+
additionalComputeClientOptions = append(additionalComputeClientOptions, func(option *arm.ClientOptions) {
101+
option.PerRetryPolicies = append(option.PerRetryPolicies, armauth.NewAuxiliaryAuthPolicy(
102+
[]azcore.TokenCredential{networkCredential},
103+
DefaultTokenScopeFor(clientOption.Cloud),
104+
))
105+
})
89106
}
90107
}
91108

92109
// Client secret authentication
93110
if computeCredential == nil && len(config.GetAADClientSecret()) > 0 {
94-
credOptions := &azidentity.ClientSecretCredentialOptions{
95-
ClientOptions: *clientOption,
96-
}
97-
computeCredential, err = azidentity.NewClientSecretCredential(armConfig.GetTenantID(), config.GetAADClientID(), config.GetAADClientSecret(), credOptions)
98-
if err != nil {
99-
return nil, err
100-
}
101111
if IsMultiTenant(armConfig) {
102-
credOptions := &azidentity.ClientSecretCredentialOptions{
103-
ClientOptions: *clientOption,
104-
}
105-
networkTokenCredential, err = azidentity.NewClientSecretCredential(armConfig.NetworkResourceTenantID, config.GetAADClientID(), config.AADClientSecret, credOptions)
106-
if err != nil {
107-
return nil, err
112+
113+
// Network credential for network resource access
114+
{
115+
credOptions := &azidentity.ClientSecretCredentialOptions{
116+
ClientOptions: *clientOption,
117+
}
118+
networkCredential, err = azidentity.NewClientSecretCredential(
119+
armConfig.NetworkResourceTenantID,
120+
config.GetAADClientID(),
121+
config.GetAADClientSecret(),
122+
credOptions,
123+
)
124+
if err != nil {
125+
return nil, err
126+
}
108127
}
109128

110-
credOptions = &azidentity.ClientSecretCredentialOptions{
111-
ClientOptions: *clientOption,
112-
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
129+
// Compute credential with additional allowed tenants for cross-tenant access
130+
{
131+
credOptions := &azidentity.ClientSecretCredentialOptions{
132+
ClientOptions: *clientOption,
133+
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
134+
}
135+
computeCredential, err = azidentity.NewClientSecretCredential(
136+
armConfig.GetTenantID(),
137+
config.GetAADClientID(),
138+
config.GetAADClientSecret(),
139+
credOptions,
140+
)
141+
if err != nil {
142+
return nil, err
143+
}
113144
}
114-
multiTenantCredential, err = azidentity.NewClientSecretCredential(armConfig.GetTenantID(), config.GetAADClientID(), config.GetAADClientSecret(), credOptions)
145+
} else {
146+
// Single tenant
147+
credOptions := &azidentity.ClientSecretCredentialOptions{
148+
ClientOptions: *clientOption,
149+
}
150+
computeCredential, err = azidentity.NewClientSecretCredential(
151+
armConfig.GetTenantID(),
152+
config.GetAADClientID(),
153+
config.GetAADClientSecret(),
154+
credOptions,
155+
)
115156
if err != nil {
116157
return nil, err
117158
}
118-
119159
}
120160
}
121161

122162
// ClientCertificateCredential is used for client certificate
123163
if computeCredential == nil && len(config.AADClientCertPath) > 0 {
124-
credOptions := &azidentity.ClientCertificateCredentialOptions{
125-
ClientOptions: *clientOption,
126-
SendCertificateChain: true,
127-
}
128164
certData, err := os.ReadFile(config.AADClientCertPath)
129165
if err != nil {
130166
return nil, fmt.Errorf("reading the client certificate from file %s: %w", config.AADClientCertPath, err)
@@ -133,20 +169,58 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
133169
if err != nil {
134170
return nil, fmt.Errorf("decoding the client certificate: %w", err)
135171
}
136-
computeCredential, err = azidentity.NewClientCertificateCredential(armConfig.GetTenantID(), config.GetAADClientID(), certificate, privateKey, credOptions)
137-
if err != nil {
138-
return nil, err
139-
}
172+
140173
if IsMultiTenant(armConfig) {
141-
networkTokenCredential, err = azidentity.NewClientCertificateCredential(armConfig.NetworkResourceTenantID, config.GetAADClientID(), certificate, privateKey, credOptions)
142-
if err != nil {
143-
return nil, err
174+
175+
// Network credential for network resource access
176+
{
177+
credOptions := &azidentity.ClientCertificateCredentialOptions{
178+
ClientOptions: *clientOption,
179+
SendCertificateChain: true,
180+
}
181+
networkCredential, err = azidentity.NewClientCertificateCredential(
182+
armConfig.NetworkResourceTenantID,
183+
config.GetAADClientID(),
184+
certificate,
185+
privateKey,
186+
credOptions,
187+
)
188+
if err != nil {
189+
return nil, err
190+
}
144191
}
145-
credOptions = &azidentity.ClientCertificateCredentialOptions{
146-
ClientOptions: *clientOption,
147-
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
192+
193+
// Compute credential with additional allowed tenants for cross-tenant access
194+
{
195+
credOptions := &azidentity.ClientCertificateCredentialOptions{
196+
ClientOptions: *clientOption,
197+
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
198+
SendCertificateChain: true,
199+
}
200+
computeCredential, err = azidentity.NewClientCertificateCredential(
201+
armConfig.GetTenantID(),
202+
config.GetAADClientID(),
203+
certificate,
204+
privateKey,
205+
credOptions,
206+
)
207+
if err != nil {
208+
return nil, err
209+
}
148210
}
149-
multiTenantCredential, err = azidentity.NewClientCertificateCredential(armConfig.GetTenantID(), config.GetAADClientID(), certificate, privateKey, credOptions)
211+
} else {
212+
// Single tenant
213+
credOptions := &azidentity.ClientCertificateCredentialOptions{
214+
ClientOptions: *clientOption,
215+
SendCertificateChain: true,
216+
}
217+
computeCredential, err = azidentity.NewClientCertificateCredential(
218+
armConfig.GetTenantID(),
219+
config.GetAADClientID(),
220+
certificate,
221+
privateKey,
222+
credOptions,
223+
)
150224
if err != nil {
151225
return nil, err
152226
}
@@ -155,17 +229,21 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
155229

156230
// UserAssignedIdentityCredentials authentication
157231
if computeCredential == nil && len(config.AADMSIDataPlaneIdentityPath) > 0 {
158-
computeCredential, err = dataplane.NewUserAssignedIdentityCredential(context.Background(), config.AADMSIDataPlaneIdentityPath, dataplane.WithClientOpts(azcore.ClientOptions{Cloud: clientOption.Cloud}))
232+
computeCredential, err = dataplane.NewUserAssignedIdentityCredential(
233+
context.Background(),
234+
config.AADMSIDataPlaneIdentityPath,
235+
dataplane.WithClientOpts(azcore.ClientOptions{Cloud: clientOption.Cloud}),
236+
)
159237
if err != nil {
160238
return nil, err
161239
}
162240
}
163241

164242
return &AuthProvider{
165-
ComputeCredential: computeCredential,
166-
NetworkCredential: networkTokenCredential,
167-
MultiTenantCredential: multiTenantCredential,
168-
CloudConfig: clientOption.Cloud,
243+
ComputeCredential: computeCredential,
244+
AdditionalComputeClientOptions: additionalComputeClientOptions,
245+
NetworkCredential: networkCredential,
246+
CloudConfig: clientOption.Cloud,
169247
}, nil
170248
}
171249

@@ -180,18 +258,11 @@ func (factory *AuthProvider) GetNetworkAzIdentity() azcore.TokenCredential {
180258
return factory.ComputeCredential
181259
}
182260

183-
func (factory *AuthProvider) GetMultiTenantIdentity() azcore.TokenCredential {
184-
if factory.MultiTenantCredential != nil {
185-
return factory.MultiTenantCredential
186-
}
187-
return factory.ComputeCredential
188-
}
189-
190-
func (factory *AuthProvider) IsMultiTenantModeEnabled() bool {
191-
return factory.MultiTenantCredential != nil
261+
func (factory *AuthProvider) DefaultTokenScope() string {
262+
return DefaultTokenScopeFor(factory.CloudConfig)
192263
}
193264

194-
func (factory *AuthProvider) DefaultTokenScope() string {
195-
audience := factory.CloudConfig.Services[cloud.ResourceManager].Audience
265+
func DefaultTokenScopeFor(cloudCfg cloud.Configuration) string {
266+
audience := cloudCfg.Services[cloud.ResourceManager].Audience
196267
return fmt.Sprintf("%s/.default", strings.TrimRight(audience, "/"))
197268
}

0 commit comments

Comments
 (0)