Skip to content

Commit 7d2006f

Browse files
[ft] Add continue unit test generation workflow (#2023)
This change lays out the groundwork for running unit tests against code coverage and test analytics metrics to generate better tests for a given PR. The retry unit tests endpoint will be invoked from Codecov's `Shelter` service This change also adds a `DbPrContextToUnitTestGenerationRunIdMapping` table that will store information(owner, repo, etc) + agent context about past runs, which will be leveraged as a memory store for subsequent unit test runs. Some other updates: - Fix for an issue where a generated tests PR is based off of the main branch. It should be built on top of the feature branch tests are being requested on - Update Claude version - Add logic that lets the codecov-ai GH app post updates instead of autofix under certain conditions --------- Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
1 parent be054bd commit 7d2006f

23 files changed

+1320
-36
lines changed

src/integrations/codecov/codecov_client.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ def fetch_test_results_for_commit(
3030
}
3131
response = requests.get(url, headers=headers)
3232
if response.status_code == 200:
33-
if response.json()["count"] == 0:
33+
data = response.json()
34+
if data["count"] == 0:
3435
return None
35-
return response.text
36-
else:
37-
return None
36+
return [
37+
{"name": r["name"], "failure_message": r["failure_message"]}
38+
for r in data["results"]
39+
]
40+
return None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Migration
2+
3+
Revision ID: f1970ed945bd
4+
Revises: cac994b711d3
5+
Create Date: 2025-03-20 13:48:53.089598
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "f1970ed945bd"
14+
down_revision = "cac994b711d3"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table(
22+
"codegen_unit_test_generation_pr_context_to_run_id",
23+
sa.Column("id", sa.Integer(), nullable=False),
24+
sa.Column("provider", sa.String(), nullable=False),
25+
sa.Column("owner", sa.String(), nullable=False),
26+
sa.Column("pr_id", sa.BigInteger(), nullable=False),
27+
sa.Column("repo", sa.String(), nullable=False),
28+
sa.Column("run_id", sa.Integer(), nullable=False),
29+
sa.Column("iterations", sa.Integer(), nullable=False),
30+
sa.Column("original_pr_url", sa.String(), nullable=False),
31+
sa.ForeignKeyConstraint(["run_id"], ["run_state.id"], ondelete="CASCADE"),
32+
sa.PrimaryKeyConstraint("id"),
33+
sa.UniqueConstraint("provider", "pr_id", "repo", "owner", "original_pr_url"),
34+
)
35+
with op.batch_alter_table(
36+
"codegen_unit_test_generation_pr_context_to_run_id", schema=None
37+
) as batch_op:
38+
batch_op.create_index(
39+
"ix_unit_test_context_repo_owner_pr_id_pr_url",
40+
["owner", "repo", "pr_id", "original_pr_url"],
41+
unique=False,
42+
)
43+
44+
# ### end Alembic commands ###
45+
46+
47+
def downgrade():
48+
# ### commands auto generated by Alembic - please adjust! ###
49+
with op.batch_alter_table(
50+
"codegen_unit_test_generation_pr_context_to_run_id", schema=None
51+
) as batch_op:
52+
batch_op.drop_index("ix_unit_test_context_repo_owner_pr_id_pr_url")
53+
54+
op.drop_table("codegen_unit_test_generation_pr_context_to_run_id")
55+
# ### end Alembic commands ###

src/seer/app.py

+4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from seer.automation.codegen.tasks import (
6868
codegen_pr_review,
6969
codegen_relevant_warnings,
70+
codegen_retry_unittest,
7071
codegen_unittest,
7172
get_unittest_state,
7273
)
@@ -333,6 +334,9 @@ def codecov_request_endpoint(
333334
return codegen_pr_review_endpoint(data.data)
334335
elif data.request_type == "unit-tests":
335336
return codegen_unit_tests_endpoint(data.data)
337+
elif data.request_type == "retry-unit-tests":
338+
return codegen_retry_unittest(data.data)
339+
336340
raise ValueError(f"Unsupported request_type: {data.request_type}")
337341

338342

src/seer/automation/codebase/repo_client.py

+61-4
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,18 @@ def get_commit_patch_for_file(
545545

546546
return matching_file.patch
547547

548-
def _create_branch(self, branch_name):
548+
def _create_branch(self, branch_name, from_feature_branch=False):
549+
if from_feature_branch:
550+
return self._create_branch_from_feature_branch(branch_name)
551+
549552
ref = self.repo.create_git_ref(
550553
ref=f"refs/heads/{branch_name}", sha=self.get_default_branch_head_sha()
551554
)
552555
return ref
553556

557+
def _create_branch_from_feature_branch(self, branch_name):
558+
return self.repo.create_git_ref(ref=f"refs/heads/{branch_name}", sha=self.base_commit_sha)
559+
554560
def process_one_file_for_git_commit(
555561
self, *, branch_ref: str, patch: FilePatch | None = None, change: FileChange | None = None
556562
) -> InputGitTreeElement | None:
@@ -610,19 +616,20 @@ def create_branch_from_changes(
610616
file_patches: list[FilePatch] | None = None,
611617
file_changes: list[FileChange] | None = None,
612618
branch_name: str | None = None,
619+
from_feature_branch: bool = False,
613620
) -> GitRef | None:
614621
if not file_patches and not file_changes:
615622
raise ValueError("Either file_patches or file_changes must be provided")
616623

617624
new_branch_name = sanitize_branch_name(branch_name or pr_title)
618625

619626
try:
620-
branch_ref = self._create_branch(new_branch_name)
627+
branch_ref = self._create_branch(new_branch_name, from_feature_branch)
621628
except GithubException as e:
622629
# only use the random suffix if the branch already exists
623630
if e.status == 409 or e.status == 422:
624631
new_branch_name = f"{new_branch_name}-{generate_random_string(n=6)}"
625-
branch_ref = self._create_branch(new_branch_name)
632+
branch_ref = self._create_branch(new_branch_name, from_feature_branch)
626633
else:
627634
raise e
628635

@@ -824,11 +831,24 @@ def post_unit_test_reference_to_original_pr(self, original_pr_url: str, unit_tes
824831
response.raise_for_status()
825832
return response.json()["html_url"]
826833

834+
def post_unit_test_reference_to_original_pr_codecov_app(
835+
self, original_pr_url: str, unit_test_pr_url: str
836+
):
837+
original_pr_id = int(original_pr_url.split("/")[-1])
838+
repo_name = original_pr_url.split("github.com/")[1].split("/pull")[0]
839+
url = f"https://api.github.com/repos/{repo_name}/issues/{original_pr_id}/comments"
840+
comment = f"Codecov has generated a new [PR]({unit_test_pr_url}) with unit tests for this PR. View the new PR({unit_test_pr_url}) to review the changes."
841+
params = {"body": comment}
842+
headers = self._get_auth_headers()
843+
response = requests.post(url, headers=headers, json=params)
844+
response.raise_for_status()
845+
return response.json()["html_url"]
846+
827847
def post_unit_test_not_generated_message_to_original_pr(self, original_pr_url: str):
828848
original_pr_id = int(original_pr_url.split("/")[-1])
829849
repo_name = original_pr_url.split("github.com/")[1].split("/pull")[0]
830850
url = f"https://api.github.com/repos/{repo_name}/issues/{original_pr_id}/comments"
831-
comment = "Sentry has determined that unit tests already exist on this PR or that they are not necessary."
851+
comment = f"Sentry has determined that unit tests are not necessary for this PR."
832852
params = {"body": comment}
833853
headers = self._get_auth_headers()
834854
response = requests.post(url, headers=headers, json=params)
@@ -866,3 +886,40 @@ def post_pr_review_comment(self, pr_url: str, comment: GithubPrReviewComment):
866886
start_line=comment.get("start_line", GithubObject.NotSet),
867887
)
868888
return review_comment.html_url
889+
890+
def push_new_commit_to_pr(
891+
self,
892+
pr,
893+
commit_message: str,
894+
file_patches: list[FilePatch] | None = None,
895+
file_changes: list[FileChange] | None = None,
896+
):
897+
if not file_patches and not file_changes:
898+
raise ValueError("Must provide file_patches or file_changes")
899+
branch_name = pr.head.ref
900+
tree_elements = []
901+
if file_patches:
902+
for patch in file_patches:
903+
element = self.process_one_file_for_git_commit(branch_ref=branch_name, patch=patch)
904+
if element:
905+
tree_elements.append(element)
906+
elif file_changes:
907+
for change in file_changes:
908+
element = self.process_one_file_for_git_commit(
909+
branch_ref=branch_name, change=change
910+
)
911+
if element:
912+
tree_elements.append(element)
913+
if not tree_elements:
914+
logger.warning("No valid changes to commit")
915+
return None
916+
latest_sha = self.get_branch_head_sha(branch_name)
917+
latest_commit = self.repo.get_git_commit(latest_sha)
918+
base_tree = latest_commit.tree
919+
new_tree = self.repo.create_git_tree(tree_elements, base_tree)
920+
new_commit = self.repo.create_git_commit(
921+
message=commit_message, tree=new_tree, parents=[latest_commit]
922+
)
923+
branch_ref = self.repo.get_git_ref(f"heads/{branch_name}")
924+
branch_ref.edit(sha=new_commit.sha)
925+
return new_commit

src/seer/automation/codegen/codegen_context.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
22

3+
from seer.automation.agent.models import Message
34
from seer.automation.codebase.repo_client import RepoClient, RepoClientType
45
from seer.automation.codegen.codegen_event_manager import CodegenEventManager
5-
from seer.automation.codegen.models import CodegenContinuation
6+
from seer.automation.codegen.models import CodegenContinuation, UnitTestRunMemory
67
from seer.automation.codegen.state import CodegenContinuationState
78
from seer.automation.models import RepoDefinition
89
from seer.automation.pipeline import PipelineContext
910
from seer.automation.state import DbStateRunTypes
11+
from seer.db import DbPrContextToUnitTestGenerationRunIdMapping, DbRunMemory, Session
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -77,3 +79,66 @@ def get_file_contents(
7779
file_contents = file_change.apply(file_contents)
7880

7981
return file_contents
82+
83+
def store_memory(self, key: str, memory: list[Message]):
84+
with Session() as session:
85+
memory_record = (
86+
session.query(DbRunMemory).where(DbRunMemory.run_id == self.run_id).one_or_none()
87+
)
88+
89+
if not memory_record:
90+
memory_model = UnitTestRunMemory(run_id=self.run_id)
91+
else:
92+
memory_model = UnitTestRunMemory.from_db_model(memory_record)
93+
94+
memory_model.memory[key] = memory
95+
memory_record = memory_model.to_db_model()
96+
97+
session.merge(memory_record)
98+
session.commit()
99+
100+
def update_stored_memory(self, key: str, memory: list[Message], original_run_id: int):
101+
with Session() as session:
102+
memory_record = (
103+
session.query(DbRunMemory)
104+
.where(DbRunMemory.run_id == original_run_id)
105+
.one_or_none()
106+
)
107+
108+
if not memory_record:
109+
raise RuntimeError(
110+
f"No memory record found for run_id {original_run_id}. Cannot update stored memory."
111+
)
112+
else:
113+
memory_model = UnitTestRunMemory.from_db_model(memory_record)
114+
115+
memory_model.memory[key] = memory
116+
memory_record = memory_model.to_db_model()
117+
118+
session.merge(memory_record)
119+
session.commit()
120+
121+
def get_memory(self, key: str, past_run_id: int) -> list[Message]:
122+
with Session() as session:
123+
memory_record = (
124+
session.query(DbRunMemory).where(DbRunMemory.run_id == past_run_id).one_or_none()
125+
)
126+
127+
if not memory_record:
128+
return []
129+
130+
return UnitTestRunMemory.from_db_model(memory_record).memory.get(key, [])
131+
132+
def get_previous_run_context(
133+
self, owner: str, repo: str, pr_id: int
134+
) -> DbPrContextToUnitTestGenerationRunIdMapping | None:
135+
with Session() as session:
136+
previous_context = (
137+
session.query(DbPrContextToUnitTestGenerationRunIdMapping)
138+
.where(DbPrContextToUnitTestGenerationRunIdMapping.owner == owner)
139+
.where(DbPrContextToUnitTestGenerationRunIdMapping.repo == repo)
140+
.where(DbPrContextToUnitTestGenerationRunIdMapping.pr_id == pr_id)
141+
.one_or_none()
142+
)
143+
144+
return previous_context

src/seer/automation/codegen/models.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import datetime
22
from enum import Enum
3-
from typing import Literal
3+
from typing import Literal, Optional
44

55
from pydantic import BaseModel, Field
66

7+
from seer.automation.agent.models import Message
78
from seer.automation.codebase.models import StaticAnalysisWarning
89
from seer.automation.component import BaseComponentOutput, BaseComponentRequest
910
from seer.automation.models import FileChange, IssueDetails, RepoDefinition
11+
from seer.db import DbRunMemory
1012

1113

1214
class CodegenStatus(str, Enum):
@@ -45,6 +47,7 @@ class CodeUnitTestOutput(BaseComponentOutput):
4547
class CodegenBaseRequest(BaseModel):
4648
repo: RepoDefinition
4749
pr_id: int # The PR number
50+
codecov_status: dict[str, bool] | None = None
4851

4952

5053
class CodegenUnitTestsRequest(CodegenBaseRequest):
@@ -210,4 +213,16 @@ class CodePredictRelevantWarningsOutput(BaseComponentOutput):
210213
class CodecovTaskRequest(BaseModel):
211214
data: CodegenUnitTestsRequest | CodegenPrReviewRequest | CodegenRelevantWarningsRequest
212215
external_owner_id: str
213-
request_type: Literal["unit-tests", "pr-review", "relevant-warnings"]
216+
request_type: Literal["unit-tests", "pr-review", "relevant-warnings", "retry-unit-tests"]
217+
218+
219+
class UnitTestRunMemory(BaseModel):
220+
run_id: int
221+
memory: dict[str, list[Message]] = Field(default_factory=dict)
222+
223+
def to_db_model(self) -> DbRunMemory:
224+
return DbRunMemory(run_id=self.run_id, value=self.model_dump(mode="json"))
225+
226+
@classmethod
227+
def from_db_model(cls, model: DbRunMemory) -> "UnitTestRunMemory":
228+
return cls.model_validate(model.value)

src/seer/automation/codegen/prompts.py

+23
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,26 @@ def format_prompt(formatted_warning: str, formatted_error: str):
321321
error_prompt=_RelevantWarningsPromptPrefix.format_prompt_error(formatted_error),
322322
formatted_warning=formatted_warning,
323323
)
324+
325+
326+
class RetryUnitTestPrompts:
327+
@staticmethod
328+
def format_continue_unit_tests_prompt(code_coverage_info: str, test_result_info: str):
329+
return textwrap.dedent(
330+
"""\
331+
The tests you have generated so far are not sufficient to cover all the changes in the codebase. You need to continue generating unit tests to address the gaps in coverage and fix any failing tests.
332+
333+
To help you with this, you have access to code coverage information at a file level attached as a JSON in addtion to test result information also in a JSON format.
334+
335+
Using the information and instructions provided, update the unit tests to ensure robust code coverage as well as fix any failing tests. Use the exact same format you used previously to regenerate tests. Your changes will be appended as a new commit to the branch of the existing PR.
336+
337+
Here is the code coverage information:
338+
{code_coverage_info}
339+
340+
Here is the test result information:
341+
{test_result_info}
342+
"""
343+
).format(
344+
code_coverage_info=code_coverage_info,
345+
test_result_info=test_result_info,
346+
)

0 commit comments

Comments
 (0)