1
1
import json
2
+ import tempfile
3
+ from pathlib import Path
2
4
from unittest .mock import Mock , patch
3
5
4
6
import pandas as pd
7
+ import pytest
5
8
from nemoguardrails .colang import parse_colang_file
6
9
7
10
from giskard .llm .client .base import ChatMessage
8
11
from giskard .scanner .issues import Issue , Robustness
9
12
from giskard .scanner .report import ScanReport
10
13
11
14
15
+ def _generate_rails (report : ScanReport , filename = None , colang_version = "1.0" ):
16
+ if filename :
17
+ with tempfile .TemporaryDirectory () as tmpdir :
18
+ dest = Path (tmpdir ).joinpath ("rails.co" )
19
+ report .generate_rails (filename = dest , colang_version = colang_version )
20
+ assert dest .exists ()
21
+ assert dest .is_file ()
22
+ rails = dest .read_text (encoding = "utf-8" )
23
+ else :
24
+ rails = report .generate_rails (colang_version = colang_version )
25
+ return rails
26
+
27
+
28
+ @pytest .mark .parametrize ("filename" , [(None ), ("rails.co" )])
12
29
@patch ("giskard.integrations.nemoguardrails.get_default_client" )
13
- def test_generate_colang_v1_rails_from_scan (get_default_client_mock ):
30
+ def test_generate_colang_v1_rails_from_scan (get_default_client_mock , filename ):
14
31
report = make_test_report ()
15
32
16
33
llm_client = get_default_client_mock ()
17
34
llm_client .complete .side_effect = make_llm_answers ()
18
35
19
- rails = report . generate_rails ( )
36
+ rails = _generate_rails ( report , filename = filename , colang_version = "1.0" )
20
37
21
38
# Check that the file is correctly formatted
22
39
parsed = parse_colang_file ("rails.co" , rails , version = "1.0" )
@@ -27,14 +44,15 @@ def test_generate_colang_v1_rails_from_scan(get_default_client_mock):
27
44
assert parsed ["flows" ][1 ]["id" ] == "ask help on illegal activities"
28
45
29
46
47
+ @pytest .mark .parametrize ("filename" , [(None ), ("rails.co" )])
30
48
@patch ("giskard.integrations.nemoguardrails.get_default_client" )
31
- def test_generate_colang_v2_rails_from_scan (get_default_client_mock ):
49
+ def test_generate_colang_v2_rails_from_scan (get_default_client_mock , filename ):
32
50
report = make_test_report ()
33
51
34
52
llm_client = get_default_client_mock ()
35
53
llm_client .complete .side_effect = make_llm_answers ()
36
54
37
- rails = report . generate_rails ( colang_version = "2.x" )
55
+ rails = _generate_rails ( report , filename = filename , colang_version = "2.x" )
38
56
39
57
# Check that the file is correctly formatted
40
58
parsed = parse_colang_file ("rails.co" , rails , version = "2.x" )
0 commit comments