Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Path Matching and Add State Caching for Performance #2093

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/seer/automation/autofix/autofix_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,23 @@ def get_commit_patch_for_file(

def _process_stacktrace_paths(self, stacktrace: Stacktrace):
"""
best_match = None
best_score = 0.0

Annotate a stacktrace with the correct repo each frame is pointing to and fix the filenames
"""
for repo in self.repos:
if repo.provider not in RepoClient.supported_providers:
continue
matches, score = potential_frame_match(valid_path, frame)
if matches and score > best_score:
best_match = valid_path
best_score = score

# Use match if confidence score is above threshold
if best_match and best_score >= 0.4:
frame.repo_name = repo.full_name
frame.filename = best_match
# Add logging for debugging purposes
logger.debug(
f"Matched frame path {frame.filename or frame.package} to {best_match} with score {best_score:.2f}"
)

try:
repo_client = self.get_repo_client(
Expand Down
101 changes: 73 additions & 28 deletions src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,56 @@ def _autocorrect_path(self, path: str, sha: str | None = None) -> tuple[str, boo
"""
if sha is None:
sha = self.base_commit_sha

path = path.lstrip("/")
valid_paths = self.get_valid_file_paths(sha)

# If path is valid, return it unchanged
if path in valid_paths:
return path, False

# Check for partial matches if no exact match and path is long enough
if len(path) > 3:
path_normalized = path_lower.lstrip("./")
candidates = []

for valid_path in valid_paths:
valid_path_normalized = valid_path.lower()
score = 0.0

# Strategy 1: Exact filename match
if path_normalized.split('/')[-1] == valid_path_normalized.split('/')[-1]:
score += 0.5

# Strategy 2: Path containment
if path_normalized in valid_path_normalized:
score += 0.3
elif valid_path_normalized in path_normalized:
score += 0.2

# Strategy 3: Component matching from the end
path_components = path_normalized.split('/')
valid_components = valid_path_normalized.split('/')
max_check = min(len(path_components), len(valid_components))
matches = 0

for i in range(1, max_check + 1):
if path_components[-i] == valid_components[-i]:
matches += 1
else:
break

if matches > 0:
score += 0.1 * matches

if score > 0:
candidates.append((valid_path, score))

# Sort by score, descending
candidates.sort(key=lambda x: x[1], reverse=True)

if candidates:
best_match, confidence = candidates[0]
if confidence >= 0.4: # Threshold for accepting a match
logger.info(
f"Path '{path}' not found exactly, using best match: '{best_match}' with confidence {confidence:.2f}"
)
path = best_match
autocorrected_path = True
else:
logger.warning(
f"No confident match found for path '{path}'. Best candidate '{best_match}' had low confidence ({confidence:.2f})"
)
path_lower = path.lower()
partial_matches = [
valid_path for valid_path in valid_paths if path_lower in valid_path.lower()
Expand All @@ -355,28 +395,33 @@ def get_file_content(
) -> tuple[str | None, str]:
logger.debug(f"Getting file contents for {path} in {self.repo.full_name} on sha {sha}")
if sha is None:
sha = self.base_commit_sha
@functools.lru_cache(maxsize=32) # Increased from 8 to 32

autocorrected_path = False
if autocorrect:
path, autocorrected_path = self._autocorrect_path(path, sha)
if not autocorrected_path and path not in self.get_valid_file_paths(sha):
return None, "utf-8"

try:
contents = self.repo.get_contents(path, ref=sha)

if isinstance(contents, list):
raise Exception(f"Expected a single ContentFile but got a list for path {path}")

detected_encoding = detect_encoding(contents.decoded_content) if contents else "utf-8"
content = contents.decoded_content.decode(detected_encoding)
if autocorrected_path:
content = f"Showing results instead for {path}\n=====\n{content}"
return content, detected_encoding
except Exception as e:
logger.exception(f"Error getting file contents: {e}")
return None, "utf-8"
try:
tree = self.repo.get_git_tree(sha, recursive=True)

if tree.raw_data["truncated"]:
sentry_sdk.capture_message(
f"Truncated tree for {self.repo.full_name}. This may cause issues with autofix."
)

valid_file_paths: set[str] = set()
valid_file_extensions = get_all_supported_extensions()

for file in tree.tree:
if file.type == "blob" and any(
file.path.endswith(ext) for ext in valid_file_extensions
):
valid_file_paths.add(file.path)

return valid_file_paths
except Exception as e:
logger.exception(f"Error getting valid file paths: {e}")
sentry_sdk.capture_exception(e)
return set() # Return empty set instead of failing

@functools.lru_cache(maxsize=8)
def get_valid_file_paths(self, sha: str | None = None, files_only=False) -> set[str]:
Expand Down
85 changes: 64 additions & 21 deletions src/seer/automation/codebase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,70 @@ def cleanup_dir(directory: str):
logger.info(f"Directory {directory} already cleaned!")


def potential_frame_match(src_file: str, frame: StacktraceFrame) -> bool:
"""Determine if the frame filename represents a source code file."""
match = False

src_split = src_file.split("/")[::-1]

filename = frame.filename or frame.package
if filename:
# Remove leading './' or '.' from filename
filename = filename.lstrip("./")
frame_split = filename.split("/")[::-1]

if len(src_split) > 0 and len(frame_split) > 0 and len(src_split) >= len(frame_split):
for i in range(len(frame_split)):
if src_split[i] == frame_split[i]:
match = True
else:
match = False
break

return match
def potential_frame_match(src_file: str, frame: StacktraceFrame) -> tuple[bool, float]:
"""
Determine if the frame filename represents a source code file.
Returns a tuple of (match_found, confidence_score) where confidence_score is a value from 0.0 to 1.0
indicating how confident we are in the match.
"""
# Normalize paths for comparison
def normalize_path(path):
if not path:
return ""
# Strip leading './' and '/'
path = path.lstrip("./").lstrip("/")
# Convert to lowercase for case-insensitive comparison
return path.lower()

src_normalized = normalize_path(src_file)
frame_path = frame.filename or frame.package
frame_normalized = normalize_path(frame_path)

if not frame_normalized:
return False, 0.0

# Quick exact match check
if src_normalized == frame_normalized:
return True, 1.0

# Component-wise matching (from the end)
src_components = src_normalized.split('/')
frame_components = frame_normalized.split('/')

# File name matching (highest priority)
if src_components and frame_components and src_components[-1] == frame_components[-1]:
# Filename matches are a good sign
base_score = 0.6
else:
# If filenames don't match, lower starting score
base_score = 0.3

# Check for path suffix match (e.g., "src/module/file.py" matches "module/file.py")
max_components = min(len(src_components), len(frame_components))
matching_components = 0

for i in range(1, max_components + 1):
if src_components[-i] == frame_components[-i]:
matching_components += 1
else:
break

if matching_components == 0:
return False, 0.0

# Calculate score based on matching components
component_score = matching_components / max(len(src_components), len(frame_components))

# Check if one path is contained in the other (lower priority, but still useful)
containment_score = 0.0
if src_normalized in frame_normalized or frame_normalized in src_normalized:
containment_score = 0.2

# Combine scores with appropriate weighting
final_score = base_score * 0.5 + component_score * 0.4 + containment_score * 0.1

# Only return true if we have a reasonable confidence
return final_score >= 0.4, final_score


def group_documents_by_language(documents: list[Document]) -> dict[str, list[Document]]:
Expand Down
49 changes: 49 additions & 0 deletions src/seer/automation/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import functools
import threading
import time # Add time import
from enum import Enum
from typing import Any, ContextManager, Generic, Iterator, Type, TypeVar

Expand All @@ -22,6 +23,43 @@ class DbStateRunTypes(str, Enum):
RELEVANT_WARNINGS = "relevant-warnings"


# Create an in-memory cache for state objects
_state_cache = {}
_state_cache_lock = threading.RLock()


def memoize_state_get(ttl_seconds=2):
"""Decorator to cache state.get() calls for a short time to reduce DB queries"""

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
cache_key = (self.__class__.__name__, self.id)

with _state_cache_lock:
now = time.time()
if cache_key in _state_cache:
cached_result, timestamp = _state_cache[cache_key]
if now - timestamp < ttl_seconds:
return cached_result

result = func(self, *args, **kwargs)

with _state_cache_lock:
_state_cache[cache_key] = (result, time.time())

# Clean old cache entries occasionally
if len(_state_cache) > 1000: # Arbitrary limit
now = time.time()
for k in list(_state_cache.keys()):
if now - _state_cache[k][1] > ttl_seconds * 2:
del _state_cache[k]

return result
return wrapper
return decorator


class State(abc.ABC, Generic[_State]):
"""
An abstract state buffer that attempts to push state changes to a sink.
Expand Down Expand Up @@ -87,6 +125,7 @@ def new(
session.commit()
return cls(id=db_state.id, model=type(value), type=t)

@memoize_state_get(ttl_seconds=2)
def get(self) -> _State:
with Session() as session:
db_state = session.get(DbRunState, self.id)
Expand Down Expand Up @@ -120,6 +159,12 @@ def update(self):
of inter related locks), the database may reach a deadlock state which last until the lock timeout configured
on the postgres database.
"""
# Clear cache for this state before update
cache_key = (self.__class__.__name__, self.id)
with _state_cache_lock:
if cache_key in _state_cache:
del _state_cache[cache_key]

with Session() as session:
r = session.execute(
select(DbRunState).where(DbRunState.id == self.id).with_for_update()
Expand All @@ -134,6 +179,10 @@ def update(self):
session.merge(db_state)
session.commit()

# Update cache with new value after successful commit
with _state_cache_lock:
_state_cache[cache_key] = (value, time.time())


@functools.total_ordering
class BufferedMemoryState(State[_State]):
Expand Down
Loading