Skip to content

Commit b90d20e

Browse files
authored
Merge pull request #73 from AzureAD/release-0.3.0
Release 0.3.0
2 parents 3df9da0 + 6d2efab commit b90d20e

8 files changed

+302
-19
lines changed

.github/workflows/codeql.yml

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: "Code Scanning - Action"
2+
3+
on:
4+
push:
5+
schedule:
6+
- cron: '0 0 * * 0'
7+
8+
jobs:
9+
CodeQL-Build:
10+
11+
strategy:
12+
fail-fast: false
13+
14+
15+
# CodeQL runs on ubuntu-latest, windows-latest, and macos-latest
16+
runs-on: ubuntu-latest
17+
18+
steps:
19+
- name: Checkout repository
20+
uses: actions/checkout@v2
21+
22+
# Initializes the CodeQL tools for scanning.
23+
- name: Initialize CodeQL
24+
uses: github/codeql-action/init@v1
25+
# Override language selection by uncommenting this and choosing your languages
26+
# with:
27+
# languages: go, javascript, csharp, python, cpp, java
28+
29+
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
30+
# If this step fails, then you should remove it and run the build manually (see below).
31+
- name: Autobuild
32+
uses: github/codeql-action/autobuild@v1
33+
34+
# ℹ️ Command-line programs to run using the OS shell.
35+
# 📚 https://git.io/JvXDl
36+
37+
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
38+
# and modify them (or add more) to build your code if your project
39+
# uses a compiled language
40+
41+
#- run: |
42+
# make bootstrap
43+
# make release
44+
45+
- name: Perform CodeQL Analysis
46+
uses: github/codeql-action/analyze@v1

msal_extensions/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Provides auxiliary functionality to the `msal` package."""
2-
__version__ = "0.2.2"
2+
__version__ = "0.3.0"
33

44
import sys
55

msal_extensions/persistence.py

+93-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import abc
1010
import os
1111
import errno
12+
import logging
1213
try:
1314
from pathlib import Path # Built-in in Python 3
1415
except:
@@ -21,6 +22,9 @@
2122
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore
2223

2324

25+
logger = logging.getLogger(__name__)
26+
27+
2428
def _mkdir_p(path):
2529
"""Creates a directory, and any necessary parents.
2630
@@ -41,6 +45,20 @@ def _mkdir_p(path):
4145
raise
4246

4347

48+
# We do not aim to wrap every os-specific exception.
49+
# Here we define only the most common one,
50+
# otherwise caller would need to catch os-specific persistence exceptions.
51+
class PersistenceNotFound(IOError): # Use IOError rather than OSError as base,
52+
# because historically an IOError was bubbled up and expected.
53+
# https://github.com/AzureAD/microsoft-authentication-extensions-for-python/blob/0.2.2/msal_extensions/token_cache.py#L38
54+
# Now we want to maintain backward compatibility even when using Python 2.x
55+
# It makes no difference in Python 3.3+ where IOError is an alias of OSError.
56+
def __init__(
57+
self,
58+
err_no=errno.ENOENT, message="Persistence not found", location=None):
59+
super(PersistenceNotFound, self).__init__(err_no, message, location)
60+
61+
4462
class BasePersistence(ABC):
4563
"""An abstract persistence defining the common interface of this family"""
4664

@@ -55,12 +73,18 @@ def save(self, content):
5573
@abc.abstractmethod
5674
def load(self):
5775
# type: () -> str
58-
"""Load content from this persistence"""
76+
"""Load content from this persistence.
77+
78+
Could raise PersistenceNotFound if no save() was called before.
79+
"""
5980
raise NotImplementedError
6081

6182
@abc.abstractmethod
6283
def time_last_modified(self):
63-
"""Get the last time when this persistence has been modified"""
84+
"""Get the last time when this persistence has been modified.
85+
86+
Could raise PersistenceNotFound if no save() was called before.
87+
"""
6488
raise NotImplementedError
6589

6690
@abc.abstractmethod
@@ -87,11 +111,32 @@ def save(self, content):
87111
def load(self):
88112
# type: () -> str
89113
"""Load content from this persistence"""
90-
with open(self._location, 'r') as handle:
91-
return handle.read()
114+
try:
115+
with open(self._location, 'r') as handle:
116+
return handle.read()
117+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
118+
if exp.errno == errno.ENOENT:
119+
raise PersistenceNotFound(
120+
message=(
121+
"Persistence not initialized. "
122+
"You can recover by calling a save() first."),
123+
location=self._location,
124+
)
125+
raise
126+
92127

93128
def time_last_modified(self):
94-
return os.path.getmtime(self._location)
129+
try:
130+
return os.path.getmtime(self._location)
131+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
132+
if exp.errno == errno.ENOENT:
133+
raise PersistenceNotFound(
134+
message=(
135+
"Persistence not initialized. "
136+
"You can recover by calling a save() first."),
137+
location=self._location,
138+
)
139+
raise
95140

