Skip to content

Commit c9fb2de

Browse files
authored
🐛 fix: fix the missing user id in chat compeletition and fix remove unstarred topic not working (#2677)
* 🐛 fix: fix remove all topic not working * 🐛 fix: fix remove all topic not working * 🐛 fix: fix user id missing in chat competition
1 parent 3fc4265 commit c9fb2de

File tree

12 files changed

+76
-51
lines changed

12 files changed

+76
-51
lines changed

src/app/api/chat/[provider]/route.test.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ describe('POST handler', () => {
159159
accessCode: 'test-access-code',
160160
apiKey: 'test-api-key',
161161
azureApiVersion: 'v1',
162+
userId: 'abc',
162163
});
163164

164165
const mockParams = { provider: 'test-provider' };
@@ -176,7 +177,7 @@ describe('POST handler', () => {
176177
const response = await POST(request as unknown as Request, { params: mockParams });
177178

178179
expect(response).toEqual(mockChatResponse);
179-
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload);
180+
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload, { user: 'abc' });
180181
});
181182

182183
it('should return an error response when chat completion fails', async () => {

src/app/api/chat/[provider]/route.ts

+7-8
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,16 @@ export const POST = checkAuth(async (req: Request, { params, jwtPayload }) => {
2525

2626
const tracePayload = getTracePayload(req);
2727

28+
let traceOptions = {};
2829
// If user enable trace
2930
if (tracePayload?.enabled) {
30-
return await agentRuntime.chat(
31-
data,
32-
createTraceOptions(data, {
33-
provider,
34-
trace: tracePayload,
35-
}),
36-
);
31+
traceOptions = createTraceOptions(data, {
32+
provider,
33+
trace: tracePayload,
34+
});
3735
}
38-
return await agentRuntime.chat(data);
36+
37+
return await agentRuntime.chat(data, { user: jwtPayload.userId, ...traceOptions });
3938
} catch (e) {
4039
const {
4140
errorType = ChatErrorType.InternalServerError,

src/const/auth.ts

+6
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,11 @@ export interface JWTPayload {
3535
awsAccessKeyId?: string;
3636
awsRegion?: string;
3737
awsSecretAccessKey?: string;
38+
/**
39+
* user id
40+
* in client db mode it's a uuid
41+
* in server db mode it's a user id
42+
*/
43+
userId?: string;
3844
}
3945
/* eslint-enable */

src/libs/agent-runtime/types/chat.ts

+4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ export interface ChatCompetitionOptions {
9595
callback?: ChatStreamCallbacks;
9696
headers?: Record<string, any>;
9797
signal?: AbortSignal;
98+
/**
99+
* userId for the chat completion
100+
*/
101+
user?: string;
98102
}
99103

100104
export interface ChatCompletionFunctions {

src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts

+10-7
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,21 @@ describe('LobeOpenAICompatibleFactory', () => {
195195
});
196196

197197
describe('handlePayload option', () => {
198-
it('should modify request payload correctly', async () => {
198+
it('should add user in payload correctly', async () => {
199199
const mockCreateMethod = vi.spyOn(instance['client'].chat.completions, 'create');
200200

201-
await instance.chat({
202-
messages: [{ content: 'Hello', role: 'user' }],
203-
model: 'mistralai/mistral-7b-instruct:free',
204-
temperature: 0,
205-
});
201+
await instance.chat(
202+
{
203+
messages: [{ content: 'Hello', role: 'user' }],
204+
model: 'mistralai/mistral-7b-instruct:free',
205+
temperature: 0,
206+
},
207+
{ user: 'abc' },
208+
);
206209

207210
expect(mockCreateMethod).toHaveBeenCalledWith(
208211
expect.objectContaining({
209-
// 根据实际的 handlePayload 函数,添加断言
212+
user: 'abc',
210213
}),
211214
expect.anything(),
212215
);

src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts

+8-5
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,14 @@ export const LobeOpenAICompatibleFactory = ({
7979
stream: payload.stream ?? true,
8080
} as OpenAI.ChatCompletionCreateParamsStreaming);
8181

82-
const response = await this.client.chat.completions.create(postPayload, {
83-
// https://github.com/lobehub/lobe-chat/pull/318
84-
headers: { Accept: '*/*' },
85-
signal: options?.signal,
86-
});
82+
const response = await this.client.chat.completions.create(
83+
{ ...postPayload, user: options?.user },
84+
{
85+
// https://github.com/lobehub/lobe-chat/pull/318
86+
headers: { Accept: '*/*' },
87+
signal: options?.signal,
88+
},
89+
);
8790

8891
if (postPayload.stream) {
8992
const [prod, useForDebug] = response.tee();

src/services/_auth.ts

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import { JWTPayload, LOBE_CHAT_AUTH_HEADER } from '@/const/auth';
22
import { ModelProvider } from '@/libs/agent-runtime';
33
import { useUserStore } from '@/store/user';
4-
import { keyVaultsConfigSelectors, settingsSelectors } from '@/store/user/selectors';
4+
import {
5+
keyVaultsConfigSelectors,
6+
settingsSelectors,
7+
userProfileSelectors,
8+
} from '@/store/user/selectors';
59
import { GlobalLLMProviderKey } from '@/types/user/settings';
610
import { createJWT } from '@/utils/jwt';
711

@@ -48,8 +52,9 @@ export const getProviderAuthPayload = (provider: string) => {
4852

4953
const createAuthTokenWithPayload = async (payload = {}) => {
5054
const accessCode = settingsSelectors.password(useUserStore.getState());
55+
const userId = userProfileSelectors.userId(useUserStore.getState());
5156

52-
return await createJWT<JWTPayload>({ accessCode, ...payload });
57+
return await createJWT<JWTPayload>({ accessCode, userId, ...payload });
5358
};
5459

5560
interface AuthParams {

src/store/chat/slices/topic/action.test.ts

+12-6
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,14 @@ describe('topic action', () => {
384384
// Set up mock state with unstarred topics
385385
await act(async () => {
386386
useChatStore.setState({
387-
topics: [
388-
{ id: 'topic-1', favorite: false },
389-
{ id: 'topic-2', favorite: true },
390-
{ id: 'topic-3', favorite: false },
391-
] as ChatTopic[],
387+
activeId: 'abc',
388+
topicMaps: {
389+
abc: [
390+
{ id: 'topic-1', favorite: false },
391+
{ id: 'topic-2', favorite: true },
392+
{ id: 'topic-3', favorite: false },
393+
] as ChatTopic[],
394+
},
392395
});
393396
});
394397
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');
@@ -431,7 +434,10 @@ describe('topic action', () => {
431434
});
432435

433436
// Mock the `updateTopicTitleInSummary` and `refreshTopic` for spying
434-
const updateTopicTitleInSummarySpy = vi.spyOn(result.current, 'updateTopicTitleInSummary');
437+
const updateTopicTitleInSummarySpy = vi.spyOn(
438+
result.current,
439+
'internal_updateTopicTitleInSummary',
440+
);
435441
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');
436442

437443
// Mock the `chatService.fetchPresetTaskResult` to simulate the AI response

src/store/chat/slices/topic/action.ts

+17-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// DON'T REMOVE THE FIRST LINE
44
import isEqual from 'fast-deep-equal';
55
import { t } from 'i18next';
6-
import { produce } from 'immer';
76
import useSWR, { SWRResponse, mutate } from 'swr';
87
import { StateCreator } from 'zustand/vanilla';
98

@@ -37,19 +36,19 @@ export interface ChatTopicAction {
3736
removeAllTopics: () => Promise<void>;
3837
removeSessionTopics: () => Promise<void>;
3938
removeTopic: (id: string) => Promise<void>;
40-
removeUnstarredTopic: () => void;
39+
removeUnstarredTopic: () => Promise<void>;
4140
saveToTopic: () => Promise<string | undefined>;
4241
createTopic: () => Promise<string | undefined>;
4342

4443
autoRenameTopicTitle: (id: string) => Promise<void>;
4544
duplicateTopic: (id: string) => Promise<void>;
4645
summaryTopicTitle: (topicId: string, messages: ChatMessage[]) => Promise<void>;
4746
switchTopic: (id?: string, skipRefreshMessage?: boolean) => Promise<void>;
48-
updateTopicTitleInSummary: (id: string, title: string) => void;
4947
updateTopicTitle: (id: string, title: string) => Promise<void>;
5048
useFetchTopics: (sessionId: string) => SWRResponse<ChatTopic[]>;
5149
useSearchTopics: (keywords?: string, sessionId?: string) => SWRResponse<ChatTopic[]>;
5250

51+
internal_updateTopicTitleInSummary: (id: string, title: string) => void;
5352
internal_updateTopicLoading: (id: string, loading: boolean) => void;
5453
internal_createTopic: (params: CreateTopicParams) => Promise<string>;
5554
internal_updateTopic: (id: string, data: Partial<ChatTopic>) => Promise<void>;
@@ -133,18 +132,18 @@ export const chatTopic: StateCreator<
133132
},
134133
// update
135134
summaryTopicTitle: async (topicId, messages) => {
136-
const { updateTopicTitleInSummary, internal_updateTopicLoading } = get();
135+
const { internal_updateTopicTitleInSummary, internal_updateTopicLoading } = get();
137136
const topic = topicSelectors.getTopicById(topicId)(get());
138137
if (!topic) return;
139138

140-
updateTopicTitleInSummary(topicId, LOADING_FLAT);
139+
internal_updateTopicTitleInSummary(topicId, LOADING_FLAT);
141140

142141
let output = '';
143142

144143
// 自动总结话题标题
145144
await chatService.fetchPresetTaskResult({
146145
onError: () => {
147-
updateTopicTitleInSummary(topicId, topic.title);
146+
internal_updateTopicTitleInSummary(topicId, topic.title);
148147
},
149148
onFinish: async (text) => {
150149
await get().internal_updateTopic(topicId, { title: text });
@@ -159,7 +158,7 @@ export const chatTopic: StateCreator<
159158
}
160159
}
161160

162-
updateTopicTitleInSummary(topicId, output);
161+
internal_updateTopicTitleInSummary(topicId, output);
163162
},
164163
params: await chainSummaryTitle(messages),
165164
trace: get().getCurrentTracePayload({ traceName: TraceNameMap.SummaryTopicTitle, topicId }),
@@ -264,15 +263,11 @@ export const chatTopic: StateCreator<
264263
},
265264

266265
// Internal process method of the topics
267-
updateTopicTitleInSummary: (id, title) => {
268-
const topics = produce(get().topics, (draftState) => {
269-
const topic = draftState.find((i) => i.id === id);
270-
271-
if (!topic) return;
272-
topic.title = title;
273-
});
274-
275-
set({ topics }, false, n(`updateTopicTitleInSummary`, { id, title }));
266+
internal_updateTopicTitleInSummary: (id, title) => {
267+
get().internal_dispatchTopic(
268+
{ type: 'updateTopic', id, value: { title } },
269+
'updateTopicTitleInSummary',
270+
);
276271
},
277272
refreshTopic: async () => {
278273
return mutate([SWR_USE_FETCH_TOPIC, get().activeId]);
@@ -317,8 +312,12 @@ export const chatTopic: StateCreator<
317312
},
318313

319314
internal_dispatchTopic: (payload, action) => {
320-
const nextTopics = topicReducer(get().topics, payload);
315+
const nextTopics = topicReducer(topicSelectors.currentTopics(get()), payload);
316+
const nextMap = { ...get().topicMaps, [get().activeId]: nextTopics };
317+
318+
// no need to update map if is the same
319+
if (isEqual(nextMap, get().topicMaps)) return;
321320

322-
set({ topics: nextTopics }, false, action ?? n(`dispatchTopic/${payload.type}`));
321+
set({ topicMaps: nextMap }, false, action ?? n(`dispatchTopic/${payload.type}`));
323322
},
324323
});

src/store/chat/slices/topic/initialState.ts

-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ export interface ChatTopicState {
99
topicMaps: Record<string, ChatTopic[]>;
1010
topicRenamingId?: string;
1111
topicSearchKeywords: string;
12-
topics: ChatTopic[];
1312
/**
1413
* whether topics have fetched
1514
*/
@@ -23,6 +22,5 @@ export const initialTopicState: ChatTopicState = {
2322
topicLoadingIds: [],
2423
topicMaps: {},
2524
topicSearchKeywords: '',
26-
topics: [],
2725
topicsInit: false,
2826
};

src/store/chat/slices/topic/selectors.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ describe('topicSelectors', () => {
5858

5959
describe('currentUnFavTopics', () => {
6060
it('should return all unfavorited topics', () => {
61-
const state = merge(initialStore, { topics: topicMaps.test });
61+
const state = merge(initialStore, { topicMaps, activeId: 'test' });
6262
const topics = topicSelectors.currentUnFavTopics(state);
6363
expect(topics).toEqual([topicMaps.test[1]]);
6464
});

src/store/chat/slices/topic/selectors.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ const searchTopics = (s: ChatStore): ChatTopic[] => s.searchTopics;
1212
const displayTopics = (s: ChatStore): ChatTopic[] | undefined =>
1313
s.isSearchingTopic ? searchTopics(s) : currentTopics(s);
1414

15-
const currentUnFavTopics = (s: ChatStore): ChatTopic[] => s.topics.filter((s) => !s.favorite);
15+
const currentUnFavTopics = (s: ChatStore): ChatTopic[] =>
16+
currentTopics(s)?.filter((s) => !s.favorite) || [];
1617

1718
const currentTopicLength = (s: ChatStore): number => currentTopics(s)?.length || 0;
1819

0 commit comments

Comments
 (0)