Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECSC Geramy #12

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ataka/common/database/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Execution(Base, JsonBase):
status = Column(Enum(JobExecutionStatus))
stdout = Column(UnicodeText)
stderr = Column(UnicodeText)
timestamp = Column(DateTime(timezone=True), server_default=func.now())
timestamp = Column(DateTime(timezone=True), server_default=func.now(), index=True)

job = relationship("Job", back_populates="executions")
target = relationship("Target")
Expand Down
2 changes: 1 addition & 1 deletion ataka/common/database/models/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Flag(Base, JsonBase):

id = Column(Integer, primary_key=True)
flag = Column(String, index=True)
status = Column(Enum(FlagStatus))
status = Column(Enum(FlagStatus), index=True)
timestamp = Column(DateTime(timezone=True), server_default=func.now())

execution_id = Column(Integer, ForeignKey("executions.id"), index=True)
Expand Down
1 change: 0 additions & 1 deletion ataka/common/queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import aio_pika
from aio_pika import RobustConnection

from .flag import *
from .job import *
from .output import *

Expand Down
14 changes: 0 additions & 14 deletions ataka/common/queue/flag.py

This file was deleted.

257 changes: 154 additions & 103 deletions ataka/ctfcode/flags.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,189 @@
import re
import time
from asyncio import TimeoutError, sleep
from asyncio import sleep, TaskGroup
from typing import NamedTuple
from collections import Counter, defaultdict
from itertools import islice, chain

import asyncio
from sqlalchemy import update, select
from sqlalchemy import update, select, func
from sqlalchemy.ext.asyncio import AsyncSession

from ataka.common import database
from ataka.common.database.models import Flag
from ataka.common.flag_status import FlagStatus, DuplicatesDontResubmitFlagStatus
from ataka.common.queue import FlagQueue, get_channel, FlagMessage, OutputQueue
from ataka.common.queue import get_channel, OutputQueue
from .ctf import CTF


class FlagInfo(NamedTuple):
flag_id: int
flag: str
status: FlagStatus


class Flags:
def __init__(self, ctf: CTF):
self._ctf = ctf
self._flag_cache = {}
self._flags_submitted_id: dict[str, int] = {}
'''
Each flag string is submitted with one Flag object. This is a cache for the id of this object.
Specifically, the ID in this dict dictates which Flag object should be submitted for a flag string,
and block others with the same flag string from being submitted.
'''

def _cache_is_duplicate(self, flag_info: FlagInfo) -> bool:
'''Determines if a flag is a duplicate. May produce false negatives but no false positives.'''
flag_id, flag, _status = flag_info
return flag in self._flags_submitted_id and flag_id != self._flags_submitted_id[flag]

def _cache_set_flag(self, flag_id: int, flag: str, dont_resubmit: bool):
match (flag in self._flags_submitted_id, dont_resubmit):
case (False, True):
# This flag blocks resubmission
self._flags_submitted_id[flag] = flag_id
case (True, False) if flag_id == self._flags_submitted_id[flag]:
# This flag used to block resubmission but does not anymore
self._flags_submitted_id.pop(flag)

async def set_flags_status(
self,
session: AsyncSession,
flag_infos: list[FlagInfo],
status: FlagStatus
):
'''
Use `session` to set the status of all flags in `flag_infos` to `status`.
'''
if len(flag_infos) == 0:
print(f'No flags to mark as {status}')
return
flag_ids, _, statuses = zip(*flag_infos)
status_counted = Counter(statuses)
prev_status_str = ', '.join(f'{status.name}: {count}' for status, count in status_counted.items())
print(f'Marking {len(flag_infos)} flags as {status}, original statuses: {prev_status_str}')

for flag_id, flag, _status in flag_infos:
self._cache_set_flag(flag_id, flag, status in DuplicatesDontResubmitFlagStatus)
# Postgres has a limit of 32k parameters in prepared statements, including `IN (?, ?, ?, ... )`
iterator = iter(flag_ids)
while batch := tuple(islice(iterator, 30_000)):
set_status = update(Flag) \
.where(Flag.id.in_(batch)) \
.values(status=status)
await session.execute(set_status)

async def poll_and_submit_flags(self):
async with get_channel() as channel:
flag_queue = await FlagQueue.get(channel)
last_submit = time.time()

async with database.get_session() as session:
while True:
batchsize = self._ctf.get_flag_batchsize()
ratelimit = self._ctf.get_flag_ratelimit()

queue_init = select(Flag).where(Flag.status.in_({FlagStatus.PENDING, FlagStatus.ERROR}))
init_list = list((await session.execute(queue_init)).scalars())

submit_list = [FlagMessage(flag.id, flag.flag) for flag in init_list if flag.status == FlagStatus.PENDING]
resubmit_list = [FlagMessage(flag.id, flag.flag) for flag in init_list if flag.status == FlagStatus.ERROR]
dupe_list = []
try:
async for message in flag_queue.wait_for_messages(timeout=ratelimit):
flag_id = message.flag_id
flag = message.flag
#print(f"Got flag {flag}, cache {'NOPE' if flag not in self._flag_cache else self._flag_cache[flag]}")

