Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Promote TokenCache._find() to TokenCache.search() #693

Merged
1 commit merged into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def _find_msal_accounts(self, environment):
"local_account_id": a.get("local_account_id"), # Tenant-specific
"realm": a.get("realm"), # Tenant-specific
}
for a in self.token_cache.find(
for a in self.token_cache.search(
TokenCache.CredentialType.ACCOUNT,
query={"environment": environment})
if a["authority_type"] in interested_authority_types
Expand Down Expand Up @@ -1188,18 +1188,22 @@ def _sign_out(self, home_account):
"home_account_id": home_account["home_account_id"],} # realm-independent
app_metadata = self._get_app_metadata(home_account["environment"])
# Remove RTs/FRTs, and they are realm-independent
for rt in [rt for rt in self.token_cache.find(
for rt in [ # Remove RTs from a static list (rather than from a dynamic generator),
# to avoid changing self.token_cache while it is being iterated
rt for rt in self.token_cache.search(
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account)
# Do RT's app ownership check as a precaution, in case family apps
# and 3rd-party apps share same token cache, although they should not.
if rt["client_id"] == self.client_id or (
app_metadata.get("family_id") # Now let's settle family business
and rt.get("family_id") == app_metadata["family_id"])
]:
]:
self.token_cache.remove_rt(rt)
for at in self.token_cache.find( # Remove ATs
# Regardless of realm, b/c we've removed realm-independent RTs anyway
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account):
for at in list(self.token_cache.search( # Remove ATs from a static list,
# to avoid changing self.token_cache while it is being iterated
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account,
# Regardless of realm, b/c we've removed realm-independent RTs anyway
)):
# To avoid the complexity of locating sibling family app's AT,
# we skip AT's app ownership check.
# It means ATs for other apps will also be removed, it is OK because:
Expand All @@ -1213,11 +1217,15 @@ def _forget_me(self, home_account):
owned_by_home_account = {
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account):
for idt in list(self.token_cache.search( # Remove IDTs from a static list,
# to avoid changing self.token_cache while it is being iterated
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account, # regardless of realm
)):
self.token_cache.remove_idt(idt)
for a in self.token_cache.find( # Remove Accounts, regardless of realm
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
for a in list(self.token_cache.search( # Remove Accounts from a static list,
# to avoid changing self.token_cache while it is being iterated
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account, # regardless of realm
)):
self.token_cache.remove_account(a)

def _acquire_token_by_cloud_shell(self, scopes, data=None):
Expand Down Expand Up @@ -1350,12 +1358,12 @@ def _acquire_token_silent_with_error(
return result
final_result = result
for alias in self._get_authority_aliases(self.authority.instance):
if not self.token_cache.find(
if not list(self.token_cache.search( # Need a list to test emptiness
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # MUST NOT filter by scopes, because:
# 1. AAD RTs are scope-independent;
# 2. therefore target is optional per schema;
query={"environment": alias}):
query={"environment": alias})):
# Skip heavy weight logic when RT for this alias doesn't exist
continue
the_authority = Authority(
Expand Down Expand Up @@ -1410,11 +1418,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
query["key_id"] = key_id
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in self.token_cache._find( # It returns a generator
for entry in self.token_cache.search( # A generator allows us to
# break early in cache-hit without finding a full list
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query=query,
): # Note that _find() holds a lock during this for loop;
): # This loop is about token search, not about token deletion.
# Note that search() holds a lock during this loop;
# that is fine because this loop is fast
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
Expand Down Expand Up @@ -1552,10 +1562,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
rt_remover=None, break_condition=lambda response: False,
refresh_reason=None, correlation_id=None, claims_challenge=None,
**kwargs):
matches = self.token_cache.find(
matches = list(self.token_cache.search( # We want a list to test emptiness
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
query=query)
query=query))
logger.debug("Found %d RTs matching %s", len(matches), {
k: _pii_less_home_account_id(v) if k == "home_account_id" and v else v
for k, v in query.items()
Expand Down Expand Up @@ -2252,11 +2262,12 @@ def remove_tokens_for_client(self):
:func:`~acquire_token_for_client()` for the current client."""
for env in [self.authority.instance] + self._get_authority_aliases(
self.authority.instance):
for at in self.token_cache.find(TokenCache.CredentialType.ACCESS_TOKEN, query={
for at in list(self.token_cache.search( # Remove ATs from a snapshot
TokenCache.CredentialType.ACCESS_TOKEN, query={
"client_id": self.client_id,
"environment": env,
"home_account_id": None, # These are mostly app-only tokens
}):
})):
self.token_cache.remove_at(at)
# acquire_token_for_client() obtains no RTs, so we have no RT to remove

Expand Down
11 changes: 8 additions & 3 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import time
import logging
import warnings

from .authority import canonicalize
from .oauth2cli.oidc import decode_part, decode_id_token
Expand Down Expand Up @@ -117,7 +118,7 @@ def _get(self, credential_type, key, default=None): # O(1)
with self._lock:
return self._cache.get(credential_type, {}).get(key, default)

def _find(self, credential_type, target=None, query=None): # O(n) generator
def search(self, credential_type, target=None, query=None): # O(n) generator
"""Returns a generator of matching entries.

It is O(1) for AT hits, and O(n) for other types.
Expand Down Expand Up @@ -150,8 +151,12 @@ def _find(self, credential_type, target=None, query=None): # O(n) generator
if entry != preferred_result: # Avoid yielding the same entry twice
yield entry

def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead.
return list(self._find(credential_type, target=target, query=query))
def find(self, credential_type, target=None, query=None):
"""Equivalent to list(search(...))."""
warnings.warn(
"Use list(search(...)) instead to explicitly get a list.",
DeprecationWarning)
return list(self.search(credential_type, target=target, query=query))

def add(self, event, now=None):
"""Handle a token obtaining event, and add tokens into cache."""
Expand Down