From 82d6b8f91ac8bccec1fc33de2fe561a043038214 Mon Sep 17 00:00:00 2001 From: Nils1729 <45318774+Nils1729@users.noreply.github.com> Date: Thu, 17 Oct 2024 22:13:05 +0200 Subject: [PATCH 1/2] Stringify data before sending it to `escape` For us, the CLI constantly printed errors and terminated because `escape` only accepts strings, bytes etc, not JSON-like dicts. --- ataka/player-cli/player_cli/util.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/ataka/player-cli/player_cli/util.py b/ataka/player-cli/player_cli/util.py index 7caf6ce..35fb4f3 100644 --- a/ataka/player-cli/player_cli/util.py +++ b/ataka/player-cli/player_cli/util.py @@ -50,12 +50,12 @@ def request(method, endpoint, data=None, params=None): if player_cli.state['debug']: print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(f"{method} {endpoint}{'' if params is None else f' with params {params}'}")) if data is not None: - print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(data)) + print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(str(data))) print(f"{DEBUG_STR}: ") result = player_cli.ctfconfig_wrapper.request(method, endpoint, data=data) if player_cli.state['debug']: - print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(result)) + print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(str(result))) return result url = f'http://{player_cli.state["host"]}/api/{endpoint}' @@ -63,24 +63,18 @@ def request(method, endpoint, data=None, params=None): if player_cli.state['debug']: print(f"{DEBUG_STR}: " + escape(f"{method} {url}{'' if params is None else f' with params {params}'}")) if data is not None: - print(f"{DEBUG_STR}: " + escape(data)) + print(f"{DEBUG_STR}: " + escape(str(data))) print(f"{DEBUG_STR}: ") - func = { - 'GET': requests.get, - 'PUT': requests.put, - 'POST': requests.post, - 'PATCH': requests.patch, - }[method] - response = func(url, json=data, params=params) + response = requests.request(method, url, json=data, params=params) if player_cli.state['debug']: print(f"{DEBUG_STR}: " + escape(f"{response.status_code} {response.reason}")) - print(f"{DEBUG_STR}: " + escape(response.json())) + print(f"{DEBUG_STR}: " + escape(str(response.json()))) if response.status_code != 200: print(f"{ERROR_STR}: " + escape(f"{method} {endpoint} returned status code {response.status_code} {response.reason}")) try: - print(f"{ERROR_STR}: " + escape(response.json())) + print(f"{ERROR_STR}: " + escape(str(response.json()))) except JSONDecodeError: print(f"{ERROR_STR}: " + escape(response.text)) raise typer.Exit(code=1) From 35db0ffa2508e44e07ad75650d8beed9cb026012 Mon Sep 17 00:00:00 2001 From: Nils1729 <45318774+Nils1729@users.noreply.github.com> Date: Thu, 17 Oct 2024 22:16:50 +0200 Subject: [PATCH 2/2] Remove flag queue Also rewrite quite a lot of `flags.py`. This almost but not quite the setup we ran for ECSC2024. I removed some indirections we used to handle the new RESUBMIT flag status specific to that event. --- ataka/common/database/models/execution.py | 2 +- ataka/common/database/models/flag.py | 2 +- ataka/common/queue/__init__.py | 1 - ataka/common/queue/flag.py | 14 -- ataka/ctfcode/flags.py | 257 +++++++++++++--------- 5 files changed, 156 insertions(+), 120 deletions(-) delete mode 100644 ataka/common/queue/flag.py diff --git a/ataka/common/database/models/execution.py b/ataka/common/database/models/execution.py index 2c2d1e1..0b08145 100644 --- a/ataka/common/database/models/execution.py +++ b/ataka/common/database/models/execution.py @@ -14,7 +14,7 @@ class Execution(Base, JsonBase): status = Column(Enum(JobExecutionStatus)) stdout = Column(UnicodeText) stderr = Column(UnicodeText) - timestamp = Column(DateTime(timezone=True), server_default=func.now()) + timestamp = Column(DateTime(timezone=True), server_default=func.now(), index=True) job = relationship("Job", back_populates="executions") target = relationship("Target") diff --git a/ataka/common/database/models/flag.py b/ataka/common/database/models/flag.py index 622f7ab..dbcbfc6 100644 --- a/ataka/common/database/models/flag.py +++ b/ataka/common/database/models/flag.py @@ -10,7 +10,7 @@ class Flag(Base, JsonBase): id = Column(Integer, primary_key=True) flag = Column(String, index=True) - status = Column(Enum(FlagStatus)) + status = Column(Enum(FlagStatus), index=True) timestamp = Column(DateTime(timezone=True), server_default=func.now()) execution_id = Column(Integer, ForeignKey("executions.id"), index=True) diff --git a/ataka/common/queue/__init__.py b/ataka/common/queue/__init__.py index 05657c6..122b99b 100644 --- a/ataka/common/queue/__init__.py +++ b/ataka/common/queue/__init__.py @@ -5,7 +5,6 @@ import aio_pika from aio_pika import RobustConnection -from .flag import * from .job import * from .output import * diff --git a/ataka/common/queue/flag.py b/ataka/common/queue/flag.py deleted file mode 100644 index a4edb6a..0000000 --- a/ataka/common/queue/flag.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from .queue import WorkQueue, Message - - -@dataclass -class FlagMessage(Message): - flag_id: int - flag: str - - -class FlagQueue(WorkQueue): - queue_name = "flag" - message_type = FlagMessage diff --git a/ataka/ctfcode/flags.py b/ataka/ctfcode/flags.py index babb836..a54e829 100644 --- a/ataka/ctfcode/flags.py +++ b/ataka/ctfcode/flags.py @@ -1,117 +1,175 @@ import re import time -from asyncio import TimeoutError, sleep +from asyncio import sleep, TaskGroup +from typing import NamedTuple +from collections import Counter, defaultdict +from itertools import islice, chain -import asyncio -from sqlalchemy import update, select +from sqlalchemy import update, select, func +from sqlalchemy.ext.asyncio import AsyncSession from ataka.common import database from ataka.common.database.models import Flag from ataka.common.flag_status import FlagStatus, DuplicatesDontResubmitFlagStatus -from ataka.common.queue import FlagQueue, get_channel, FlagMessage, OutputQueue +from ataka.common.queue import get_channel, OutputQueue from .ctf import CTF +class FlagInfo(NamedTuple): + flag_id: int + flag: str + status: FlagStatus + + class Flags: def __init__(self, ctf: CTF): self._ctf = ctf - self._flag_cache = {} + self._flags_submitted_id: dict[str, int] = {} + ''' + Each flag string is submitted with one Flag object. This is a cache for the id of this object. + Specifically, the ID in this dict dictates which Flag object should be submitted for a flag string, + and block others with the same flag string from being submitted. + ''' + + def _cache_is_duplicate(self, flag_info: FlagInfo) -> bool: + '''Determines if a flag is a duplicate. May produce false negatives but no false positives.''' + flag_id, flag, _status = flag_info + return flag in self._flags_submitted_id and flag_id != self._flags_submitted_id[flag] + + def _cache_set_flag(self, flag_id: int, flag: str, dont_resubmit: bool): + match (flag in self._flags_submitted_id, dont_resubmit): + case (False, True): + # This flag blocks resubmission + self._flags_submitted_id[flag] = flag_id + case (True, False) if flag_id == self._flags_submitted_id[flag]: + # This flag used to block resubmission but does not anymore + self._flags_submitted_id.pop(flag) + + async def set_flags_status( + self, + session: AsyncSession, + flag_infos: list[FlagInfo], + status: FlagStatus + ): + ''' + Use `session` to set the status of all flags in `flag_infos` to `status`. + ''' + if len(flag_infos) == 0: + print(f'No flags to mark as {status}') + return + flag_ids, _, statuses = zip(*flag_infos) + status_counted = Counter(statuses) + prev_status_str = ', '.join(f'{status.name}: {count}' for status, count in status_counted.items()) + print(f'Marking {len(flag_infos)} flags as {status}, original statuses: {prev_status_str}') + + for flag_id, flag, _status in flag_infos: + self._cache_set_flag(flag_id, flag, status in DuplicatesDontResubmitFlagStatus) + # Postgres has a limit of 32k parameters in prepared statements, including `IN (?, ?, ?, ... )` + iterator = iter(flag_ids) + while batch := tuple(islice(iterator, 30_000)): + set_status = update(Flag) \ + .where(Flag.id.in_(batch)) \ + .values(status=status) + await session.execute(set_status) async def poll_and_submit_flags(self): - async with get_channel() as channel: - flag_queue = await FlagQueue.get(channel) - last_submit = time.time() - - async with database.get_session() as session: - while True: - batchsize = self._ctf.get_flag_batchsize() - ratelimit = self._ctf.get_flag_ratelimit() - - queue_init = select(Flag).where(Flag.status.in_({FlagStatus.PENDING, FlagStatus.ERROR})) - init_list = list((await session.execute(queue_init)).scalars()) - - submit_list = [FlagMessage(flag.id, flag.flag) for flag in init_list if flag.status == FlagStatus.PENDING] - resubmit_list = [FlagMessage(flag.id, flag.flag) for flag in init_list if flag.status == FlagStatus.ERROR] - dupe_list = [] - try: - async for message in flag_queue.wait_for_messages(timeout=ratelimit): - flag_id = message.flag_id - flag = message.flag - #print(f"Got flag {flag}, cache {'NOPE' if flag not in self._flag_cache else self._flag_cache[flag]}") - - check_duplicates = select(Flag) \ - .where(Flag.id != flag_id) \ - .where(Flag.flag == flag) \ - .where(Flag.status.in_(DuplicatesDontResubmitFlagStatus)) \ - .limit(1) - duplicate = (await session.execute(check_duplicates)).scalars().first() - if duplicate is not None: - dupe_list.append(flag_id) - self._flag_cache[flag] = FlagStatus.DUPLICATE_NOT_SUBMITTED - else: - submit_list.append(message) - self._flag_cache[flag] = FlagStatus.PENDING - - if len(submit_list) >= batchsize: - break - except TimeoutError as e: - pass - - if len(dupe_list) > 0: - print(f"Dupe list of size {len(dupe_list)}") - set_duplicates = update(Flag)\ - .where(Flag.id.in_(dupe_list))\ - .values(status=FlagStatus.DUPLICATE_NOT_SUBMITTED) - await session.execute(set_duplicates) - await session.commit() - - if len(submit_list) < batchsize and len(resubmit_list) > 0: - resubmit_amount = min(batchsize-len(submit_list), len(resubmit_list)) - print(f"Got leftover capacity, resubmitting {resubmit_amount} errored flags " - f"({len(resubmit_list) - resubmit_amount} remaining)") - - submit_list += resubmit_list[:resubmit_amount] - resubmit_list = resubmit_list[resubmit_amount:] - - if len(submit_list) > 0: - set_pending = update(Flag) \ - .where(Flag.id.in_([x.flag_id for x in submit_list])) \ - .values(status=FlagStatus.PENDING) \ - .returning(Flag) - result = list((await session.execute(set_pending)).scalars()) - await session.commit() - - diff = time.time() - last_submit - print(f"Submitting {len(submit_list)} flags, {diff:.2f}s since last time" + - (f" (sleeping {ratelimit-diff:.2f})" if diff < ratelimit else "")) - if diff < ratelimit: - await sleep(ratelimit-diff) - last_submit = time.time() - - statuslist = self._ctf.submit_flags([flag.flag for flag in result]) - print(f"Done submitting ({statuslist.count(FlagStatus.OK)} ok)") - - for flag, status in zip(result, statuslist): - #print(flag.id, flag.flag, status) - flag.status = status - self._flag_cache[flag.flag] = status - - if status == FlagStatus.ERROR: - resubmit_list.append(FlagMessage(flag.id, flag.flag)) - - await session.commit() + last_submit = time.time() + + async with database.get_session() as session: + while True: + batchsize = self._ctf.get_flag_batchsize() + ratelimit = self._ctf.get_flag_ratelimit() + + flag_status_priorities = [FlagStatus.PENDING, FlagStatus.QUEUED, FlagStatus.ERROR] + + # Collect potentially submittable flags + flag_infos_query = select(Flag.id, Flag.flag, Flag.status) \ + .where(Flag.status.in_(flag_status_priorities)) + flag_infos = map(lambda tuple : FlagInfo(*tuple), (await session.execute(flag_infos_query)).fetchall()) + flag_infos_by_status = defaultdict[str, list[FlagInfo]](list) + for flag_info in flag_infos: + flag_infos_by_status[flag_info.status].append(flag_info) + flag_infos_prioritized = chain(*(flag_infos_by_status[status] for status in flag_status_priorities)) + del flag_infos_query, flag_infos, flag_infos_by_status + + # Deduplicate QUEUED flags + duplicates: list[FlagInfo] = [] + maybe_duplicates: list[FlagInfo] = [] + non_duplicates: list[FlagInfo] = [] + + # # Find potentially duplicate flags + for flag_info in flag_infos_prioritized: + if flag_info.status == FlagStatus.QUEUED: + if self._cache_is_duplicate(flag_info): + duplicates.append(flag_info) + else: + maybe_duplicates.append(flag_info) + else: + non_duplicates.append(flag_info) + del flag_infos_prioritized + + # # Fill cache + if len(maybe_duplicates) > 0: + flags = set(flag_info.flag for flag_info in maybe_duplicates) + iterator = iter(flags) + while batch := list(islice(iterator, 30_000)): + dont_resubmit_flags_query = select(func.min(Flag.id), Flag.flag) \ + .where(Flag.flag.in_(batch)) \ + .where(Flag.status.in_(DuplicatesDontResubmitFlagStatus)) \ + .group_by(Flag.flag) + dont_resubmit_flags = (await session.execute(dont_resubmit_flags_query)).all() + for flag_id, flag_info in dont_resubmit_flags: + self._cache_set_flag(flag_id, flag_info, True) + + # # Mark duplicates using the new cache + for flag_info in maybe_duplicates: + if self._cache_is_duplicate(flag_info): + duplicates.append(flag_info) else: - print("No flags for now") + non_duplicates.append(flag_info) + + if len(duplicates): + await self.set_flags_status(session, duplicates, FlagStatus.DUPLICATE_NOT_SUBMITTED) + await session.commit() + + # Take first max. batchsize flags + flag_infos_to_submit = non_duplicates[:batchsize] + + del non_duplicates, maybe_duplicates, duplicates + # Submit flags + if len(flag_infos_to_submit): + await self.set_flags_status(session, flag_infos_to_submit, FlagStatus.PENDING) + + diff = time.time() - last_submit + print(f"Prepared {len(flag_infos_to_submit)} flags for submission, {diff:.2f}s since last time" + + (f" (sleeping {ratelimit-diff:.2f})" if diff < ratelimit else "")) + if diff < ratelimit: + await sleep(ratelimit-diff) + last_submit = time.time() + + statuslist = self._ctf.submit_flags([flag_info.flag for flag_info in flag_infos_to_submit]) + print(f"Done submitting ({statuslist.count(FlagStatus.OK)} ok)") + + flag_infos_by_status = defaultdict[str, list[FlagInfo]](list) + for flag_info, status in zip(flag_infos_to_submit, statuslist): + flag_infos_by_status[status].append(flag_info) + + async with TaskGroup() as tg: + for status, flag_infos in flag_infos_by_status.items(): + tg.create_task(self.set_flags_status(session, flag_infos, status)) + await session.commit() + else: + print("No flags for now") + await sleep(ratelimit) + del flag_infos_to_submit async def poll_and_parse_output(self): async with get_channel() as channel: - flag_queue = await FlagQueue.get(channel) output_queue = await OutputQueue.get(channel) async with database.get_session() as session: async for message in output_queue.wait_for_messages(): regex, group = self._ctf.get_flag_regex() - submissions = [] - duplicates = [] + flags_objs = [] for match in re.finditer(regex, message.output): if match.start(group) == -1 or match.end(group) == -1: continue @@ -119,20 +177,13 @@ async def poll_and_parse_output(self): flag = match.group(group) flag_obj = Flag(flag=flag, status=FlagStatus.QUEUED, execution_id=message.execution_id, stdout=message.stdout, start=match.start(group), end=match.end(group)) - if flag in self._flag_cache and self._flag_cache[flag] in DuplicatesDontResubmitFlagStatus: + if self._cache_is_duplicate((flag_obj.id, flag_obj.flag, flag_obj.status)): flag_obj.status = FlagStatus.DUPLICATE_NOT_SUBMITTED - duplicates.append(flag_obj) - else: - submissions.append(flag_obj) - self._flag_cache[flag] = flag_obj.status + flags_objs.append(flag_obj) - if len(submissions) + len(duplicates) == 0: + if len(flags_objs) == 0: continue - session.add_all(submissions + duplicates) + session.add_all(flags_objs) await session.commit() - - if len(submissions) > 0: - await asyncio.gather(*[ - flag_queue.send_message(FlagMessage(flag_id=f.id, flag=f.flag)) - for f in submissions]) + del flag_obj, flags_objs