|
| 1 | +import { TRPCError } from '@trpc/server'; |
1 | 2 | import { inArray } from 'drizzle-orm/expressions';
|
2 | 3 | import { z } from 'zod';
|
3 | 4 |
|
@@ -126,60 +127,75 @@ export const chunkRouter = router({
|
126 | 127 | semanticSearchForChat: chunkProcedure
|
127 | 128 | .input(SemanticSearchSchema)
|
128 | 129 | .mutation(async ({ ctx, input }) => {
|
129 |
| - const item = await ctx.messageModel.findMessageQueriesById(input.messageId); |
130 |
| - const { model, provider } = |
131 |
| - getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; |
132 |
| - let embedding: number[]; |
133 |
| - let ragQueryId: string; |
134 |
| - // if there is no message rag or it's embeddings, then we need to create one |
135 |
| - if (!item || !item.embeddings) { |
136 |
| - // TODO: need to support customize |
137 |
| - const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload); |
138 |
| - |
139 |
| - const embeddings = await agentRuntime.embeddings({ |
140 |
| - dimensions: 1024, |
141 |
| - input: input.rewriteQuery, |
142 |
| - model, |
143 |
| - }); |
144 |
| - |
145 |
| - embedding = embeddings![0]; |
146 |
| - const embeddingsId = await ctx.embeddingModel.create({ |
147 |
| - embeddings: embedding, |
148 |
| - model, |
| 130 | + try { |
| 131 | + const item = await ctx.messageModel.findMessageQueriesById(input.messageId); |
| 132 | + const { model, provider } = |
| 133 | + getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; |
| 134 | + let embedding: number[]; |
| 135 | + let ragQueryId: string; |
| 136 | + |
| 137 | + // if there is no message rag or it's embeddings, then we need to create one |
| 138 | + if (!item || !item.embeddings) { |
| 139 | + // TODO: need to support customize |
| 140 | + const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload); |
| 141 | + |
| 142 | + // slice content to make sure in the context window limit |
| 143 | + const query = |
| 144 | + input.rewriteQuery.length > 8000 |
| 145 | + ? input.rewriteQuery.slice(0, 8000) |
| 146 | + : input.rewriteQuery; |
| 147 | + |
| 148 | + const embeddings = await agentRuntime.embeddings({ |
| 149 | + dimensions: 1024, |
| 150 | + input: query, |
| 151 | + model, |
| 152 | + }); |
| 153 | + |
| 154 | + embedding = embeddings![0]; |
| 155 | + const embeddingsId = await ctx.embeddingModel.create({ |
| 156 | + embeddings: embedding, |
| 157 | + model, |
| 158 | + }); |
| 159 | + |
| 160 | + const result = await ctx.messageModel.createMessageQuery({ |
| 161 | + embeddingsId, |
| 162 | + messageId: input.messageId, |
| 163 | + rewriteQuery: input.rewriteQuery, |
| 164 | + userQuery: input.userQuery, |
| 165 | + }); |
| 166 | + |
| 167 | + ragQueryId = result.id; |
| 168 | + } else { |
| 169 | + embedding = item.embeddings; |
| 170 | + ragQueryId = item.id; |
| 171 | + } |
| 172 | + |
| 173 | + let finalFileIds = input.fileIds ?? []; |
| 174 | + |
| 175 | + if (input.knowledgeIds && input.knowledgeIds.length > 0) { |
| 176 | + const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({ |
| 177 | + where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds), |
| 178 | + }); |
| 179 | + |
| 180 | + finalFileIds = knowledgeFiles.map((f) => f.fileId).concat(finalFileIds); |
| 181 | + } |
| 182 | + |
| 183 | + const chunks = await ctx.chunkModel.semanticSearchForChat({ |
| 184 | + embedding, |
| 185 | + fileIds: finalFileIds, |
| 186 | + query: input.rewriteQuery, |
149 | 187 | });
|
150 | 188 |
|
151 |
| - const result = await ctx.messageModel.createMessageQuery({ |
152 |
| - embeddingsId, |
153 |
| - messageId: input.messageId, |
154 |
| - rewriteQuery: input.rewriteQuery, |
155 |
| - userQuery: input.userQuery, |
156 |
| - }); |
| 189 | + // TODO: need to rerank the chunks |
157 | 190 |
|
158 |
| - ragQueryId = result.id; |
159 |
| - } else { |
160 |
| - embedding = item.embeddings; |
161 |
| - ragQueryId = item.id; |
162 |
| - } |
| 191 | + return { chunks, queryId: ragQueryId }; |
| 192 | + } catch (e) { |
| 193 | + console.error(e); |
163 | 194 |
|
164 |
| - console.time('semanticSearch'); |
165 |
| - let finalFileIds = input.fileIds ?? []; |
166 |
| - |
167 |
| - if (input.knowledgeIds && input.knowledgeIds.length > 0) { |
168 |
| - const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({ |
169 |
| - where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds), |
| 195 | + throw new TRPCError({ |
| 196 | + code: 'INTERNAL_SERVER_ERROR', |
| 197 | + message: (e as any).errorType || JSON.stringify(e), |
170 | 198 | });
|
171 |
| - |
172 |
| - finalFileIds = knowledgeFiles.map((f) => f.fileId).concat(finalFileIds); |
173 | 199 | }
|
174 |
| - |
175 |
| - const chunks = await ctx.chunkModel.semanticSearchForChat({ |
176 |
| - embedding, |
177 |
| - fileIds: finalFileIds, |
178 |
| - query: input.rewriteQuery, |
179 |
| - }); |
180 |
| - // TODO: need to rerank the chunks |
181 |
| - console.timeEnd('semanticSearch'); |
182 |
| - |
183 |
| - return { chunks, queryId: ragQueryId }; |
184 | 200 | }),
|
185 | 201 | });
|
0 commit comments