Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix: refactor knowledge base issue #6973

Merged
merged 3 commits into from
Mar 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/const/settings/agent.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { DEFAULT_AGENT_META } from '@/const/meta';
import { DEFAULT_MODEL } from '@/const/settings/llm';
import { ModelProvider } from '@/libs/agent-runtime';
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '@/const/settings/llm';
import { LobeAgentChatConfig, LobeAgentConfig, LobeAgentTTSConfig } from '@/types/agent';
import { UserDefaultAgent } from '@/types/user/settings';

@@ -15,7 +14,7 @@ export const DEFAUTT_AGENT_TTS_CONFIG: LobeAgentTTSConfig = {

export const DEFAULT_AGENT_SEARCH_FC_MODEL = {
model: DEFAULT_MODEL,
provider: ModelProvider.OpenAI,
provider: DEFAULT_PROVIDER,
};

export const DEFAULT_AGENT_CHAT_CONFIG: LobeAgentChatConfig = {
@@ -41,7 +40,7 @@ export const DEFAULT_AGENT_CONFIG: LobeAgentConfig = {
top_p: 1,
},
plugins: [],
provider: ModelProvider.OpenAI,
provider: DEFAULT_PROVIDER,
systemRole: '',
tts: DEFAUTT_AGENT_TTS_CONFIG,
};
8 changes: 4 additions & 4 deletions src/database/server/models/__tests__/chunk.test.ts
Original file line number Diff line number Diff line change
@@ -495,13 +495,13 @@ content in Table html is below:
});

