Skip to content

Commit 3e24d94

Browse files
Fix errors
1 parent 7796fee commit 3e24d94

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

src/seer/automation/codegen/codegen_context.py

+11
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,14 @@ def get_unit_test_memory(
9090
.one_or_none()
9191
)
9292
return pr_context
93+
94+
def store_unit_test_memory(
95+
self, owner: str, repo: str, pr_id: int
96+
) -> DbPrContextToUnitTestGenerationRunIdMapping | None:
97+
with Session() as session:
98+
pr_context = (
99+
session.query(DbPrContextToUnitTestGenerationRunIdMapping)
100+
.filter_by(owner=owner, repo=repo, pr_id=pr_id)
101+
.one_or_none()
102+
)
103+
return pr_context

src/seer/automation/codegen/unit_test_coding_component.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,85 @@ class UnitTestCodingComponent(BaseComponent[CodeUnitTestRequest, CodeUnitTestOut
3434
def invoke(
3535
self, request: CodeUnitTestRequest, llm_client: LlmClient = injected
3636
) -> CodeUnitTestOutput | None:
37-
pass
37+
with BaseTools(self.context, repo_client_type=RepoClientType.CODECOV_UNIT_TEST) as tools:
38+
agent = LlmAgent(
39+
tools=tools.get_tools(),
40+
config=AgentConfig(interactive=False),
41+
)
42+
43+
codecov_client_params = request.codecov_client_params
44+
45+
code_coverage_data = CodecovClient.fetch_coverage(
46+
repo_name=codecov_client_params["repo_name"],
47+
pullid=codecov_client_params["pullid"],
48+
owner_username=codecov_client_params["owner_username"],
49+
)
50+
51+
test_result_data = CodecovClient.fetch_test_results_for_commit(
52+
repo_name=codecov_client_params["repo_name"],
53+
owner_username=codecov_client_params["owner_username"],
54+
latest_commit_sha=codecov_client_params["head_sha"],
55+
)
56+
57+
existing_test_design_response = llm_client.generate_text(
58+
model=AnthropicProvider.model("claude-3-7-sonnet@20250219"),
59+
prompt=CodingUnitTestPrompts.format_find_unit_test_pattern_step_msg(
60+
diff_str=request.diff
61+
),
62+
)
63+
64+
llm_client.generate_text(
65+
model=AnthropicProvider.model("claude-3-7-sonnet@20250219"),
66+
prompt=CodingUnitTestPrompts.format_plan_step_msg(
67+
diff_str=request.diff,
68+
has_coverage_info=code_coverage_data,
69+
has_test_result_info=test_result_data,
70+
),
71+
)
72+
73+
final_response = agent.run(
74+
run_config=RunConfig(
75+
prompt=CodingUnitTestPrompts.format_unit_test_msg(
76+
diff_str=request.diff, test_design_hint=existing_test_design_response
77+
),
78+
system_prompt=CodingUnitTestPrompts.format_system_msg(),
79+
model=AnthropicProvider.model("claude-3-7-sonnet@20250219"),
80+
run_name="Generate Unit Tests",
81+
),
82+
)
83+
84+
if not final_response:
85+
return None
86+
plan_steps_content = extract_text_inside_tags(final_response, "plan_steps")
87+
88+
if len(plan_steps_content) == 0:
89+
raise ValueError("Failed to extract plan_steps from the planning step of LLM")
90+
91+
coding_output = PlanStepsPromptXml.from_xml(
92+
f"<plan_steps>{escape_multi_xml(plan_steps_content, ['diff', 'description', 'commit_message'])}</plan_steps>"
93+
).to_model()
94+
95+
if not coding_output.tasks:
96+
raise ValueError("No tasks found in coding output")
97+
file_changes: list[FileChange] = []
98+
for task in coding_output.tasks:
99+
repo_client = self.context.get_repo_client(
100+
task.repo_name, type=RepoClientType.CODECOV_UNIT_TEST
101+
)
102+
if task.type == "file_change":
103+
file_content, _ = repo_client.get_file_content(task.file_path)
104+
if not file_content:
105+
logger.warning(f"Failed to get content for {task.file_path}")
106+
continue
107+
108+
changes, _ = task_to_file_change(task, file_content)
109+
file_changes += changes
110+
elif task.type == "file_delete":
111+
change = task_to_file_delete(task)
112+
file_changes.append(change)
113+
elif task.type == "file_create":
114+
change = task_to_file_create(task)
115+
file_changes.append(change)
116+
else:
117+
logger.warning(f"Unsupported task type: {task.type}")
118+
return CodeUnitTestOutput(diffs=file_changes)

src/seer/automation/codegen/unit_test_github_pr_creator.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def create_github_pull_request(self):
6060
new_pr_url = new_pr.html_url
6161
self.repo_client.post_unit_test_reference_to_original_pr(original_pr_url, new_pr_url)
6262

63-
def store_pr_context(self, new_pr: PullRequest, original_pr: PullRequest):
63+
def store_pr_context(self, new_pr: PullRequest):
6464
with Session() as session:
65-
pr_id_mapping = DbPrContextToUnitTestGenerationRunIdMapping(
65+
run_info = DbPrContextToUnitTestGenerationRunIdMapping(
6666
provider="github",
6767
owner=self.repo_client.repo.owner.login,
6868
repo=self.repo_client.repo.name,
@@ -71,6 +71,5 @@ def store_pr_context(self, new_pr: PullRequest, original_pr: PullRequest):
7171
run_id=self.unit_test_run_id,
7272
original_pr_url=self.pr.html_url,
7373
)
74-
session.add(pr_id_mapping)
74+
session.add(run_info)
7575
session.commit()
76-
print("SAVED TO DB")

0 commit comments

Comments
 (0)