Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ft] Add continue unit test generation workflow #2023

Merged
merged 37 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9d6b09d
Save WIP 2
rohitvinnakota-codecov Mar 1, 2025
e672c0f
Update
rohitvinnakota-codecov Mar 3, 2025
fd7cea0
Update
rohitvinnakota-codecov Mar 3, 2025
21de77f
Update
rohitvinnakota-codecov Mar 3, 2025
7796fee
Update types
rohitvinnakota-codecov Mar 3, 2025
3e24d94
Fix errors
rohitvinnakota-codecov Mar 3, 2025
dcf71f5
Update tests
rohitvinnakota-codecov Mar 3, 2025
0a3f868
Update
rohitvinnakota-codecov Mar 4, 2025
ee46e5d
fix tests
rohitvinnakota-codecov Mar 4, 2025
fa84677
Fix test
rohitvinnakota-codecov Mar 4, 2025
0d673ad
Update
rohitvinnakota-codecov Mar 5, 2025
687dd9b
Update - no tests
rohitvinnakota-codecov Mar 6, 2025
464fd0a
Update failing test
rohitvinnakota-codecov Mar 6, 2025
2687d19
Type fix
rohitvinnakota-codecov Mar 6, 2025
5238a1a
Merge branch 'main' of https://github.com/getsentry/seer into rvinnak…
rohitvinnakota-codecov Mar 6, 2025
7c0a2fc
Update
rohitvinnakota-codecov Mar 6, 2025
9f962d5
Fix some issues
rohitvinnakota-codecov Mar 17, 2025
7c2d0c9
Update
rohitvinnakota-codecov Mar 17, 2025
0b46e6c
Update
rohitvinnakota-codecov Mar 17, 2025
a2d2a79
Merge branch 'main' of https://github.com/getsentry/seer into rvinnak…
rohitvinnakota-codecov Mar 17, 2025
4cdb3e1
Add migration
rohitvinnakota-codecov Mar 17, 2025
3103122
Fix tests
rohitvinnakota-codecov Mar 17, 2025
32b79f1
Update tests
rohitvinnakota-codecov Mar 17, 2025
120c164
Add test file
rohitvinnakota-codecov Mar 17, 2025
2d8fc2a
Fix test
rohitvinnakota-codecov Mar 17, 2025
57ed151
update
rohitvinnakota-codecov Mar 17, 2025
fc03d98
Fix
rohitvinnakota-codecov Mar 17, 2025
11d0796
Update
rohitvinnakota-codecov Mar 17, 2025
ea51f91
Actually fix tests
rohitvinnakota-codecov Mar 18, 2025
60d9d59
Update tests
rohitvinnakota-codecov Mar 18, 2025
7dcb140
Update with feedback
rohitvinnakota-codecov Mar 19, 2025
cce291c
More tests
rohitvinnakota-codecov Mar 19, 2025
c2a827d
Update repo client
rohitvinnakota-codecov Mar 20, 2025
2660b7e
Merge branch 'main' of https://github.com/getsentry/seer into rvinnak…
rohitvinnakota-codecov Mar 20, 2025
6dfe889
Update
rohitvinnakota-codecov Mar 20, 2025
8388c6a
Update
rohitvinnakota-codecov Mar 20, 2025
ee97014
:hammer_and_wrench: apply pre-commit fixes
getsantry[bot] Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/integrations/codecov/codecov_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ def fetch_test_results_for_commit(
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
if response.json()["count"] == 0:
data = response.json()
if data["count"] == 0:
return None
return response.text
else:
return None
return [
{"name": r["name"], "failure_message": r["failure_message"]}
for r in data["results"]
]
return None
55 changes: 55 additions & 0 deletions src/migrations/versions/f1970ed945bd_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Migration

Revision ID: f1970ed945bd
Revises: cac994b711d3
Create Date: 2025-03-20 13:48:53.089598

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "f1970ed945bd"
down_revision = "cac994b711d3"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"codegen_unit_test_generation_pr_context_to_run_id",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("provider", sa.String(), nullable=False),
sa.Column("owner", sa.String(), nullable=False),
sa.Column("pr_id", sa.BigInteger(), nullable=False),
sa.Column("repo", sa.String(), nullable=False),
sa.Column("run_id", sa.Integer(), nullable=False),
sa.Column("iterations", sa.Integer(), nullable=False),
sa.Column("original_pr_url", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["run_id"], ["run_state.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("provider", "pr_id", "repo", "owner", "original_pr_url"),
)
with op.batch_alter_table(
"codegen_unit_test_generation_pr_context_to_run_id", schema=None
) as batch_op:
batch_op.create_index(
"ix_unit_test_context_repo_owner_pr_id_pr_url",
["owner", "repo", "pr_id", "original_pr_url"],
unique=False,
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table(
"codegen_unit_test_generation_pr_context_to_run_id", schema=None
) as batch_op:
batch_op.drop_index("ix_unit_test_context_repo_owner_pr_id_pr_url")

op.drop_table("codegen_unit_test_generation_pr_context_to_run_id")
# ### end Alembic commands ###
4 changes: 4 additions & 0 deletions src/seer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from seer.automation.codegen.tasks import (
codegen_pr_review,
codegen_relevant_warnings,
codegen_retry_unittest,
codegen_unittest,
get_unittest_state,
)
Expand Down Expand Up @@ -329,6 +330,9 @@ def codecov_request_endpoint(
return codegen_pr_review_endpoint(data.data)
elif data.request_type == "unit-tests":
return codegen_unit_tests_endpoint(data.data)
elif data.request_type == "retry-unit-tests":
return codegen_retry_unittest(data.data)

raise ValueError(f"Unsupported request_type: {data.request_type}")


Expand Down
65 changes: 61 additions & 4 deletions src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,18 @@ def get_commit_patch_for_file(

return matching_file.patch

def _create_branch(self, branch_name):
def _create_branch(self, branch_name, from_feature_branch=False):
if from_feature_branch:
return self._create_branch_from_feature_branch(branch_name)

ref = self.repo.create_git_ref(
ref=f"refs/heads/{branch_name}", sha=self.get_default_branch_head_sha()
)
return ref

def _create_branch_from_feature_branch(self, branch_name):
return self.repo.create_git_ref(ref=f"refs/heads/{branch_name}", sha=self.base_commit_sha)

def process_one_file_for_git_commit(
self, *, branch_ref: str, patch: FilePatch | None = None, change: FileChange | None = None
) -> InputGitTreeElement | None:
Expand Down Expand Up @@ -610,19 +616,20 @@ def create_branch_from_changes(
file_patches: list[FilePatch] | None = None,
file_changes: list[FileChange] | None = None,
branch_name: str | None = None,
from_feature_branch: bool = False,
) -> GitRef | None:
if not file_patches and not file_changes:
raise ValueError("Either file_patches or file_changes must be provided")

new_branch_name = sanitize_branch_name(branch_name or pr_title)

try:
branch_ref = self._create_branch(new_branch_name)
branch_ref = self._create_branch(new_branch_name, from_feature_branch)
except GithubException as e:
# only use the random suffix if the branch already exists
if e.status == 409 or e.status == 422:
new_branch_name = f"{new_branch_name}-{generate_random_string(n=6)}"
branch_ref = self._create_branch(new_branch_name)
branch_ref = self._create_branch(new_branch_name, from_feature_branch)
else:
raise e

Expand Down Expand Up @@ -824,11 +831,24 @@ def post_unit_test_reference_to_original_pr(self, original_pr_url: str, unit_tes
response.raise_for_status()
return response.json()["html_url"]

def post_unit_test_reference_to_original_pr_codecov_app(
self, original_pr_url: str, unit_test_pr_url: str
):
original_pr_id = int(original_pr_url.split("/")[-1])
repo_name = original_pr_url.split("github.com/")[1].split("/pull")[0]
url = f"https://api.github.com/repos/{repo_name}/issues/{original_pr_id}/comments"
comment = f"Codecov has generated a new [PR]({unit_test_pr_url}) with unit tests for this PR. View the new PR({unit_test_pr_url}) to review the changes."
params = {"body": comment}
headers = self._get_auth_headers()
response = requests.post(url, headers=headers, json=params)
response.raise_for_status()
return response.json()["html_url"]

def post_unit_test_not_generated_message_to_original_pr(self, original_pr_url: str):
original_pr_id = int(original_pr_url.split("/")[-1])
repo_name = original_pr_url.split("github.com/")[1].split("/pull")[0]
url = f"https://api.github.com/repos/{repo_name}/issues/{original_pr_id}/comments"
comment = "Sentry has determined that unit tests already exist on this PR or that they are not necessary."
comment = f"Sentry has determined that unit tests are not necessary for this PR."
params = {"body": comment}
headers = self._get_auth_headers()
response = requests.post(url, headers=headers, json=params)
Expand Down Expand Up @@ -866,3 +886,40 @@ def post_pr_review_comment(self, pr_url: str, comment: GithubPrReviewComment):
start_line=comment.get("start_line", GithubObject.NotSet),
)
return review_comment.html_url

def push_new_commit_to_pr(
self,
pr,
commit_message: str,
file_patches: list[FilePatch] | None = None,
file_changes: list[FileChange] | None = None,
):
if not file_patches and not file_changes:
raise ValueError("Must provide file_patches or file_changes")
branch_name = pr.head.ref
tree_elements = []
if file_patches:
for patch in file_patches:
element = self.process_one_file_for_git_commit(branch_ref=branch_name, patch=patch)
if element:
tree_elements.append(element)
elif file_changes:
for change in file_changes:
element = self.process_one_file_for_git_commit(
branch_ref=branch_name, change=change
)
if element:
tree_elements.append(element)
if not tree_elements:
logger.warning("No valid changes to commit")
return None
latest_sha = self.get_branch_head_sha(branch_name)
latest_commit = self.repo.get_git_commit(latest_sha)
base_tree = latest_commit.tree
new_tree = self.repo.create_git_tree(tree_elements, base_tree)
new_commit = self.repo.create_git_commit(
message=commit_message, tree=new_tree, parents=[latest_commit]
)
branch_ref = self.repo.get_git_ref(f"heads/{branch_name}")
branch_ref.edit(sha=new_commit.sha)
return new_commit
67 changes: 66 additions & 1 deletion src/seer/automation/codegen/codegen_context.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging

from seer.automation.agent.models import Message
from seer.automation.codebase.repo_client import RepoClient, RepoClientType
from seer.automation.codegen.codegen_event_manager import CodegenEventManager
from seer.automation.codegen.models import CodegenContinuation
from seer.automation.codegen.models import CodegenContinuation, UnitTestRunMemory
from seer.automation.codegen.state import CodegenContinuationState
from seer.automation.models import RepoDefinition
from seer.automation.pipeline import PipelineContext
from seer.automation.state import DbStateRunTypes
from seer.db import DbPrContextToUnitTestGenerationRunIdMapping, DbRunMemory, Session

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,3 +79,66 @@ def get_file_contents(
file_contents = file_change.apply(file_contents)

return file_contents

def store_memory(self, key: str, memory: list[Message]):
with Session() as session:
memory_record = (
session.query(DbRunMemory).where(DbRunMemory.run_id == self.run_id).one_or_none()
)

if not memory_record:
memory_model = UnitTestRunMemory(run_id=self.run_id)
else:
memory_model = UnitTestRunMemory.from_db_model(memory_record)

memory_model.memory[key] = memory
memory_record = memory_model.to_db_model()

session.merge(memory_record)
session.commit()

def update_stored_memory(self, key: str, memory: list[Message], original_run_id: int):
with Session() as session:
memory_record = (
session.query(DbRunMemory)
.where(DbRunMemory.run_id == original_run_id)
.one_or_none()
)

if not memory_record:
raise RuntimeError(
f"No memory record found for run_id {original_run_id}. Cannot update stored memory."
)
else:
memory_model = UnitTestRunMemory.from_db_model(memory_record)

memory_model.memory[key] = memory
memory_record = memory_model.to_db_model()

session.merge(memory_record)
session.commit()

def get_memory(self, key: str, past_run_id: int) -> list[Message]:
with Session() as session:
memory_record = (
session.query(DbRunMemory).where(DbRunMemory.run_id == past_run_id).one_or_none()
)

if not memory_record:
return []

return UnitTestRunMemory.from_db_model(memory_record).memory.get(key, [])

def get_previous_run_context(
self, owner: str, repo: str, pr_id: int
) -> DbPrContextToUnitTestGenerationRunIdMapping | None:
with Session() as session:
previous_context = (
session.query(DbPrContextToUnitTestGenerationRunIdMapping)
.where(DbPrContextToUnitTestGenerationRunIdMapping.owner == owner)
.where(DbPrContextToUnitTestGenerationRunIdMapping.repo == repo)
.where(DbPrContextToUnitTestGenerationRunIdMapping.pr_id == pr_id)
.one_or_none()
)

return previous_context
19 changes: 17 additions & 2 deletions src/seer/automation/codegen/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
from enum import Enum
from typing import Literal
from typing import Literal, Optional

from pydantic import BaseModel, Field

from seer.automation.agent.models import Message
from seer.automation.codebase.models import StaticAnalysisWarning
from seer.automation.component import BaseComponentOutput, BaseComponentRequest
from seer.automation.models import FileChange, IssueDetails, RepoDefinition
from seer.db import DbRunMemory


class CodegenStatus(str, Enum):
Expand Down Expand Up @@ -45,6 +47,7 @@ class CodeUnitTestOutput(BaseComponentOutput):
class CodegenBaseRequest(BaseModel):
repo: RepoDefinition
pr_id: int # The PR number
codecov_status: dict[str, bool] | None = None


class CodegenUnitTestsRequest(CodegenBaseRequest):
Expand Down Expand Up @@ -210,4 +213,16 @@ class CodePredictRelevantWarningsOutput(BaseComponentOutput):
class CodecovTaskRequest(BaseModel):
data: CodegenUnitTestsRequest | CodegenPrReviewRequest | CodegenRelevantWarningsRequest
external_owner_id: str
request_type: Literal["unit-tests", "pr-review", "relevant-warnings"]
request_type: Literal["unit-tests", "pr-review", "relevant-warnings", "retry-unit-tests"]


class UnitTestRunMemory(BaseModel):
run_id: int
memory: dict[str, list[Message]] = Field(default_factory=dict)

def to_db_model(self) -> DbRunMemory:
return DbRunMemory(run_id=self.run_id, value=self.model_dump(mode="json"))

@classmethod
def from_db_model(cls, model: DbRunMemory) -> "UnitTestRunMemory":
return cls.model_validate(model.value)
23 changes: 23 additions & 0 deletions src/seer/automation/codegen/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,26 @@ def format_prompt(formatted_warning: str, formatted_error: str):
error_prompt=_RelevantWarningsPromptPrefix.format_prompt_error(formatted_error),
formatted_warning=formatted_warning,
)


class RetryUnitTestPrompts:
@staticmethod
def format_continue_unit_tests_prompt(code_coverage_info: str, test_result_info: str):
return textwrap.dedent(
"""\
The tests you have generated so far are not sufficient to cover all the changes in the codebase. You need to continue generating unit tests to address the gaps in coverage and fix any failing tests.

To help you with this, you have access to code coverage information at a file level attached as a JSON in addtion to test result information also in a JSON format.

Using the information and instructions provided, update the unit tests to ensure robust code coverage as well as fix any failing tests. Use the exact same format you used previously to regenerate tests. Your changes will be appended as a new commit to the branch of the existing PR.

Here is the code coverage information:
{code_coverage_info}

Here is the test result information:
{test_result_info}
"""
).format(
code_coverage_info=code_coverage_info,
test_result_info=test_result_info,
)
Loading
Loading