@@ -23,6 +23,7 @@ import (
23
23
"strings"
24
24
25
25
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
26
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
26
27
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
27
28
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
28
29
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -32,23 +33,29 @@ import (
32
33
)
33
34
34
35
type AuthProvider struct {
35
- ComputeCredential azcore.TokenCredential
36
- NetworkCredential azcore. TokenCredential
37
- MultiTenantCredential azcore.TokenCredential
38
- CloudConfig cloud.Configuration
36
+ ComputeCredential azcore.TokenCredential
37
+ AdditionalComputeClientOptions [] func ( option * arm. ClientOptions )
38
+ NetworkCredential azcore.TokenCredential
39
+ CloudConfig cloud.Configuration
39
40
}
40
41
41
- func NewAuthProvider (armConfig * ARMClientConfig , config * AzureAuthConfig , clientOptionsMutFn ... func (option * policy.ClientOptions )) (* AuthProvider , error ) {
42
+ func NewAuthProvider (
43
+ armConfig * ARMClientConfig ,
44
+ config * AzureAuthConfig ,
45
+ clientOptionsMutFn ... func (option * policy.ClientOptions ),
46
+ ) (* AuthProvider , error ) {
42
47
clientOption , _ , err := GetAzCoreClientOption (armConfig )
43
48
if err != nil {
44
49
return nil , err
45
50
}
46
51
for _ , fn := range clientOptionsMutFn {
47
52
fn (clientOption )
48
53
}
49
- var computeCredential azcore.TokenCredential
50
- var networkTokenCredential azcore.TokenCredential
51
- var multiTenantCredential azcore.TokenCredential
54
+ var (
55
+ computeCredential azcore.TokenCredential
56
+ networkCredential azcore.TokenCredential
57
+ additionalComputeClientOptions []func (option * arm.ClientOptions )
58
+ )
52
59
53
60
// federatedIdentityCredential is used for workload identity federation
54
61
if aadFederatedTokenFile , enabled := config .GetAzureFederatedTokenFile (); enabled {
@@ -62,6 +69,7 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
62
69
return nil , err
63
70
}
64
71
}
72
+
65
73
// managedIdentityCredential is used for managed identity extension
66
74
if computeCredential == nil && config .UseManagedIdentityExtension {
67
75
credOptions := & azidentity.ManagedIdentityCredentialOptions {
@@ -79,52 +87,80 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
79
87
return nil , err
80
88
}
81
89
if config .AuxiliaryTokenProvider != nil && IsMultiTenant (armConfig ) {
82
- networkTokenCredential , err = armauth .NewKeyVaultCredential (
90
+ // Use AuxiliaryTokenProvider as the network credential
91
+ networkCredential , err = armauth .NewKeyVaultCredential (
83
92
computeCredential ,
84
93
config .AuxiliaryTokenProvider .SecretResourceID (),
85
94
)
86
95
if err != nil {
87
96
return nil , fmt .Errorf ("create KeyVaultCredential for auxiliary token provider: %w" , err )
88
97
}
98
+
99
+ // Additionally, we need to add the auxiliary token to the HTTP header when making requests to the compute resources
100
+ additionalComputeClientOptions = append (additionalComputeClientOptions , func (option * arm.ClientOptions ) {
101
+ option .PerRetryPolicies = append (option .PerRetryPolicies , armauth .NewAuxiliaryAuthPolicy (
102
+ []azcore.TokenCredential {networkCredential },
103
+ DefaultTokenScopeFor (clientOption .Cloud ),
104
+ ))
105
+ })
89
106
}
90
107
}
91
108
92
109
// Client secret authentication
93
110
if computeCredential == nil && len (config .GetAADClientSecret ()) > 0 {
94
- credOptions := & azidentity.ClientSecretCredentialOptions {
95
- ClientOptions : * clientOption ,
96
- }
97
- computeCredential , err = azidentity .NewClientSecretCredential (armConfig .GetTenantID (), config .GetAADClientID (), config .GetAADClientSecret (), credOptions )
98
- if err != nil {
99
- return nil , err
100
- }
101
111
if IsMultiTenant (armConfig ) {
102
- credOptions := & azidentity.ClientSecretCredentialOptions {
103
- ClientOptions : * clientOption ,
104
- }
105
- networkTokenCredential , err = azidentity .NewClientSecretCredential (armConfig .NetworkResourceTenantID , config .GetAADClientID (), config .AADClientSecret , credOptions )
106
- if err != nil {
107
- return nil , err
112
+
113
+ // Network credential for network resource access
114
+ {
115
+ credOptions := & azidentity.ClientSecretCredentialOptions {
116
+ ClientOptions : * clientOption ,
117
+ }
118
+ networkCredential , err = azidentity .NewClientSecretCredential (
119
+ armConfig .NetworkResourceTenantID ,
120
+ config .GetAADClientID (),
121
+ config .GetAADClientSecret (),
122
+ credOptions ,
123
+ )
124
+ if err != nil {
125
+ return nil , err
126
+ }
108
127
}
109
128
110
- credOptions = & azidentity.ClientSecretCredentialOptions {
111
- ClientOptions : * clientOption ,
112
- AdditionallyAllowedTenants : []string {armConfig .NetworkResourceTenantID },
129
+ // Compute credential with additional allowed tenants for cross-tenant access
130
+ {
131
+ credOptions := & azidentity.ClientSecretCredentialOptions {
132
+ ClientOptions : * clientOption ,
133
+ AdditionallyAllowedTenants : []string {armConfig .NetworkResourceTenantID },
134
+ }
135
+ computeCredential , err = azidentity .NewClientSecretCredential (
136
+ armConfig .GetTenantID (),
137
+ config .GetAADClientID (),
138
+ config .GetAADClientSecret (),
139
+ credOptions ,
140
+ )
141
+ if err != nil {
142
+ return nil , err
143
+ }
113
144
}
114
- multiTenantCredential , err = azidentity .NewClientSecretCredential (armConfig .GetTenantID (), config .GetAADClientID (), config .GetAADClientSecret (), credOptions )
145
+ } else {
146
+ // Single tenant
147
+ credOptions := & azidentity.ClientSecretCredentialOptions {
148
+ ClientOptions : * clientOption ,
149
+ }
150
+ computeCredential , err = azidentity .NewClientSecretCredential (
151
+ armConfig .GetTenantID (),
152
+ config .GetAADClientID (),
153
+ config .GetAADClientSecret (),
154
+ credOptions ,
155
+ )
115
156
if err != nil {
116
157
return nil , err
117
158
}
118
-
119
159
}
120
160
}
121
161
122
162
// ClientCertificateCredential is used for client certificate
123
163
if computeCredential == nil && len (config .AADClientCertPath ) > 0 {
124
- credOptions := & azidentity.ClientCertificateCredentialOptions {
125
- ClientOptions : * clientOption ,
126
- SendCertificateChain : true ,
127
- }
128
164
certData , err := os .ReadFile (config .AADClientCertPath )
129
165
if err != nil {
130
166
return nil , fmt .Errorf ("reading the client certificate from file %s: %w" , config .AADClientCertPath , err )
@@ -133,20 +169,58 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
133
169
if err != nil {
134
170
return nil , fmt .Errorf ("decoding the client certificate: %w" , err )
135
171
}
136
- computeCredential , err = azidentity .NewClientCertificateCredential (armConfig .GetTenantID (), config .GetAADClientID (), certificate , privateKey , credOptions )
137
- if err != nil {
138
- return nil , err
139
- }
172
+
140
173
if IsMultiTenant (armConfig ) {
141
- networkTokenCredential , err = azidentity .NewClientCertificateCredential (armConfig .NetworkResourceTenantID , config .GetAADClientID (), certificate , privateKey , credOptions )
142
- if err != nil {
143
- return nil , err
174
+
175
+ // Network credential for network resource access
176
+ {
177
+ credOptions := & azidentity.ClientCertificateCredentialOptions {
178
+ ClientOptions : * clientOption ,
179
+ SendCertificateChain : true ,
180
+ }
181
+ networkCredential , err = azidentity .NewClientCertificateCredential (
182
+ armConfig .NetworkResourceTenantID ,
183
+ config .GetAADClientID (),
184
+ certificate ,
185
+ privateKey ,
186
+ credOptions ,
187
+ )
188
+ if err != nil {
189
+ return nil , err
190
+ }
144
191
}
145
- credOptions = & azidentity.ClientCertificateCredentialOptions {
146
- ClientOptions : * clientOption ,
147
- AdditionallyAllowedTenants : []string {armConfig .NetworkResourceTenantID },
192
+
193
+ // Compute credential with additional allowed tenants for cross-tenant access
194
+ {
195
+ credOptions := & azidentity.ClientCertificateCredentialOptions {
196
+ ClientOptions : * clientOption ,
197
+ AdditionallyAllowedTenants : []string {armConfig .NetworkResourceTenantID },
198
+ SendCertificateChain : true ,
199
+ }
200
+ computeCredential , err = azidentity .NewClientCertificateCredential (
201
+ armConfig .GetTenantID (),
202
+ config .GetAADClientID (),
203
+ certificate ,
204
+ privateKey ,
205
+ credOptions ,
206
+ )
207
+ if err != nil {
208
+ return nil , err
209
+ }
148
210
}
149
- multiTenantCredential , err = azidentity .NewClientCertificateCredential (armConfig .GetTenantID (), config .GetAADClientID (), certificate , privateKey , credOptions )
211
+ } else {
212
+ // Single tenant
213
+ credOptions := & azidentity.ClientCertificateCredentialOptions {
214
+ ClientOptions : * clientOption ,
215
+ SendCertificateChain : true ,
216
+ }
217
+ computeCredential , err = azidentity .NewClientCertificateCredential (
218
+ armConfig .GetTenantID (),
219
+ config .GetAADClientID (),
220
+ certificate ,
221
+ privateKey ,
222
+ credOptions ,
223
+ )
150
224
if err != nil {
151
225
return nil , err
152
226
}
@@ -155,17 +229,21 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
155
229
156
230
// UserAssignedIdentityCredentials authentication
157
231
if computeCredential == nil && len (config .AADMSIDataPlaneIdentityPath ) > 0 {
158
- computeCredential , err = dataplane .NewUserAssignedIdentityCredential (context .Background (), config .AADMSIDataPlaneIdentityPath , dataplane .WithClientOpts (azcore.ClientOptions {Cloud : clientOption .Cloud }))
232
+ computeCredential , err = dataplane .NewUserAssignedIdentityCredential (
233
+ context .Background (),
234
+ config .AADMSIDataPlaneIdentityPath ,
235
+ dataplane .WithClientOpts (azcore.ClientOptions {Cloud : clientOption .Cloud }),
236
+ )
159
237
if err != nil {
160
238
return nil , err
161
239
}
162
240
}
163
241
164
242
return & AuthProvider {
165
- ComputeCredential : computeCredential ,
166
- NetworkCredential : networkTokenCredential ,
167
- MultiTenantCredential : multiTenantCredential ,
168
- CloudConfig : clientOption .Cloud ,
243
+ ComputeCredential : computeCredential ,
244
+ AdditionalComputeClientOptions : additionalComputeClientOptions ,
245
+ NetworkCredential : networkCredential ,
246
+ CloudConfig : clientOption .Cloud ,
169
247
}, nil
170
248
}
171
249
@@ -180,18 +258,11 @@ func (factory *AuthProvider) GetNetworkAzIdentity() azcore.TokenCredential {
180
258
return factory .ComputeCredential
181
259
}
182
260
183
- func (factory * AuthProvider ) GetMultiTenantIdentity () azcore.TokenCredential {
184
- if factory .MultiTenantCredential != nil {
185
- return factory .MultiTenantCredential
186
- }
187
- return factory .ComputeCredential
188
- }
189
-
190
- func (factory * AuthProvider ) IsMultiTenantModeEnabled () bool {
191
- return factory .MultiTenantCredential != nil
261
+ func (factory * AuthProvider ) DefaultTokenScope () string {
262
+ return DefaultTokenScopeFor (factory .CloudConfig )
192
263
}
193
264
194
- func ( factory * AuthProvider ) DefaultTokenScope ( ) string {
195
- audience := factory . CloudConfig .Services [cloud .ResourceManager ].Audience
265
+ func DefaultTokenScopeFor ( cloudCfg cloud. Configuration ) string {
266
+ audience := cloudCfg .Services [cloud .ResourceManager ].Audience
196
267
return fmt .Sprintf ("%s/.default" , strings .TrimRight (audience , "/" ))
197
268
}
0 commit comments