Skip to content

Commit 67678c2

Browse files
authored
Cache output of asm() (#2358)
* Cache output of `asm()` To speed up repeated runs of an exploit, cache the assembled output. Use a sha1 hash of the shellcode as well as relevant context values like `context.arch` and `context.bits` to see if the exact same shellcode was assembled for the same context before. Fixes #2312 * Return path to cache file if `not extract` * Update CHANGELOG * Create temporary copy of cached file * Add debug log about using the cache * Include full assembler and linker commandlines in hash This should catch any changes across pwntools updates and system environment changes. * Include pwntools version in hash
1 parent 734cb3b commit 67678c2

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ The table below shows which release corresponds to each branch, and what date th
7272

7373
## 4.15.0 (`dev`)
7474

75+
- [#2358][2358] Cache output of `asm()`
7576

77+
[2358]: https://github.com/Gallopsled/pwntools/pull/2358
7678

7779
## 4.14.0 (`beta`)
7880

pwnlib/asm.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
from pwnlib.context import LocalContext
6060
from pwnlib.context import context
6161
from pwnlib.log import getLogger
62+
from pwnlib.util.hashes import sha1sumhex
63+
from pwnlib.util.packing import _encode
64+
from pwnlib.version import __version__
6265

6366
log = getLogger(__name__)
6467

@@ -758,8 +761,21 @@ def asm(shellcode, vma = 0, extract = True, shared = False):
758761
b'0@*\x00'
759762
>>> asm("la %r0, 42", arch = 's390', bits=64)
760763
b'A\x00\x00*'
764+
765+
The output is cached:
766+
767+
>>> start = time.time()
768+
>>> asm("lea rax, [rip+0]", arch = 'amd64')
769+
b'H\x8d\x05\x00\x00\x00\x00'
770+
>>> uncached_time = time.time() - start
771+
>>> start = time.time()
772+
>>> asm("lea rax, [rip+0]", arch = 'amd64')
773+
b'H\x8d\x05\x00\x00\x00\x00'
774+
>>> cached_time = time.time() - start
775+
>>> uncached_time > cached_time
776+
True
761777
"""
762-
result = ''
778+
result = b''
763779

764780
assembler = _assembler()
765781
linker = _linker()
@@ -770,6 +786,30 @@ def asm(shellcode, vma = 0, extract = True, shared = False):
770786

771787
log.debug('Assembling\n%s' % code)
772788

789+
cache_file = None
790+
if context.cache_dir:
791+
cache_dir = os.path.join(context.cache_dir, 'asm-cache')
792+
if not os.path.isdir(cache_dir):
793+
os.makedirs(cache_dir)
794+
795+
# Include the context in the hash in addition to the shellcode
796+
hash_params = '{}_{}_{}_{}'.format(vma, extract, shared, __version__)
797+
fingerprint_params = _encode(code) + _encode(hash_params) + _encode(' '.join(assembler)) + _encode(' '.join(linker)) + _encode(' '.join(objcopy))
798+
asm_hash = sha1sumhex(fingerprint_params)
799+
cache_file = os.path.join(cache_dir, asm_hash)
800+
if os.path.exists(cache_file):
801+
log.debug('Using cached assembly output from %r', cache_file)
802+
if extract:
803+
with open(cache_file, 'rb') as f:
804+
return f.read()
805+
806+
# Create a temporary copy of the cached file to avoid modification.
807+
tmpdir = tempfile.mkdtemp(prefix = 'pwn-asm-')
808+
atexit.register(shutil.rmtree, tmpdir)
809+
step3 = os.path.join(tmpdir, 'step3')
810+
shutil.copy(cache_file, step3)
811+
return step3
812+
773813
tmpdir = tempfile.mkdtemp(prefix = 'pwn-asm-')
774814
step1 = path.join(tmpdir, 'step1')
775815
step2 = path.join(tmpdir, 'step2')
@@ -817,6 +857,8 @@ def asm(shellcode, vma = 0, extract = True, shared = False):
817857
shutil.copy(step2, step3)
818858

819859
if not extract:
860+
if cache_file is not None:
861+
shutil.copy(step3, cache_file)
820862
return step3
821863

822864
_run(objcopy + [step3, step4])
@@ -830,6 +872,10 @@ def asm(shellcode, vma = 0, extract = True, shared = False):
830872
else:
831873
atexit.register(lambda: shutil.rmtree(tmpdir))
832874

875+
if cache_file is not None and result != b'':
876+
with open(cache_file, 'wb') as f:
877+
f.write(result)
878+
833879
return result
834880

835881
@LocalContext

0 commit comments

Comments
 (0)