96141
def touch(self):
97142
"""To touch this file-based persistence without writing content into it"""
@@ -115,13 +160,28 @@ def __init__(self, location, entropy=''):
115160

116161
def save(self, content):
117162
# type: (str) -> None
163+
data = self._dp_agent.protect(content)
118164
with open(self._location, 'wb+') as handle:
119-
handle.write(self._dp_agent.protect(content))
165+
handle.write(data)
120166

121167
def load(self):
122168
# type: () -> str
123-
with open(self._location, 'rb') as handle:
124-
return self._dp_agent.unprotect(handle.read())
169+
try:
170+
with open(self._location, 'rb') as handle:
171+
data = handle.read()
172+
return self._dp_agent.unprotect(data)
173+
except EnvironmentError as exp: # EnvironmentError in Py 2.7 works across platform
174+
if exp.errno == errno.ENOENT:
175+
raise PersistenceNotFound(
176+
message=(
177+
"Persistence not initialized. "
178+
"You can recover by calling a save() first."),
179+
location=self._location,
180+
)
181+
logger.exception(
182+
"DPAPI error likely caused by file content not previously encrypted. "
183+
"App developer should migrate by calling save(plaintext) first.")
184+
raise
125185

126186

127187
class KeychainPersistence(BasePersistence):
@@ -136,9 +196,10 @@ def __init__(self, signal_location, service_name, account_name):
136196
"""
137197
if not (service_name and account_name): # It would hang on OSX
138198
raise ValueError("service_name and account_name are required")
139-
from .osx import Keychain # pylint: disable=import-outside-toplevel
199+
from .osx import Keychain, KeychainError # pylint: disable=import-outside-toplevel
140200
self._file_persistence = FilePersistence(signal_location) # Favor composition
141201
self._Keychain = Keychain # pylint: disable=invalid-name
202+
self._KeychainError = KeychainError # pylint: disable=invalid-name
142203
self._service_name = service_name
143204
self._account_name = account_name
144205

@@ -150,8 +211,21 @@ def save(self, content):
150211

151212
def load(self):
152213
with self._Keychain() as locker:
153-
return locker.get_generic_password(
154-
self._service_name, self._account_name)
214+
try:
215+
return locker.get_generic_password(
216+
self._service_name, self._account_name)
217+
except self._KeychainError as ex:
218+
if ex.exit_status == self._KeychainError.ITEM_NOT_FOUND:
219+
# This happens when a load() is called before a save().
220+
# We map it into cross-platform error for unified catching.
221+
raise PersistenceNotFound(
222+
location="Service:{} Account:{}".format(
223+
self._service_name, self._account_name),
224+
message=(
225+
"Keychain persistence not initialized. "
226+
"You can recover by call a save() first."),
227+
)
228+
raise # We do not intend to hide any other underlying exceptions
155229

156230
def time_last_modified(self):
157231
return self._file_persistence.time_last_modified()
@@ -188,7 +262,14 @@ def save(self, content):
188262
self._file_persistence.touch() # For time_last_modified()
189263

190264
def load(self):
191-
return self._agent.load()
265+
data = self._agent.load()
266+
if data is None:
267+
# Lower level libsecret would return None when found nothing. Here
268+
# in persistence layer, we convert it to a unified error for consistence.
269+
raise PersistenceNotFound(message=(
270+
"Keyring persistence not initialized. "
271+
"You can recover by call a save() first."))
272+
return data
192273

193274
def time_last_modified(self):
194275
return self._file_persistence.time_last_modified()

msal_extensions/token_cache.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
import warnings
44
import time
5-
import errno
65
import logging
76

87
import msal
98

109
from .cache_lock import CrossPlatLock
1110
from .persistence import (
12-
_mkdir_p, FilePersistence,
11+
_mkdir_p, PersistenceNotFound, FilePersistence,
1312
FilePersistenceWithDataProtection, KeychainPersistence)
1413

1514

@@ -35,10 +34,10 @@ def _reload_if_necessary(self):
3534
if self._last_sync < self._persistence.time_last_modified():
3635
self.deserialize(self._persistence.load())
3736
self._last_sync = time.time()
38-
except IOError as exp:
39-
if exp.errno != errno.ENOENT:
40-
raise
41-
# Otherwise, from cache's perspective, a nonexistent file is a NO-OP
37+
except PersistenceNotFound:
38+
# From cache's perspective, a nonexistent persistence is a NO-OP.
39+
pass
40+
# However, existing data unable to be decrypted will still be bubbled up.
4241

4342
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
4443
with CrossPlatLock(self._lock_location):

tests/cache_file_generator.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Usage: cache_file_generator.py cache_file_path sleep_interval
3+
4+
This is a console application which is to be used for cross-platform lock performance testing.
5+
The app will acquire lock for the cache file, log the process id and then release the lock.
6+
7+
It takes in two arguments - cache file path and the sleep interval.
8+
The cache file path is the path of cache file.
9+
The sleep interval is the time in seconds for which the lock is held by a process.
10+
"""
11+
12+
import logging
13+
import os
14+
import sys
15+
import time
16+
17+
from portalocker import exceptions
18+
19+
from msal_extensions import FilePersistence, CrossPlatLock
20+
21+
22+
def _acquire_lock_and_write_to_cache(cache_location, sleep_interval):
23+
cache_accessor = FilePersistence(cache_location)
24+
lock_file_path = cache_accessor.get_location() + ".lockfile"
25+
try:
26+
with CrossPlatLock(lock_file_path):
27+
data = cache_accessor.load()
28+
if data is None:
29+
data = ""
30+
data += "< " + str(os.getpid()) + "\n"
31+
time.sleep(sleep_interval)
32+
data += "> " + str(os.getpid()) + "\n"
33+
cache_accessor.save(data)
34+
except exceptions.LockException as e:
35+
logging.warning("Unable to acquire lock %s", e)
36+
37+
38+
if __name__ == "__main__":
39+
if len(sys.argv) < 3:
40+
print(__doc__)
41+
sys.exit(0)
42+
_acquire_lock_and_write_to_cache(sys.argv[1], float(sys.argv[2]))
43+

