forked from embeddings-benchmark/mteb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsentence_transformer_wrapper.py
126 lines (109 loc) · 4.64 KB
/
sentence_transformer_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
import numpy as np
import torch
from sentence_transformers import CrossEncoder, SentenceTransformer
from mteb.encoder_interface import PromptType
from mteb.models.wrapper import Wrapper
logger = logging.getLogger(__name__)
class SentenceTransformerWrapper(Wrapper):
def __init__(
self,
model: str | SentenceTransformer | CrossEncoder,
revision: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs,
) -> None:
"""Wrapper for SentenceTransformer models.
Args:
model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
revision: The revision of the model to use.
model_prompts: A dictionary mapping task names to prompt names.
First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
then to the composed prompt of task type + prompt type, then to the specific task type prompt,
and finally to the specific prompt type.
**kwargs: Additional arguments to pass to the SentenceTransformer model.
"""
if isinstance(model, str):
self.model = SentenceTransformer(model, revision=revision, **kwargs)
else:
self.model = model
if (
model_prompts is None
and hasattr(self.model, "prompts")
and len(self.model.prompts) > 0
):
try:
model_prompts = self.validate_task_to_prompt_name(self.model.prompts)
except KeyError:
model_prompts = None
logger.warning(
"Model prompts are not in the expected format. Ignoring them."
)
elif model_prompts is not None and hasattr(self.model, "prompts"):
logger.info(f"Model prompts will be overwritten with {model_prompts}")
self.model.prompts = model_prompts
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)
if isinstance(self.model, CrossEncoder):
self.predict = self._predict
if hasattr(self.model, "similarity") and callable(self.model.similarity):
self.similarity = self.model.similarity
def encode(
self,
sentences: Sequence[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
"""Encodes the given sentences using the encoder.
Args:
sentences: The sentences to encode.
task_name: The name of the task. Sentence-transformers uses this to
determine which prompt to use from a specified dictionary.
prompt_type: The name type of prompt. (query or passage)
**kwargs: Additional arguments to pass to the encoder.
The order of priorities for prompt selection are:
1. Composed prompt of task name + prompt type (query or passage)
2. Specific task prompt
3. Composed prompt of task type + prompt type (query or passage)
4. Specific task type prompt
5. Specific prompt type (query or passage)
Returns:
The encoded sentences.
"""
prompt_name = None
if self.model_prompts is not None:
prompt_name = self.get_prompt_name(
self.model_prompts, task_name, prompt_type
)
if prompt_name:
logger.info(
f"Using prompt_name={prompt_name} for task={task_name} prompt_type={prompt_type}"
)
else:
logger.info(
f"No model prompts found for task={task_name} prompt_type={prompt_type}"
)
logger.info(f"Encoding {len(sentences)} sentences.")
embeddings = self.model.encode(
sentences,
prompt_name=prompt_name,
**kwargs,
)
if isinstance(embeddings, torch.Tensor):
# sometimes in kwargs can be return_tensors=True
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings
def _predict(
self,
sentences: Sequence[str],
**kwargs: Any,
) -> np.ndarray:
return self.model.predict(
sentences,
convert_to_numpy=True,
**kwargs,
)