Skip to content

Commit 604a03a

Browse files
morambrocopybara-github
authored andcommitted
Some refactoring of aws_kms_client.*
* Make AwsKmsClient move-only * Simplify GetAead logic; this results in calling GetKeyArn only once * Other minor fixes PiperOrigin-RevId: 528710559 Change-Id: Id4266b2cc750f4fb58693a6cf07b0f61955fbf58
1 parent 10b152a commit 604a03a

File tree

2 files changed

+64
-62
lines changed

2 files changed

+64
-62
lines changed

tink/integration/awskms/aws_kms_client.cc

+50-52
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ namespace {
4949
constexpr absl::string_view kKeyUriPrefix = "aws-kms://";
5050

5151
// Returns AWS key ARN contained in `key_uri`. If `key_uri` does not refer to an
52-
// AWS key, returns an empty string.
52+
// AWS key, returns an error.
5353
util::StatusOr<std::string> GetKeyArn(absl::string_view key_uri) {
5454
if (!absl::StartsWithIgnoreCase(key_uri, kKeyUriPrefix)) {
5555
return util::Status(absl::StatusCode::kInvalidArgument,
@@ -62,8 +62,8 @@ util::StatusOr<std::string> GetKeyArn(absl::string_view key_uri) {
6262
// `key_arn`.
6363
// An AWS key ARN is of the form
6464
// arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab.
65-
util::StatusOr<Aws::Client::ClientConfiguration>
66-
GetAwsClientConfig(absl::string_view key_arn) {
65+
util::StatusOr<Aws::Client::ClientConfiguration> GetAwsClientConfig(
66+
absl::string_view key_arn) {
6767
std::vector<std::string> key_arn_parts = absl::StrSplit(key_arn, ':');
6868
if (key_arn_parts.size() < 6) {
6969
return util::Status(absl::StatusCode::kInvalidArgument,
@@ -98,9 +98,9 @@ util::StatusOr<std::string> GetValue(absl::string_view name,
9898
absl::string_view line) {
9999
std::vector<std::string> parts = absl::StrSplit(line, '=');
100100
if (parts.size() != 2 || absl::StripAsciiWhitespace(parts[0]) != name) {
101-
return util::Status(
102-
absl::StatusCode::kInvalidArgument,
103-
absl::StrCat("Expected line in format ", name, " = value"));
101+
return util::Status(absl::StatusCode::kInvalidArgument,
102+
absl::StrCat("Expected line to have the format: ", name,
103+
" = value. Found: ", line));
104104
}
105105
return std::string(absl::StripAsciiWhitespace(parts[1]));
106106
}
@@ -132,40 +132,42 @@ util::StatusOr<std::string> GetValue(absl::string_view name,
132132
// Aws::Auth::ProfileConfigFileAWSCredentialsProvider.
133133
util::StatusOr<Aws::Auth::AWSCredentials> GetAwsCredentials(
134134
absl::string_view credentials_path) {
135-
if (!credentials_path.empty()) { // Read credentials from given file.
136-
auto creds_result = ReadFile(std::string(credentials_path));
137-
if (!creds_result.ok()) {
138-
return creds_result.status();
139-
}
140-
std::vector<std::string> creds_lines =
141-
absl::StrSplit(creds_result.value(), '\n');
142-
if (creds_lines.size() < 3) {
143-
return util::Status(absl::StatusCode::kInvalidArgument,
144-
absl::StrCat("Invalid format of credentials in file ",
145-
credentials_path));
146-
}
147-
auto key_id_result = GetValue("aws_access_key_id", creds_lines[1]);
148-
if (!key_id_result.ok()) {
149-
return util::Status(absl::StatusCode::kInvalidArgument,
150-
absl::StrCat("Invalid format of credentials in file ",
151-
credentials_path, " : ",
152-
key_id_result.status().message()));
153-
}
154-
auto secret_key_result = GetValue("aws_secret_access_key", creds_lines[2]);
155-
if (!secret_key_result.ok()) {
156-
return util::Status(
157-
absl::StatusCode::kInvalidArgument,
158-
absl::StrCat("Invalid format of credentials in file ",
159-
credentials_path, " : ",
160-
secret_key_result.status().message()));
161-
}
162-
return Aws::Auth::AWSCredentials(key_id_result.value().c_str(),
163-
secret_key_result.value().c_str());
135+
if (credentials_path.empty()) {
136+
// Get default credentials.
137+
Aws::Auth::DefaultAWSCredentialsProviderChain provider_chain;
138+
return provider_chain.GetAWSCredentials();
164139
}
165-
166-
// Get default credentials.
167-
Aws::Auth::DefaultAWSCredentialsProviderChain provider_chain;
168-
return provider_chain.GetAWSCredentials();
140+
// Read credentials from the given file.
141+
util::StatusOr<std::string> creds_result =
142+
ReadFile(std::string(credentials_path));
143+
if (!creds_result.ok()) {
144+
return creds_result.status();
145+
}
146+
std::vector<std::string> creds_lines =
147+
absl::StrSplit(creds_result.value(), '\n');
148+
if (creds_lines.size() < 3) {
149+
return util::Status(absl::StatusCode::kInvalidArgument,
150+
absl::StrCat("Invalid format of credentials in file ",
151+
credentials_path));
152+
}
153+
util::StatusOr<std::string> key_id_result =
154+
GetValue("aws_access_key_id", creds_lines[1]);
155+
if (!key_id_result.ok()) {
156+
return util::Status(
157+
absl::StatusCode::kInvalidArgument,
158+
absl::StrCat("Invalid format of credentials in file ", credentials_path,
159+
" : ", key_id_result.status().message()));
160+
}
161+
util::StatusOr<std::string> secret_key_result =
162+
GetValue("aws_secret_access_key", creds_lines[2]);
163+
if (!secret_key_result.ok()) {
164+
return util::Status(
165+
absl::StatusCode::kInvalidArgument,
166+
absl::StrCat("Invalid format of credentials in file ", credentials_path,
167+
" : ", secret_key_result.status().message()));
168+
}
169+
return Aws::Auth::AWSCredentials(key_id_result.value().c_str(),
170+
secret_key_result.value().c_str());
169171
}
170172

171173
void InitAwsApi() {
@@ -225,26 +227,22 @@ bool AwsKmsClient::DoesSupport(absl::string_view key_uri) const {
225227
return key_arn_.empty() ? true : key_arn_ == *key_arn;
226228
}
227229

228-
util::StatusOr<std::unique_ptr<Aead>>
229-
AwsKmsClient::GetAead(absl::string_view key_uri) const {
230-
if (!DoesSupport(key_uri)) {
231-
if (!key_arn_.empty()) {
230+
util::StatusOr<std::unique_ptr<Aead>> AwsKmsClient::GetAead(
231+
absl::string_view key_uri) const {
232+
util::StatusOr<std::string> key_arn = GetKeyArn(key_uri);
233+
if (!key_arn.ok()) {
234+
return key_arn.status();
235+
}
236+
// This client is bound to a specific key.
237+
if (!key_arn_.empty()) {
238+
if (key_arn_ != *key_arn) {
232239
return util::Status(absl::StatusCode::kInvalidArgument,
233240
absl::StrCat("This client is bound to ", key_arn_,
234241
" and cannot use key ", key_uri));
235242
}
236-
return util::Status(
237-
absl::StatusCode::kInvalidArgument,
238-
absl::StrCat("This client does not support key ", key_uri));
239-
}
240-
241-
// This client is bound to a specific key.
242-
if (!key_arn_.empty()) {
243243
return AwsKmsAead::New(key_arn_, aws_client_);
244244
}
245245

246-
// Create an Aws::KMS::KMSClient for the given key.
247-
util::StatusOr<std::string> key_arn = GetKeyArn(key_uri);
248246
util::StatusOr<Aws::Client::ClientConfiguration> client_config =
249247
GetAwsClientConfig(*key_arn);
250248
if (!client_config.ok()) {

tink/integration/awskms/aws_kms_client.h

+14-10
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
#include <memory>
2121

22-
#include "absl/strings/string_view.h"
23-
#include "absl/synchronization/mutex.h"
2422
#include "aws/core/auth/AWSCredentialsProvider.h"
2523
#include "aws/kms/KMSClient.h"
24+
#include "absl/strings/string_view.h"
25+
#include "absl/synchronization/mutex.h"
2626
#include "tink/aead.h"
2727
#include "tink/kms_client.h"
2828
#include "tink/kms_clients.h"
@@ -34,17 +34,23 @@ namespace tink {
3434
namespace integration {
3535
namespace awskms {
3636

37-
// AwsKmsClient is an implementation of KmsClient for
38-
// <a href="https://aws.amazon.com/kms/">AWS KMS</a>
37+
// AwsKmsClient is an implementation of KmsClient for AWS KMS
38+
// (https://aws.amazon.com/kms/).
3939
class AwsKmsClient : public crypto::tink::KmsClient {
4040
public:
41+
// Move only.
42+
AwsKmsClient(AwsKmsClient&& other) = default;
43+
AwsKmsClient& operator=(AwsKmsClient&& other) = default;
44+
AwsKmsClient(const AwsKmsClient&) = delete;
45+
AwsKmsClient& operator=(const AwsKmsClient&) = delete;
46+
4147
// Creates a new AwsKmsClient that is bound to the key specified in `key_uri`,
4248
// if not empty, and that uses the credentials in `credentials_path`, if not
4349
// empty, or the default ones to authenticate to the KMS.
4450
//
4551
// If `key_uri` is empty, then the client is not bound to any particular key.
46-
static crypto::tink::util::StatusOr<std::unique_ptr<AwsKmsClient>>
47-
New(absl::string_view key_uri, absl::string_view credentials_path);
52+
static crypto::tink::util::StatusOr<std::unique_ptr<AwsKmsClient>> New(
53+
absl::string_view key_uri, absl::string_view credentials_path);
4854

4955
// Creates a new client and registers it in KMSClients.
5056
static crypto::tink::util::Status RegisterNewClient(
@@ -55,10 +61,8 @@ class AwsKmsClient : public crypto::tink::KmsClient {
5561
// to a specific key.
5662
bool DoesSupport(absl::string_view key_uri) const override;
5763

58-
// Returns an Aead-primitive backed by KMS key specified by `key_uri`,
59-
// provided that this KmsClient does support `key_uri`.
60-
crypto::tink::util::StatusOr<std::unique_ptr<Aead>>
61-
GetAead(absl::string_view key_uri) const override;
64+
crypto::tink::util::StatusOr<std::unique_ptr<Aead>> GetAead(
65+
absl::string_view key_uri) const override;
6266

6367
private:
6468
AwsKmsClient(absl::string_view key_arn, Aws::Auth::AWSCredentials credentials)

0 commit comments

Comments
 (0)