|
1 |
| -import { Column, count, sql } from 'drizzle-orm'; |
2 |
| -import { and, desc, eq, exists, gt, inArray, isNull, like, or } from 'drizzle-orm/expressions'; |
| 1 | +import { count, sql } from 'drizzle-orm'; |
| 2 | +import { and, desc, eq, gt, inArray, isNull, like } from 'drizzle-orm/expressions'; |
3 | 3 |
|
4 | 4 | import { LobeChatDatabase } from '@/database/type';
|
5 | 5 | import {
|
@@ -79,27 +79,57 @@ export class TopicModel {
|
79 | 79 |
|
80 | 80 | const keywordLowerCase = keyword.toLowerCase();
|
81 | 81 |
|
82 |
| - const matchKeyword = (field: any) => |
83 |
| - like(sql`lower(${field})` as unknown as Column, `%${keywordLowerCase}%`); |
84 |
| - |
85 |
| - return this.db.query.topics.findMany({ |
| 82 | + // 查询标题匹配的主题 |
| 83 | + const topicsByTitle = await this.db.query.topics.findMany({ |
86 | 84 | orderBy: [desc(topics.updatedAt)],
|
87 | 85 | where: and(
|
88 | 86 | eq(topics.userId, this.userId),
|
89 | 87 | this.matchSession(sessionId),
|
90 |
| - or( |
91 |
| - matchKeyword(topics.title), |
92 |
| - exists( |
93 |
| - this.db |
94 |
| - .select() |
95 |
| - .from(messages) |
96 |
| - .where(and(eq(messages.topicId, topics.id), matchKeyword(messages.content))), |
97 |
| - ), |
98 |
| - ), |
| 88 | + like(topics.title, `%${keywordLowerCase}%`), |
99 | 89 | ),
|
100 | 90 | });
|
101 |
| - }; |
102 | 91 |
|
| 92 | + // 查询消息内容匹配的主题ID |
| 93 | + const topicIdsByMessages = await this.db |
| 94 | + .select({ topicId: messages.topicId }) |
| 95 | + .from(messages) |
| 96 | + .innerJoin(topics, eq(messages.topicId, topics.id)) |
| 97 | + .where( |
| 98 | + and( |
| 99 | + eq(messages.userId, this.userId), |
| 100 | + like(messages.content, `%${keywordLowerCase}%`), |
| 101 | + eq(topics.userId, this.userId), |
| 102 | + this.matchSession(sessionId), |
| 103 | + ), |
| 104 | + ) |
| 105 | + .groupBy(messages.topicId); |
| 106 | + // 如果没有通过消息内容找到主题,直接返回标题匹配的主题 |
| 107 | + if (topicIdsByMessages.length === 0) { |
| 108 | + return topicsByTitle; |
| 109 | + } |
| 110 | + |
| 111 | + // 查询通过消息内容找到的主题 |
| 112 | + const topicIds = topicIdsByMessages.map((t) => t.topicId); |
| 113 | + const topicsByMessages = await this.db.query.topics.findMany({ |
| 114 | + orderBy: [desc(topics.updatedAt)], |
| 115 | + where: and(eq(topics.userId, this.userId), inArray(topics.id, topicIds)), |
| 116 | + }); |
| 117 | + |
| 118 | + // 合并结果并去重 |
| 119 | + const allTopics = [...topicsByTitle]; |
| 120 | + const existingIds = new Set(topicsByTitle.map((t) => t.id)); |
| 121 | + |
| 122 | + for (const topic of topicsByMessages) { |
| 123 | + if (!existingIds.has(topic.id)) { |
| 124 | + allTopics.push(topic); |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + // 按更新时间排序 |
| 129 | + return allTopics.sort( |
| 130 | + (a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(), |
| 131 | + ); |
| 132 | + }; |
103 | 133 | count = async (params?: {
|
104 | 134 | endDate?: string;
|
105 | 135 | range?: [string, string];
|
|
0 commit comments