check_duplicates = select(Flag) \
.where(Flag.id != flag_id) \
.where(Flag.flag == flag) \
.where(Flag.status.in_(DuplicatesDontResubmitFlagStatus)) \
.limit(1)
duplicate = (await session.execute(check_duplicates)).scalars().first()
if duplicate is not None:
dupe_list.append(flag_id)
self._flag_cache[flag] = FlagStatus.DUPLICATE_NOT_SUBMITTED
else:
submit_list.append(message)
self._flag_cache[flag] = FlagStatus.PENDING

if len(submit_list) >= batchsize:
break
except TimeoutError as e:
pass

if len(dupe_list) > 0:
print(f"Dupe list of size {len(dupe_list)}")
set_duplicates = update(Flag)\
.where(Flag.id.in_(dupe_list))\
.values(status=FlagStatus.DUPLICATE_NOT_SUBMITTED)
await session.execute(set_duplicates)
await session.commit()

if len(submit_list) < batchsize and len(resubmit_list) > 0:
resubmit_amount = min(batchsize-len(submit_list), len(resubmit_list))
print(f"Got leftover capacity, resubmitting {resubmit_amount} errored flags "
f"({len(resubmit_list) - resubmit_amount} remaining)")

submit_list += resubmit_list[:resubmit_amount]
resubmit_list = resubmit_list[resubmit_amount:]

if len(submit_list) > 0:
set_pending = update(Flag) \
.where(Flag.id.in_([x.flag_id for x in submit_list])) \
.values(status=FlagStatus.PENDING) \
.returning(Flag)
result = list((await session.execute(set_pending)).scalars())
await session.commit()

diff = time.time() - last_submit
print(f"Submitting {len(submit_list)} flags, {diff:.2f}s since last time" +
(f" (sleeping {ratelimit-diff:.2f})" if diff < ratelimit else ""))
if diff < ratelimit:
await sleep(ratelimit-diff)
last_submit = time.time()

statuslist = self._ctf.submit_flags([flag.flag for flag in result])
print(f"Done submitting ({statuslist.count(FlagStatus.OK)} ok)")

for flag, status in zip(result, statuslist):
#print(flag.id, flag.flag, status)
flag.status = status
self._flag_cache[flag.flag] = status

if status == FlagStatus.ERROR:
resubmit_list.append(FlagMessage(flag.id, flag.flag))

await session.commit()
last_submit = time.time()

async with database.get_session() as session:
while True:
batchsize = self._ctf.get_flag_batchsize()
ratelimit = self._ctf.get_flag_ratelimit()

flag_status_priorities = [FlagStatus.PENDING, FlagStatus.QUEUED, FlagStatus.ERROR]

# Collect potentially submittable flags
flag_infos_query = select(Flag.id, Flag.flag, Flag.status) \
.where(Flag.status.in_(flag_status_priorities))
flag_infos = map(lambda tuple : FlagInfo(*tuple), (await session.execute(flag_infos_query)).fetchall())
flag_infos_by_status = defaultdict[str, list[FlagInfo]](list)
for flag_info in flag_infos:
flag_infos_by_status[flag_info.status].append(flag_info)
flag_infos_prioritized = chain(*(flag_infos_by_status[status] for status in flag_status_priorities))
del flag_infos_query, flag_infos, flag_infos_by_status

# Deduplicate QUEUED flags
duplicates: list[FlagInfo] = []
maybe_duplicates: list[FlagInfo] = []
non_duplicates: list[FlagInfo] = []

# # Find potentially duplicate flags
for flag_info in flag_infos_prioritized:
if flag_info.status == FlagStatus.QUEUED:
if self._cache_is_duplicate(flag_info):
duplicates.append(flag_info)
else:
maybe_duplicates.append(flag_info)
else:
non_duplicates.append(flag_info)
del flag_infos_prioritized

# # Fill cache
if len(maybe_duplicates) > 0:
flags = set(flag_info.flag for flag_info in maybe_duplicates)
iterator = iter(flags)
while batch := list(islice(iterator, 30_000)):
dont_resubmit_flags_query = select(func.min(Flag.id), Flag.flag) \
.where(Flag.flag.in_(batch)) \
.where(Flag.status.in_(DuplicatesDontResubmitFlagStatus)) \
.group_by(Flag.flag)
dont_resubmit_flags = (await session.execute(dont_resubmit_flags_query)).all()
for flag_id, flag_info in dont_resubmit_flags:
self._cache_set_flag(flag_id, flag_info, True)

# # Mark duplicates using the new cache
for flag_info in maybe_duplicates:
if self._cache_is_duplicate(flag_info):
duplicates.append(flag_info)
else:
print("No flags for now")
non_duplicates.append(flag_info)