// 测试结果限制
it('should limit results to 5 items', async () => {
it('should limit results to 15 items', async () => {
const fileId = '1';
// Create 6 chunks
// Create 24 chunks
const chunkResult = await serverDB
.insert(chunks)
.values(
Array(6)
Array(24)
.fill(0)
.map((_, i) => ({ text: `Test Chunk ${i}`, userId })),
)
@@ -528,7 +528,7 @@ content in Table html is below:
query: 'test',
});

expect(result).toHaveLength(5);
expect(result).toHaveLength(15);
});
});
});
3 changes: 2 additions & 1 deletion src/database/server/models/chunk.ts
Original file line number Diff line number Diff line change
@@ -207,7 +207,8 @@ export class ChunkModel {
.leftJoin(files, eq(files.id, fileChunks.fileId))
.where(inArray(fileChunks.fileId, fileIds))
.orderBy((t) => desc(t.similarity))
.limit(5);
// 先放宽到 15
.limit(15);

return result.map((item) => {
return {
11 changes: 10 additions & 1 deletion src/libs/agent-runtime/anthropic/index.ts
Original file line number Diff line number Diff line change
@@ -38,6 +38,10 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
apiKey?: string;
private id: string;

private isDebug() {
return process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1';
}

constructor({ apiKey, baseURL = DEFAULT_BASE_URL, id, ...res }: AnthropicAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

@@ -51,6 +55,11 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
try {
const anthropicPayload = await this.buildAnthropicPayload(payload);

if (this.isDebug()) {
console.log('[requestPayload]');
console.log(JSON.stringify(anthropicPayload), '\n');
}

const response = await this.client.messages.create(
{ ...anthropicPayload, stream: true },
{
@@ -60,7 +69,7 @@ export class LobeAnthropicAI implements LobeRuntimeAI {

const [prod, debug] = response.tee();

if (process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1') {
if (this.isDebug()) {
debugStream(debug.toReadableStream()).catch(console.error);
}

9 changes: 6 additions & 3 deletions src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@ import {
OpenAIChatMessage,
UserMessageContentPart,
} from '../types';
import { ModelProvider } from '../types/type';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
@@ -77,6 +76,7 @@ interface LobeGoogleAIParams {
apiKey?: string;
baseURL?: string;
client?: GoogleGenerativeAI | VertexAI;
id?: string;
isVertexAi?: boolean;
}

@@ -85,15 +85,18 @@ export class LobeGoogleAI implements LobeRuntimeAI {
private isVertexAi: boolean;
baseURL?: string;
apiKey?: string;
provider: string;

constructor({ apiKey, baseURL, client, isVertexAi }: LobeGoogleAIParams = {}) {
constructor({ apiKey, baseURL, client, isVertexAi, id }: LobeGoogleAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new GoogleGenerativeAI(apiKey);
this.apiKey = apiKey;
this.client = client ? (client as GoogleGenerativeAI) : new GoogleGenerativeAI(apiKey);
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
this.isVertexAi = isVertexAi || false;

this.provider = id || (isVertexAi ? 'vertexai' : 'google');
}

async chat(rawPayload: ChatStreamPayload, options?: ChatCompetitionOptions) {
@@ -168,7 +171,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
console.log(err);
const { errorType, error } = this.parseErrorMessage(err.message);

throw AgentRuntimeError.chat({ error, errorType, provider: ModelProvider.Google });
throw AgentRuntimeError.chat({ error, errorType, provider: this.provider });
}
}

95 changes: 47 additions & 48 deletions src/libs/agent-runtime/runtimeMap.ts
Original file line number Diff line number Diff line change
@@ -37,7 +37,6 @@ import { LobeStepfunAI } from './stepfun';
import { LobeTaichuAI } from './taichu';
import { LobeTencentCloudAI } from './tencentcloud';
import { LobeTogetherAI } from './togetherai';
import { ModelProvider } from './types';
import { LobeUpstageAI } from './upstage';
import { LobeVLLMAI } from './vllm';
import { LobeVolcengineAI } from './volcengine';
@@ -47,51 +46,51 @@ import { LobeZeroOneAI } from './zeroone';
import { LobeZhipuAI } from './zhipu';

export const providerRuntimeMap = {
[ModelProvider.OpenAI]: LobeOpenAI,
[ModelProvider.Azure]: LobeAzureOpenAI,
[ModelProvider.AzureAI]: LobeAzureAI,
[ModelProvider.ZhiPu]: LobeZhipuAI,
[ModelProvider.Google]: LobeGoogleAI,
[ModelProvider.Moonshot]: LobeMoonshotAI,
[ModelProvider.Bedrock]: LobeBedrockAI,
[ModelProvider.LMStudio]: LobeLMStudioAI,
[ModelProvider.Ollama]: LobeOllamaAI,
[ModelProvider.VLLM]: LobeVLLMAI,
[ModelProvider.Perplexity]: LobePerplexityAI,
[ModelProvider.Anthropic]: LobeAnthropicAI,
[ModelProvider.DeepSeek]: LobeDeepSeekAI,
[ModelProvider.HuggingFace]: LobeHuggingFaceAI,
[ModelProvider.Minimax]: LobeMinimaxAI,
[ModelProvider.Mistral]: LobeMistralAI,
[ModelProvider.Groq]: LobeGroq,
[ModelProvider.Github]: LobeGithubAI,
[ModelProvider.OpenRouter]: LobeOpenRouterAI,
[ModelProvider.TogetherAI]: LobeTogetherAI,
[ModelProvider.FireworksAI]: LobeFireworksAI,
[ModelProvider.ZeroOne]: LobeZeroOneAI,
[ModelProvider.Stepfun]: LobeStepfunAI,
[ModelProvider.Qwen]: LobeQwenAI,
[ModelProvider.Novita]: LobeNovitaAI,
[ModelProvider.Nvidia]: LobeNvidiaAI,
[ModelProvider.Taichu]: LobeTaichuAI,
[ModelProvider.Baichuan]: LobeBaichuanAI,
[ModelProvider.Ai360]: LobeAi360AI,
[ModelProvider.SiliconCloud]: LobeSiliconCloudAI,
[ModelProvider.GiteeAI]: LobeGiteeAI,
[ModelProvider.Upstage]: LobeUpstageAI,
[ModelProvider.Spark]: LobeSparkAI,
[ModelProvider.Ai21]: LobeAi21AI,
[ModelProvider.Hunyuan]: LobeHunyuanAI,
[ModelProvider.SenseNova]: LobeSenseNovaAI,
[ModelProvider.XAI]: LobeXAI,
[ModelProvider.Jina]: LobeJinaAI,
[ModelProvider.SambaNova]: LobeSambaNovaAI,
[ModelProvider.Cloudflare]: LobeCloudflareAI,
[ModelProvider.InternLM]: LobeInternLMAI,
[ModelProvider.Higress]: LobeHigressAI,
[ModelProvider.TencentCloud]: LobeTencentCloudAI,
[ModelProvider.Volcengine]: LobeVolcengineAI,
[ModelProvider.PPIO]: LobePPIOAI,
[ModelProvider.Doubao]: LobeVolcengineAI,
[ModelProvider.Wenxin]: LobeWenxinAI,
ai21: LobeAi21AI,
ai360: LobeAi360AI,
anthropic: LobeAnthropicAI,
azure: LobeAzureOpenAI,
azureai: LobeAzureAI,
baichuan: LobeBaichuanAI,
bedrock: LobeBedrockAI,
cloudflare: LobeCloudflareAI,
deepseek: LobeDeepSeekAI,
doubao: LobeVolcengineAI,
fireworksai: LobeFireworksAI,
giteeai: LobeGiteeAI,
github: LobeGithubAI,
google: LobeGoogleAI,
groq: LobeGroq,
higress: LobeHigressAI,
huggingface: LobeHuggingFaceAI,
hunyuan: LobeHunyuanAI,
internlm: LobeInternLMAI,
jina: LobeJinaAI,
lmstudio: LobeLMStudioAI,
minimax: LobeMinimaxAI,
mistral: LobeMistralAI,
moonshot: LobeMoonshotAI,
novita: LobeNovitaAI,
nvidia: LobeNvidiaAI,
ollama: LobeOllamaAI,
openai: LobeOpenAI,
openrouter: LobeOpenRouterAI,
perplexity: LobePerplexityAI,
ppio: LobePPIOAI,
qwen: LobeQwenAI,
sambanova: LobeSambaNovaAI,
sensenova: LobeSenseNovaAI,
siliconcloud: LobeSiliconCloudAI,
spark: LobeSparkAI,
stepfun: LobeStepfunAI,
taichu: LobeTaichuAI,
tencentcloud: LobeTencentCloudAI,
togetherai: LobeTogetherAI,
upstage: LobeUpstageAI,
vllm: LobeVLLMAI,
volcengine: LobeVolcengineAI,
wenxin: LobeWenxinAI,
xai: LobeXAI,
zeroone: LobeZeroOneAI,
zhipu: LobeZhipuAI,
};
17 changes: 10 additions & 7 deletions src/server/routers/lambda/agent.ts
Original file line number Diff line number Diff line change
@@ -122,13 +122,16 @@ export const agentRouter = router({
const knowledge = await ctx.agentModel.getAgentAssignedKnowledge(input.agentId);

return [
...files.map((file) => ({
enabled: knowledge.files.some((item) => item.id === file.id),
fileType: file.fileType,
id: file.id,
name: file.name,
type: KnowledgeType.File,
})),
...files
// 过滤掉所有图片
.filter((file) => !file.fileType.startsWith('image'))
.map((file) => ({
enabled: knowledge.files.some((item) => item.id === file.id),
fileType: file.fileType,
id: file.id,
name: file.name,
type: KnowledgeType.File,
})),
...knowledgeBases.map((knowledgeBase) => ({
avatar: knowledgeBase.avatar,
description: knowledgeBase.description,
114 changes: 65 additions & 49 deletions src/server/routers/lambda/chunk.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { TRPCError } from '@trpc/server';
import { inArray } from 'drizzle-orm/expressions';
import { z } from 'zod';

@@ -126,60 +127,75 @@ export const chunkRouter = router({
semanticSearchForChat: chunkProcedure
.input(SemanticSearchSchema)
.mutation(async ({ ctx, input }) => {
const item = await ctx.messageModel.findMessageQueriesById(input.messageId);
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
let embedding: number[];
let ragQueryId: string;
// if there is no message rag or it's embeddings, then we need to create one
if (!item || !item.embeddings) {
// TODO: need to support customize
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);

const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: input.rewriteQuery,
model,
});

embedding = embeddings![0];
const embeddingsId = await ctx.embeddingModel.create({
embeddings: embedding,
model,
try {
const item = await ctx.messageModel.findMessageQueriesById(input.messageId);
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
let embedding: number[];
let ragQueryId: string;

// if there is no message rag or it's embeddings, then we need to create one
if (!item || !item.embeddings) {
// TODO: need to support customize
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);

// slice content to make sure in the context window limit
const query =
input.rewriteQuery.length > 8000
? input.rewriteQuery.slice(0, 8000)
: input.rewriteQuery;

const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: query,
model,
});

embedding = embeddings![0];
const embeddingsId = await ctx.embeddingModel.create({
embeddings: embedding,
model,
});

const result = await ctx.messageModel.createMessageQuery({
embeddingsId,
messageId: input.messageId,
rewriteQuery: input.rewriteQuery,
userQuery: input.userQuery,
});

ragQueryId = result.id;
} else {
embedding = item.embeddings;
ragQueryId = item.id;
}

let finalFileIds = input.fileIds ?? [];

if (input.knowledgeIds && input.knowledgeIds.length > 0) {
const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({
where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds),
});

finalFileIds = knowledgeFiles.map((f) => f.fileId).concat(finalFileIds);
}

const chunks = await ctx.chunkModel.semanticSearchForChat({
embedding,
fileIds: finalFileIds,
query: input.rewriteQuery,
});

const result = await ctx.messageModel.createMessageQuery({
embeddingsId,
messageId: input.messageId,
rewriteQuery: input.rewriteQuery,
userQuery: input.userQuery,
});
// TODO: need to rerank the chunks

ragQueryId = result.id;
} else {
embedding = item.embeddings;
ragQueryId = item.id;
}
return { chunks, queryId: ragQueryId };
} catch (e) {
console.error(e);

console.time('semanticSearch');
let finalFileIds = input.fileIds ?? [];

if (input.knowledgeIds && input.knowledgeIds.length > 0) {
const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({
where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds),
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: (e as any).errorType || JSON.stringify(e),
});

finalFileIds = knowledgeFiles.map((f) => f.fileId).concat(finalFileIds);
}

const chunks = await ctx.chunkModel.semanticSearchForChat({
embedding,
fileIds: finalFileIds,
query: input.rewriteQuery,
});
// TODO: need to rerank the chunks
console.timeEnd('semanticSearch');

return { chunks, queryId: ragQueryId };
}),
});
30 changes: 18 additions & 12 deletions src/store/chat/slices/aiChat/actions/rag.ts
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ export interface ChatRAGAction {
id: string,
userQuery: string,
messages: string[],
) => Promise<{ chunks: ChatSemanticSearchChunk[]; queryId: string; rewriteQuery?: string }>;
) => Promise<{ chunks: ChatSemanticSearchChunk[]; queryId?: string; rewriteQuery?: string }>;
/**
* Rewrite user content to better RAG query
*/
@@ -74,17 +74,23 @@ export const chatRag: StateCreator<ChatStore, [['zustand/devtools', never]], [],

// 2. retrieve chunks from semantic search
const files = chatSelectors.currentUserFiles(get()).map((f) => f.id);
const { chunks, queryId } = await ragService.semanticSearchForChat({
fileIds: knowledgeIds().fileIds.concat(files),
knowledgeIds: knowledgeIds().knowledgeBaseIds,
messageId: id,
rewriteQuery: rewriteQuery || userQuery,
userQuery,
});

get().internal_toggleMessageRAGLoading(false, id);

return { chunks, queryId, rewriteQuery };
try {
const { chunks, queryId } = await ragService.semanticSearchForChat({
fileIds: knowledgeIds().fileIds.concat(files),
knowledgeIds: knowledgeIds().knowledgeBaseIds,
messageId: id,
rewriteQuery: rewriteQuery || userQuery,
userQuery,
});

get().internal_toggleMessageRAGLoading(false, id);

return { chunks, queryId, rewriteQuery };
} catch {
get().internal_toggleMessageRAGLoading(false, id);

return { chunks: [] };
}
},
internal_rewriteQuery: async (id, content, messages) => {
let rewriteQuery = content;