diff --git a/msal/application.py b/msal/application.py index a722b28..6a6100f 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1142,7 +1142,7 @@ def _find_msal_accounts(self, environment): "local_account_id": a.get("local_account_id"), # Tenant-specific "realm": a.get("realm"), # Tenant-specific } - for a in self.token_cache.find( + for a in self.token_cache.search( TokenCache.CredentialType.ACCOUNT, query={"environment": environment}) if a["authority_type"] in interested_authority_types @@ -1188,18 +1188,22 @@ def _sign_out(self, home_account): "home_account_id": home_account["home_account_id"],} # realm-independent app_metadata = self._get_app_metadata(home_account["environment"]) # Remove RTs/FRTs, and they are realm-independent - for rt in [rt for rt in self.token_cache.find( + for rt in [ # Remove RTs from a static list (rather than from a dynamic generator), + # to avoid changing self.token_cache while it is being iterated + rt for rt in self.token_cache.search( TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account) # Do RT's app ownership check as a precaution, in case family apps # and 3rd-party apps share same token cache, although they should not. if rt["client_id"] == self.client_id or ( app_metadata.get("family_id") # Now let's settle family business and rt.get("family_id") == app_metadata["family_id"]) - ]: + ]: self.token_cache.remove_rt(rt) - for at in self.token_cache.find( # Remove ATs - # Regardless of realm, b/c we've removed realm-independent RTs anyway - TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account): + for at in list(self.token_cache.search( # Remove ATs from a static list, + # to avoid changing self.token_cache while it is being iterated + TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account, + # Regardless of realm, b/c we've removed realm-independent RTs anyway + )): # To avoid the complexity of locating sibling family app's AT, # we skip AT's app ownership check. # It means ATs for other apps will also be removed, it is OK because: @@ -1213,11 +1217,15 @@ def _forget_me(self, home_account): owned_by_home_account = { "environment": home_account["environment"], "home_account_id": home_account["home_account_id"],} # realm-independent - for idt in self.token_cache.find( # Remove IDTs, regardless of realm - TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account): + for idt in list(self.token_cache.search( # Remove IDTs from a static list, + # to avoid changing self.token_cache while it is being iterated + TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account, # regardless of realm + )): self.token_cache.remove_idt(idt) - for a in self.token_cache.find( # Remove Accounts, regardless of realm - TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account): + for a in list(self.token_cache.search( # Remove Accounts from a static list, + # to avoid changing self.token_cache while it is being iterated + TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account, # regardless of realm + )): self.token_cache.remove_account(a) def _acquire_token_by_cloud_shell(self, scopes, data=None): @@ -1350,12 +1358,12 @@ def _acquire_token_silent_with_error( return result final_result = result for alias in self._get_authority_aliases(self.authority.instance): - if not self.token_cache.find( + if not list(self.token_cache.search( # Need a list to test emptiness self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # MUST NOT filter by scopes, because: # 1. AAD RTs are scope-independent; # 2. therefore target is optional per schema; - query={"environment": alias}): + query={"environment": alias})): # Skip heavy weight logic when RT for this alias doesn't exist continue the_authority = Authority( @@ -1410,11 +1418,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( query["key_id"] = key_id now = time.time() refresh_reason = msal.telemetry.AT_ABSENT - for entry in self.token_cache._find( # It returns a generator + for entry in self.token_cache.search( # A generator allows us to + # break early in cache-hit without finding a full list self.token_cache.CredentialType.ACCESS_TOKEN, target=scopes, query=query, - ): # Note that _find() holds a lock during this for loop; + ): # This loop is about token search, not about token deletion. + # Note that search() holds a lock during this loop; # that is fine because this loop is fast expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired @@ -1552,10 +1562,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token( rt_remover=None, break_condition=lambda response: False, refresh_reason=None, correlation_id=None, claims_challenge=None, **kwargs): - matches = self.token_cache.find( + matches = list(self.token_cache.search( # We want a list to test emptiness self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # AAD RTs are scope-independent - query=query) + query=query)) logger.debug("Found %d RTs matching %s", len(matches), { k: _pii_less_home_account_id(v) if k == "home_account_id" and v else v for k, v in query.items() @@ -2252,11 +2262,12 @@ def remove_tokens_for_client(self): :func:`~acquire_token_for_client()` for the current client.""" for env in [self.authority.instance] + self._get_authority_aliases( self.authority.instance): - for at in self.token_cache.find(TokenCache.CredentialType.ACCESS_TOKEN, query={ + for at in list(self.token_cache.search( # Remove ATs from a snapshot + TokenCache.CredentialType.ACCESS_TOKEN, query={ "client_id": self.client_id, "environment": env, "home_account_id": None, # These are mostly app-only tokens - }): + })): self.token_cache.remove_at(at) # acquire_token_for_client() obtains no RTs, so we have no RT to remove diff --git a/msal/token_cache.py b/msal/token_cache.py index 444aa2d..ffa0090 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -2,6 +2,7 @@ import threading import time import logging +import warnings from .authority import canonicalize from .oauth2cli.oidc import decode_part, decode_id_token @@ -117,7 +118,7 @@ def _get(self, credential_type, key, default=None): # O(1) with self._lock: return self._cache.get(credential_type, {}).get(key, default) - def _find(self, credential_type, target=None, query=None): # O(n) generator + def search(self, credential_type, target=None, query=None): # O(n) generator """Returns a generator of matching entries. It is O(1) for AT hits, and O(n) for other types. @@ -150,8 +151,12 @@ def _find(self, credential_type, target=None, query=None): # O(n) generator if entry != preferred_result: # Avoid yielding the same entry twice yield entry - def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead. - return list(self._find(credential_type, target=target, query=query)) + def find(self, credential_type, target=None, query=None): + """Equivalent to list(search(...)).""" + warnings.warn( + "Use list(search(...)) instead to explicitly get a list.", + DeprecationWarning) + return list(self.search(credential_type, target=target, query=query)) def add(self, event, now=None): """Handle a token obtaining event, and add tokens into cache."""