if len(duplicates):
await self.set_flags_status(session, duplicates, FlagStatus.DUPLICATE_NOT_SUBMITTED)
await session.commit()

# Take first max. batchsize flags
flag_infos_to_submit = non_duplicates[:batchsize]

del non_duplicates, maybe_duplicates, duplicates
# Submit flags
if len(flag_infos_to_submit):
await self.set_flags_status(session, flag_infos_to_submit, FlagStatus.PENDING)

diff = time.time() - last_submit
print(f"Prepared {len(flag_infos_to_submit)} flags for submission, {diff:.2f}s since last time" +
(f" (sleeping {ratelimit-diff:.2f})" if diff < ratelimit else ""))
if diff < ratelimit:
await sleep(ratelimit-diff)
last_submit = time.time()

statuslist = self._ctf.submit_flags([flag_info.flag for flag_info in flag_infos_to_submit])
print(f"Done submitting ({statuslist.count(FlagStatus.OK)} ok)")

flag_infos_by_status = defaultdict[str, list[FlagInfo]](list)
for flag_info, status in zip(flag_infos_to_submit, statuslist):
flag_infos_by_status[status].append(flag_info)

async with TaskGroup() as tg:
for status, flag_infos in flag_infos_by_status.items():
tg.create_task(self.set_flags_status(session, flag_infos, status))
await session.commit()
else:
print("No flags for now")
await sleep(ratelimit)
del flag_infos_to_submit

async def poll_and_parse_output(self):
async with get_channel() as channel:
flag_queue = await FlagQueue.get(channel)
output_queue = await OutputQueue.get(channel)
async with database.get_session() as session:
async for message in output_queue.wait_for_messages():
regex, group = self._ctf.get_flag_regex()
submissions = []
duplicates = []
flags_objs = []
for match in re.finditer(regex, message.output):
if match.start(group) == -1 or match.end(group) == -1:
continue

flag = match.group(group)
flag_obj = Flag(flag=flag, status=FlagStatus.QUEUED, execution_id=message.execution_id,
stdout=message.stdout, start=match.start(group), end=match.end(group))
if flag in self._flag_cache and self._flag_cache[flag] in DuplicatesDontResubmitFlagStatus:
if self._cache_is_duplicate((flag_obj.id, flag_obj.flag, flag_obj.status)):
flag_obj.status = FlagStatus.DUPLICATE_NOT_SUBMITTED
duplicates.append(flag_obj)
else:
submissions.append(flag_obj)
self._flag_cache[flag] = flag_obj.status
flags_objs.append(flag_obj)

if len(submissions) + len(duplicates) == 0:
if len(flags_objs) == 0:
continue

session.add_all(submissions + duplicates)
session.add_all(flags_objs)
await session.commit()

if len(submissions) > 0:
await asyncio.gather(*[
flag_queue.send_message(FlagMessage(flag_id=f.id, flag=f.flag))
for f in submissions])
del flag_obj, flags_objs
18 changes: 6 additions & 12 deletions ataka/player-cli/player_cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,31 @@ def request(method, endpoint, data=None, params=None):
if player_cli.state['debug']:
print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(f"{method} {endpoint}{'' if params is None else f' with params {params}'}"))
if data is not None:
print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(data))
print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(str(data)))
print(f"{DEBUG_STR}: ")

result = player_cli.ctfconfig_wrapper.request(method, endpoint, data=data)
if player_cli.state['debug']:
print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(result))
print(f"{DEBUG_STR}: {yellowfy('BYPASS')} " + escape(str(result)))
return result

url = f'http://{player_cli.state["host"]}/api/{endpoint}'

if player_cli.state['debug']:
print(f"{DEBUG_STR}: " + escape(f"{method} {url}{'' if params is None else f' with params {params}'}"))
if data is not None:
print(f"{DEBUG_STR}: " + escape(data))
print(f"{DEBUG_STR}: " + escape(str(data)))
print(f"{DEBUG_STR}: ")

func = {
'GET': requests.get,
'PUT': requests.put,
'POST': requests.post,
'PATCH': requests.patch,
}[method]
response = func(url, json=data, params=params)
response = requests.request(method, url, json=data, params=params)
if player_cli.state['debug']:
print(f"{DEBUG_STR}: " + escape(f"{response.status_code} {response.reason}"))
print(f"{DEBUG_STR}: " + escape(response.json()))
print(f"{DEBUG_STR}: " + escape(str(response.json())))

if response.status_code != 200:
print(f"{ERROR_STR}: " + escape(f"{method} {endpoint} returned status code {response.status_code} {response.reason}"))
try:
print(f"{ERROR_STR}: " + escape(response.json()))
print(f"{ERROR_STR}: " + escape(str(response.json())))
except JSONDecodeError:
print(f"{ERROR_STR}: " + escape(response.text))
raise typer.Exit(code=1)
Expand Down