Skip to content

Commit 9997b1c

Browse files
authored
Enforce Gremlin protocol and serializer based on database type (#697)
* Set allowed and default Gremlin protocol and serializer dynamically * Add unit test suite * update changelog
1 parent 8fa1749 commit 9997b1c

File tree

7 files changed

+457
-81
lines changed

7 files changed

+457
-81
lines changed

ChangeLog.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd
55
## Upcoming
66

77
- Updated Gremlin config `message_serializer` to accept all TinkerPop serializers ([Link to PR](https://github.com/aws/graph-notebook/pull/685))
8+
- Implemented service-based dynamic allowlists and defaults for Gremlin serializer and protocol combinations ([Link to PR](https://github.com/aws/graph-notebook/pull/697))
89
- Added `%get_import_task` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/668))
910
- Added `--export-to` JSON file option to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/684))
1011
- Deprecated Python 3.8 support ([Link to PR](https://github.com/aws/graph-notebook/pull/683))

src/graph_notebook/configuration/generate_config.py

+81-50
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS,
1515
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE,
1616
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants,
17-
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP,
17+
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP, GREMLIN_SERIALIZERS_WS,
18+
GREMLIN_SERIALIZERS_ALL, NEPTUNE_GREMLIN_SERIALIZERS_HTTP,
19+
DEFAULT_GREMLIN_WS_SERIALIZER, DEFAULT_GREMLIN_HTTP_SERIALIZER,
1820
NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME,
19-
normalize_service_name)
21+
normalize_service_name, normalize_protocol_name,
22+
normalize_serializer_class_name)
2023

2124
DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')
2225

@@ -57,7 +60,8 @@ class GremlinSection(object):
5760
"""
5861

5962
def __init__(self, traversal_source: str = '', username: str = '', password: str = '',
60-
message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False):
63+
message_serializer: str = '', connection_protocol: str = '',
64+
include_protocol: bool = False, neptune_service: str = ''):
6165
"""
6266
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
6367
connected to an endpoint that can access multiple graphs.
@@ -71,57 +75,78 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str
7175
if traversal_source == '':
7276
traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE
7377

74-
serializer_lower = message_serializer.lower()
75-
# TODO: Update with untyped serializers once supported in GremlinPython
76-
# Accept TinkerPop serializer class name
77-
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphson
78-
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphbinary
79-
if serializer_lower == '':
80-
message_serializer = DEFAULT_GREMLIN_SERIALIZER
81-
elif 'graphson' in serializer_lower:
82-
message_serializer = 'GraphSON'
83-
if 'untyped' in serializer_lower:
84-
message_serializer += 'Untyped'
85-
if 'v1' in serializer_lower:
86-
if 'untyped' in serializer_lower:
87-
message_serializer += 'MessageSerializerV1'
88-
else:
89-
message_serializer += 'MessageSerializerGremlinV1'
90-
elif 'v2' in serializer_lower:
91-
message_serializer += 'MessageSerializerV2'
78+
invalid_serializer_input = False
79+
if message_serializer != '':
80+
message_serializer, invalid_serializer_input = normalize_serializer_class_name(message_serializer)
81+
82+
if include_protocol:
83+
# Neptune endpoint
84+
invalid_protocol_input = False
85+
if connection_protocol != '':
86+
connection_protocol, invalid_protocol_input = normalize_protocol_name(connection_protocol)
87+
88+
if neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME:
89+
if connection_protocol != DEFAULT_HTTP_PROTOCOL:
90+
if invalid_protocol_input:
91+
print(f"Invalid connection protocol specified, you must use {DEFAULT_HTTP_PROTOCOL}. ")
92+
elif connection_protocol == DEFAULT_WS_PROTOCOL:
93+
print(f"Enforcing HTTP protocol.")
94+
connection_protocol = DEFAULT_HTTP_PROTOCOL
95+
# temporary restriction until GraphSON-typed and GraphBinary results are supported
96+
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
97+
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
98+
if invalid_serializer_input:
99+
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
100+
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
101+
else:
102+
print(f"{message_serializer} is not currently supported for HTTP connections, "
103+
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
104+
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
105+
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
92106
else:
93-
message_serializer += 'MessageSerializerV3'
94-
elif 'graphbinary' in serializer_lower:
95-
message_serializer = GRAPHBINARYV1
107+
if connection_protocol not in [DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL]:
108+
if invalid_protocol_input:
109+
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_WS_PROTOCOL}. "
110+
f"Valid protocols: [websockets, http].")
111+
connection_protocol = DEFAULT_WS_PROTOCOL
112+
113+
if connection_protocol == DEFAULT_HTTP_PROTOCOL:
114+
# temporary restriction until GraphSON-typed and GraphBinary results are supported
115+
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
116+
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
117+
if invalid_serializer_input:
118+
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
119+
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
120+
else:
121+
print(f"{message_serializer} is not currently supported for HTTP connections, "
122+
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
123+
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
124+
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
125+
else:
126+
if message_serializer not in GREMLIN_SERIALIZERS_WS:
127+
if invalid_serializer_input:
128+
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
129+
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
130+
elif message_serializer != '':
131+
print(f"{message_serializer} is not currently supported by Gremlin Python driver, "
132+
f"defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
133+
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
134+
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER
135+
136+
self.connection_protocol = connection_protocol
96137
else:
97-
print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. '
98-
f'Valid serializers: {GREMLIN_SERIALIZERS_HTTP}.')
99-
message_serializer = DEFAULT_GREMLIN_SERIALIZER
138+
# Non-Neptune database - check and set valid WebSockets serializer if invalid/empty
139+
if message_serializer not in GREMLIN_SERIALIZERS_WS:
140+
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER
141+
if invalid_serializer_input:
142+
print(f'Invalid Gremlin serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. '
143+
f'Valid serializers: {GREMLIN_SERIALIZERS_WS}.')
100144

101145
self.traversal_source = traversal_source
102146
self.username = username
103147
self.password = password
104148
self.message_serializer = message_serializer
105149

106-
if include_protocol:
107-
protocol_lower = connection_protocol.lower()
108-
if message_serializer in GREMLIN_SERIALIZERS_HTTP:
109-
connection_protocol = DEFAULT_HTTP_PROTOCOL
110-
if protocol_lower != '' and protocol_lower not in HTTP_PROTOCOL_FORMATS:
111-
print(f"Enforcing HTTP protocol usage for serializer: {message_serializer}.")
112-
else:
113-
if protocol_lower == '':
114-
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
115-
elif protocol_lower in HTTP_PROTOCOL_FORMATS:
116-
connection_protocol = DEFAULT_HTTP_PROTOCOL
117-
elif protocol_lower in WS_PROTOCOL_FORMATS:
118-
connection_protocol = DEFAULT_WS_PROTOCOL
119-
else:
120-
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. "
121-
f"Valid protocols: [websockets, http].")
122-
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
123-
self.connection_protocol = connection_protocol
124-
125150
def to_dict(self):
126151
return self.__dict__
127152