tests/test_agnostic_backend.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,8 @@ def test_current_platform_cache_roundtrip_with_alias_class(temp_location):
4444
def test_persisted_token_cache(temp_location):
4545
_test_token_cache_roundtrip(PersistedTokenCache(FilePersistence(temp_location)))
4646

47+
def test_file_not_found_error_is_not_raised():
48+
persistence = FilePersistence('non_existing_file')
49+
cache = PersistedTokenCache(persistence=persistence)
50+
# An exception raised here will fail the test case as it is supposed to be a NO-OP
51+
cache.find('')

tests/test_cache_lock_file_perf.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import multiprocessing
2+
import os
3+
import shutil
4+
import tempfile
5+
6+
import pytest
7+
8+
from cache_file_generator import _acquire_lock_and_write_to_cache
9+
10+
11+
@pytest.fixture
12+
def temp_location():
13+
test_folder = tempfile.mkdtemp(prefix="test_persistence_roundtrip")
14+
yield os.path.join(test_folder, 'persistence.bin')
15+
shutil.rmtree(test_folder, ignore_errors=True)
16+
17+
18+
def _validate_result_in_cache(cache_location):
19+
with open(cache_location) as handle:
20+
data = handle.read()
21+
prev_process_id = None
22+
count = 0
23+
for line in data.split("\n"):
24+
if line:
25+
count += 1
26+
tag, process_id = line.split(" ")
27+
if prev_process_id is not None:
28+
assert process_id == prev_process_id, "Process overlap found"
29+
assert tag == '>', "Process overlap_found"
30+
prev_process_id = None
31+
else:
32+
assert tag == '<', "Opening bracket not found"
33+
prev_process_id = process_id
34+
return count
35+
36+
37+
def _run_multiple_processes(no_of_processes, cache_location, sleep_interval):
38+
open(cache_location, "w+")
39+
processes = []
40+
for i in range(no_of_processes):
41+
process = multiprocessing.Process(
42+
target=_acquire_lock_and_write_to_cache,
43+
args=(cache_location, sleep_interval))
44+
processes.append(process)
45+
46+
for process in processes:
47+
process.start()
48+
49+
for process in processes:
50+
process.join()
51+
52+
53+
def test_lock_for_normal_workload(temp_location):
54+
num_of_processes = 4
55+
sleep_interval = 0.1
56+
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
57+
count = _validate_result_in_cache(temp_location)
58+
assert count == num_of_processes * 2, "Should not observe starvation"
59+
60+
61+
def test_lock_for_high_workload(temp_location):
62+
num_of_processes = 80
63+
sleep_interval = 0
64+
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
65+
count = _validate_result_in_cache(temp_location)
66+
assert count <= num_of_processes * 2, "Starvation or not, we should not observe garbled payload"
67+
68+
69+
def test_lock_for_timeout(temp_location):
70+
num_of_processes = 10
71+
sleep_interval = 1
72+
_run_multiple_processes(num_of_processes, temp_location, sleep_interval)
73+
count = _validate_result_in_cache(temp_location)
74+
assert count < num_of_processes * 2, "Should observe starvation"
75+

0 commit comments

Comments
 (0)