diff --git a/msal/application.py b/msal/application.py index 49de7cd..afa08f5 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1357,13 +1357,14 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( key_id = kwargs.get("data", {}).get("key_id") if key_id: # Some token types (SSH-certs, POP) are bound to a key query["key_id"] = key_id - matches = self.token_cache.find( - self.token_cache.CredentialType.ACCESS_TOKEN, - target=scopes, - query=query) now = time.time() refresh_reason = msal.telemetry.AT_ABSENT - for entry in matches: + for entry in self.token_cache._find( # It returns a generator + self.token_cache.CredentialType.ACCESS_TOKEN, + target=scopes, + query=query, + ): # Note that _find() holds a lock during this for loop; + # that is fine because this loop is fast expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired refresh_reason = msal.telemetry.AT_EXPIRED @@ -1492,10 +1493,8 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( **kwargs) or last_resp def _get_app_metadata(self, environment): - apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) - TokenCache.CredentialType.APP_METADATA, query={ - "environment": environment, "client_id": self.client_id}) - return apps[0] if apps else {} + return self.token_cache._get_app_metadata( + environment=environment, client_id=self.client_id, default={}) def _acquire_token_silent_by_finding_specific_refresh_token( self, authority, scopes, query, diff --git a/msal/token_cache.py b/msal/token_cache.py index bd6d8a6..d19a1db 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -88,20 +88,69 @@ def __init__(self): "appmetadata-{}-{}".format(environment or "", client_id or ""), } - def find(self, credential_type, target=None, query=None): - target = target or [] + def _get_access_token( + self, + home_account_id, environment, client_id, realm, target, # Together they form a compound key + default=None, + ): # O(1) + return self._get( + self.CredentialType.ACCESS_TOKEN, + self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN]( + home_account_id=home_account_id, + environment=environment, + client_id=client_id, + realm=realm, + target=" ".join(target), + ), + default=default) + + def _get_app_metadata(self, environment, client_id, default=None): # O(1) + return self._get( + self.CredentialType.APP_METADATA, + self.key_makers[TokenCache.CredentialType.APP_METADATA]( + environment=environment, + client_id=client_id, + ), + default=default) + + def _get(self, credential_type, key, default=None): # O(1) + with self._lock: + return self._cache.get(credential_type, {}).get(key, default) + + def _find(self, credential_type, target=None, query=None): # O(n) generator + """Returns a generator of matching entries. + + It is O(1) for AT hits, and O(n) for other types. + Note that it holds a lock during the entire search. + """ + target = sorted(target or []) # Match the order sorted by add() assert isinstance(target, list), "Invalid parameter type" + + preferred_result = None + if (credential_type == self.CredentialType.ACCESS_TOKEN + and "home_account_id" in query and "environment" in query + and "client_id" in query and "realm" in query and target + ): # Special case for O(1) AT lookup + preferred_result = self._get_access_token( + query["home_account_id"], query["environment"], + query["client_id"], query["realm"], target) + if preferred_result: + yield preferred_result + target_set = set(target) with self._lock: # Since the target inside token cache key is (per schema) unsorted, # there is no point to attempt an O(1) key-value search here. # So we always do an O(n) in-memory search. - return [entry - for entry in self._cache.get(credential_type, {}).values() - if is_subdict_of(query or {}, entry) - and (target_set <= set(entry.get("target", "").split()) - if target else True) - ] + for entry in self._cache.get(credential_type, {}).values(): + if is_subdict_of(query or {}, entry) and ( + target_set <= set(entry.get("target", "").split()) + if target else True): + if entry != preferred_result: # Avoid yielding the same entry twice + yield entry + + def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead. + return list(self._find(credential_type, target=target, query=query)) def add(self, event, now=None): """Handle a token obtaining event, and add tokens into cache.""" @@ -160,7 +209,7 @@ def __add(self, event, now=None): decode_id_token(id_token, client_id=event["client_id"]) if id_token else {}) client_info, home_account_id = self.__parse_account(response, id_token_claims) - target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it + target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting with self._lock: now = int(time.time() if now is None else now) diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 2fe486c..94bf496 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -76,11 +76,11 @@ def testAddByAad(self): 'home_account_id': "uid.utid", 'realm': 'contoso', 'secret': 'an access token', - 'target': 's2 s1 s3', + 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', }, self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3') + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3') ) self.assertEqual( { @@ -90,10 +90,10 @@ def testAddByAad(self): 'home_account_id': "uid.utid", 'last_modification_time': '1000', 'secret': 'a refresh token', - 'target': 's2 s1 s3', + 'target': 's1 s2 s3', # Sorted }, self.cache._cache["RefreshToken"].get( - 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') + 'uid.utid-login.example.com-refreshtoken-my_client_id--s1 s2 s3') ) self.assertEqual( { @@ -150,11 +150,11 @@ def testAddByAdfs(self): 'home_account_id': "subject", 'realm': 'adfs', 'secret': 'an access token', - 'target': 's2 s1 s3', + 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', }, self.cache._cache["AccessToken"].get( - 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3') + 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3') ) self.assertEqual( { @@ -164,10 +164,10 @@ def testAddByAdfs(self): 'home_account_id': "subject", 'last_modification_time': "1000", 'secret': 'a refresh token', - 'target': 's2 s1 s3', + 'target': 's1 s2 s3', # Sorted }, self.cache._cache["RefreshToken"].get( - 'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3') + 'subject-fs.msidlab8.com-refreshtoken-my_client_id--s1 s2 s3') ) self.assertEqual( { @@ -214,7 +214,7 @@ def test_key_id_is_also_recorded(self): refresh_token="a refresh token"), }, now=1000) cached_key_id = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', {}).get("key_id") self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") @@ -229,7 +229,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep ), #refresh_token="a refresh token"), }, now=1000) refresh_on = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', {}).get("refresh_on") self.assertEqual("2800", refresh_on, "Should save refresh_on")