@@ -178,8 +203,8 @@ def __init__(self, host: str, port: int,
178203
self.auth_mode = auth_mode
179204
self.load_from_s3_arn = load_from_s3_arn
180205
self.aws_region = aws_region
181-
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL
182206
if gremlin_section is not None:
207+
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else ''
183208
if hasattr(gremlin_section, "connection_protocol"):
184209
if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL:
185210
print("Enforcing HTTP connection protocol for proxy connections.")
@@ -189,9 +214,12 @@ def __init__(self, host: str, port: int,
189214
else:
190215
final_protocol = default_protocol
191216
self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer,
192-
connection_protocol=final_protocol, include_protocol=True)
217+
connection_protocol=final_protocol,
218+
include_protocol=True,
219+
neptune_service=self.neptune_service)
193220
else:
194-
self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True)
221+
self.gremlin = GremlinSection(include_protocol=True,
222+
neptune_service=self.neptune_service)
195223
self.neo4j = Neo4JSection()
196224
else:
197225
self.is_neptune_config = False
@@ -331,11 +359,14 @@ def generate_default_config():
331359
auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
332360
protocol_arg = args.gremlin_connection_protocol
333361
include_protocol = False
362+
gremlin_service = ''
334363
if is_allowed_neptune_host(args.host, args.neptune_hosts):
335364
include_protocol = True
365+
gremlin_service = args.neptune_service
336366
if not protocol_arg:
337367
protocol_arg = DEFAULT_HTTP_PROTOCOL \
338368
if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL
369+
339370
config = generate_config(args.host, int(args.port),
340371
AuthModeEnum(auth_mode_arg),
341372
args.ssl, args.ssl_verify,
@@ -344,7 +375,7 @@ def generate_default_config():
344375
SparqlSection(args.sparql_path, ''),
345376
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
346377
args.gremlin_password, args.gremlin_serializer,
347-
protocol_arg, include_protocol),
378+
protocol_arg, include_protocol, gremlin_service),
348379
Neo4JSection(args.neo4j_username, args.neo4j_password,
349380
args.neo4j_auth, args.neo4j_database),
350381
args.neptune_hosts)

