Skip to content

Commit cce291c

Browse files
More tests
1 parent 7dcb140 commit cce291c

File tree

2 files changed

+191
-1
lines changed

2 files changed

+191
-1
lines changed

src/seer/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def codecov_request_endpoint(
319319
elif data.request_type == "unit-tests":
320320
return codegen_unit_tests_endpoint(data.data)
321321
elif data.request_type == "retry-unit-tests":
322-
return codegen_retry_unittest(data)
322+
return codegen_retry_unittest(data.data)
323323

324324
raise ValueError(f"Unsupported request_type: {data.request_type}")
325325

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
from seer.automation.agent.models import Message
4+
from seer.automation.codegen.codegen_context import CodegenContext
5+
from seer.automation.models import RepoDefinition
6+
7+
8+
class DummyRequest:
9+
def __init__(self, repo):
10+
self.repo = repo
11+
12+
13+
class DummyContinuation:
14+
def __init__(self, run_id, request, file_changes=None, signals=None):
15+
self.run_id = run_id
16+
self.request = request
17+
self.file_changes = file_changes or []
18+
self.signals = signals or []
19+
20+
21+
class DummyCodegenContinuationState:
22+
def __init__(self, run_id, request, file_changes=None, signals=None):
23+
self._state = DummyContinuation(run_id, request, file_changes, signals)
24+
25+
def get(self):
26+
return self._state
27+
28+
def update(self):
29+
class DummyContextManager:
30+
def __init__(self, state):
31+
self.state = state
32+
33+
def __enter__(self):
34+
return self.state
35+
36+
def __exit__(self, exc_type, exc_val, exc_tb):
37+
pass
38+
39+
return DummyContextManager(self._state)
40+
41+
42+
class DummyFileChange:
43+
def __init__(self, path, new_content):
44+
self.path = path
45+
self.new_content = new_content
46+
47+
def apply(self, original_content):
48+
return self.new_content
49+
50+
51+
class TestCodegenContext(unittest.TestCase):
52+
def setUp(self):
53+
self.repo = RepoDefinition(
54+
provider="github", owner="test_owner", name="test_repo", external_id="dummy_id"
55+
)
56+
self.request = DummyRequest(self.repo)
57+
self.state = DummyCodegenContinuationState(run_id=1, request=self.request)
58+
self.codegen_context = CodegenContext(self.state)
59+
60+
def test_run_id(self):
61+
self.assertEqual(self.codegen_context.run_id, 1)
62+
63+
def test_signals_getter_setter(self):
64+
self.codegen_context.signals = ["signal1", "signal2"]
65+
self.assertEqual(self.codegen_context.signals, ["signal1", "signal2"])
66+
67+
def test_get_file_contents_no_local_changes(self):
68+
dummy_client = MagicMock()
69+
dummy_client.get_file_content.return_value = ("original content", None)
70+
with patch.object(self.codegen_context, "get_repo_client", return_value=dummy_client):
71+
content = self.codegen_context.get_file_contents(
72+
"dummy_path", ignore_local_changes=True
73+
)
74+
self.assertEqual(content, "original content")
75+
76+
def test_get_file_contents_with_local_changes(self):
77+
file_change = DummyFileChange("dummy_path", "changed content")
78+
self.state.get().file_changes = [file_change]
79+
dummy_client = MagicMock()
80+
dummy_client.get_file_content.return_value = ("original content", None)
81+
with patch.object(self.codegen_context, "get_repo_client", return_value=dummy_client):
82+
content = self.codegen_context.get_file_contents("dummy_path")
83+
self.assertEqual(content, "changed content")
84+
85+
@patch("seer.automation.codegen.codegen_context.Session")
86+
def test_store_and_get_memory(self, mock_session):
87+
fake_db = {}
88+
89+
class FakeSession:
90+
def __enter__(self):
91+
return self
92+
93+
def __exit__(self, exc_type, exc_val, exc_tb):
94+
pass
95+
96+
def query(self, model):
97+
class FakeQuery:
98+
def __init__(self, db):
99+
self.db = db
100+
101+
def where(self, condition):
102+
return self
103+
104+
def one_or_none(self):
105+
return self.db.get("memory")
106+
107+
return FakeQuery(fake_db)
108+
109+
def merge(self, obj):
110+
fake_db["memory"] = obj
111+
112+
def commit(self):
113+
pass
114+
115+
mock_session.return_value = FakeSession()
116+
key = "test_key"
117+
memory = [Message(role="user", content="Test message")]
118+
self.codegen_context.store_memory(key, memory)
119+
result = self.codegen_context.get_memory(key, past_run_id=1)
120+
self.assertEqual(len(result), 1)
121+
self.assertEqual(result[0].role, "user")
122+
self.assertEqual(result[0].content, "Test message")
123+
124+
@patch("seer.automation.codegen.codegen_context.Session")
125+
def test_update_stored_memory(self, mock_session):
126+
fake_db = {}
127+
128+
class FakeSession:
129+
def __init__(self):
130+
self.db = fake_db
131+
132+
def __enter__(self):
133+
return self
134+
135+
def __exit__(self, exc_type, exc_val, exc_tb):
136+
pass
137+
138+
def query(self, model):
139+
class FakeQuery:
140+
def __init__(self, db):
141+
self.db = db
142+
143+
def where(self, condition):
144+
return self
145+
146+
def one_or_none(self):
147+
return self.db.get("memory")
148+
149+
return FakeQuery(self.db)
150+
151+
def merge(self, obj):
152+
self.db["memory"] = obj
153+
154+
def commit(self):
155+
pass
156+
157+
mock_session.return_value = FakeSession()
158+
key = "update_key"
159+
initial_memory = [Message(role="user", content="Old message")]
160+
self.codegen_context.store_memory(key, initial_memory)
161+
new_memory = [Message(role="user", content="New message")]
162+
self.codegen_context.update_stored_memory(key, new_memory, original_run_id=1)
163+
result = self.codegen_context.get_memory(key, past_run_id=1)
164+
self.assertEqual(len(result), 1)
165+
self.assertEqual(result[0].content, "New message")
166+
167+
@patch("seer.automation.codegen.codegen_context.Session")
168+
def test_get_previous_run_context(self, mock_session):
169+
fake_context = MagicMock()
170+
171+
class FakeSession:
172+
def __enter__(self):
173+
return self
174+
175+
def __exit__(self, exc_type, exc_val, exc_tb):
176+
pass
177+
178+
def query(self, model):
179+
class FakeQuery:
180+
def where(self, *args, **kwargs):
181+
return self
182+
183+
def one_or_none(self):
184+
return fake_context
185+
186+
return FakeQuery()
187+
188+
mock_session.return_value = FakeSession()
189+
result = self.codegen_context.get_previous_run_context("test_owner", "test_repo", 123)
190+
self.assertEqual(result, fake_context)

0 commit comments

Comments
 (0)