src/graph_notebook/configuration/get_config.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
SparqlSection, GremlinSection, Neo4JSection
1010
from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \
1111
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \
12-
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL
12+
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL, \
13+
DEFAULT_GREMLIN_HTTP_SERIALIZER, DEFAULT_GREMLIN_WS_SERIALIZER, \
14+
normalize_service_name
1315

1416
neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region']
1517
neptune_gremlin_params = ['connection_protocol']
@@ -30,18 +32,24 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I
3032
is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)
3133

3234
if is_neptune_host:
33-
neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME
35+
if 'neptune_service' in data:
36+
neptune_service = normalize_service_name(data['neptune_service'])
37+
else:
38+
neptune_service = NEPTUNE_DB_SERVICE_NAME
3439
if 'gremlin' in data:
35-
data['gremlin']['include_protocol'] = True
3640
if 'connection_protocol' not in data['gremlin']:
3741
data['gremlin']['connection_protocol'] = DEFAULT_WS_PROTOCOL \
3842
if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
39-
gremlin_section = GremlinSection(**data['gremlin'])
43+
gremlin_section = GremlinSection(**data['gremlin'],
44+
include_protocol=True,
45+
neptune_service=neptune_service)
4046
if gremlin_section.to_dict()['traversal_source'] != 'g':
4147
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
4248
else:
4349
protocol = DEFAULT_WS_PROTOCOL if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
44-
gremlin_section = GremlinSection(include_protocol=True, connection_protocol=protocol)
50+
gremlin_section = GremlinSection(include_protocol=True,
51+
connection_protocol=protocol,
52+
neptune_service=neptune_service)
4553
if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \
4654
or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD:
4755
print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n')

src/graph_notebook/magics/graph_magic.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, GREMLIN_EXPLAIN_MODES, \
5555
OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, OPENCYPHER_STATUS_STATE_MODES, \
5656
normalize_service_name, NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, GRAPH_PG_INFO_METRICS, \
57-
DEFAULT_GREMLIN_PROTOCOL, GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
57+
GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
5858
GREMLIN_SERIALIZERS_WS, GREMLIN_SERIALIZERS_CLASS_TO_MIME_MAP, normalize_protocol_name, generate_snapshot_name)
5959
from graph_notebook.network import SPARQLNetwork
6060
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
@@ -1250,11 +1250,18 @@ def gremlin(self, line, cell, local_ns: dict = None):
12501250
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
12511251
if self.client.is_neptune_domain():
12521252
if args.connection_protocol != '':
1253-
connection_protocol = normalize_protocol_name(args.connection_protocol)
1253+
connection_protocol, bad_protocol_input = normalize_protocol_name(args.connection_protocol)
1254+
if bad_protocol_input:
1255+
if self.client.is_analytics_domain():
1256+
connection_protocol = DEFAULT_HTTP_PROTOCOL
1257+
else:
1258+
connection_protocol = DEFAULT_WS_PROTOCOL
1259+
print(f"Connection protocol input is invalid for Neptune, "
1260+
f"defaulting to {connection_protocol}.")
12541261
if connection_protocol == DEFAULT_WS_PROTOCOL and \
12551262
self.graph_notebook_config.gremlin.message_serializer not in GREMLIN_SERIALIZERS_WS:
1256-
print("Unsupported serializer for GremlinPython client, "
1257-
"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
1263+
print(f"Serializer is unsupported for GremlinPython client, "
1264+
f"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
12581265
print("Defaulting to HTTP protocol.")
12591266
connection_protocol = DEFAULT_HTTP_PROTOCOL
12601267
else:

0 commit comments

Comments
 (0)