Skip to content

Commit 2e63e32

Browse files
authored
🐼 feat: Add Flux Image Generation Tool (#6147)
* 🔧 fix: Log warning for aborted operations in AgentClient * ci: Remove unused saveMessageToDatabase mock in FakeClient initialization * ci: test actual implementation of saveMessageToDatabase * refactor: Change log level from warning to error for aborted operations in AgentClient * refactor: Add className prop to Image component for customizable styling, use theme selectors * feat: FLUX Image Generation tool
1 parent 7f6b32f commit 2e63e32

File tree

13 files changed

+760
-16
lines changed

13 files changed

+760
-16
lines changed

.env.example

+7
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
248248
# DALLE3_AZURE_API_VERSION=
249249
# DALLE2_AZURE_API_VERSION=
250250

251+
# Flux
252+
#-----------------
253+
FLUX_API_BASE_URL=https://api.us1.bfl.ai
254+
# FLUX_API_BASE_URL = 'https://api.bfl.ml';
255+
256+
# Get your API key at https://api.us1.bfl.ai/auth/profile
257+
# FLUX_API_KEY=
251258

252259
# Google
253260
#-----------------

api/app/clients/specs/BaseClient.test.js

+156-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ jest.mock('~/models', () => ({
3030
updateFileUsage: jest.fn(),
3131
}));
3232

33+
const { getConvo, saveConvo } = require('~/models');
34+
3335
jest.mock('@langchain/openai', () => {
3436
return {
3537
ChatOpenAI: jest.fn().mockImplementation(() => {
@@ -540,10 +542,11 @@ describe('BaseClient', () => {
540542

541543
test('saveMessageToDatabase is called with the correct arguments', async () => {
542544
const saveOptions = TestClient.getSaveOptions();
543-
const user = {}; // Mock user
545+
const user = {};
544546
const opts = { user };
547+
const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase');
545548
await TestClient.sendMessage('Hello, world!', opts);
546-
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith(
549+
expect(saveSpy).toHaveBeenCalledWith(
547550
expect.objectContaining({
548551
sender: expect.any(String),
549552
text: expect.any(String),
@@ -557,6 +560,157 @@ describe('BaseClient', () => {
557560
);
558561
});
559562

563+
test('should handle existing conversation when getConvo retrieves one', async () => {
564+
const existingConvo = {
565+
conversationId: 'existing-convo-id',
566+
endpoint: 'openai',
567+
endpointType: 'openai',
568+
model: 'gpt-3.5-turbo',
569+
messages: [
570+
{ role: 'user', content: 'Existing message 1' },
571+
{ role: 'assistant', content: 'Existing response 1' },
572+
],
573+
temperature: 1,
574+
};
575+
576+
const { temperature: _temp, ...newConvo } = existingConvo;
577+
578+
const user = {
579+
id: 'user-id',
580+
};
581+
582+
getConvo.mockResolvedValue(existingConvo);
583+
saveConvo.mockResolvedValue(newConvo);
584+
585+
TestClient = initializeFakeClient(
586+
apiKey,
587+
{
588+
...options,
589+
req: {
590+
user,
591+
},
592+
},
593+
[],
594+
);
595+
596+
const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase');
597+
598+
const newMessage = 'New message in existing conversation';
599+
const response = await TestClient.sendMessage(newMessage, {
600+
user,
601+
conversationId: existingConvo.conversationId,
602+
});
603+
604+
expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId);
605+
expect(TestClient.conversationId).toBe(existingConvo.conversationId);
606+
expect(response.conversationId).toBe(existingConvo.conversationId);
607+
expect(TestClient.fetchedConvo).toBe(true);
608+
609+
expect(saveSpy).toHaveBeenCalledWith(
610+
expect.objectContaining({
611+
conversationId: existingConvo.conversationId,
612+
text: newMessage,
613+
}),
614+
expect.any(Object),
615+
expect.any(Object),
616+
);
617+
618+
expect(saveConvo).toHaveBeenCalledTimes(2);
619+
expect(saveConvo).toHaveBeenCalledWith(
620+
expect.any(Object),
621+
expect.objectContaining({
622+
conversationId: existingConvo.conversationId,
623+
}),
624+
expect.objectContaining({
625+
context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo',
626+
unsetFields: {
627+
temperature: 1,
628+
},
629+
}),
630+
);
631+
632+
await TestClient.sendMessage('Another message', {
633+
conversationId: existingConvo.conversationId,
634+
});
635+
expect(getConvo).toHaveBeenCalledTimes(1);
636+
});
637+
638+
test('should correctly handle existing conversation and unset fields appropriately', async () => {
639+
const existingConvo = {
640+
conversationId: 'existing-convo-id',
641+
endpoint: 'openai',
642+
endpointType: 'openai',
643+
model: 'gpt-3.5-turbo',
644+
messages: [
645+
{ role: 'user', content: 'Existing message 1' },
646+
{ role: 'assistant', content: 'Existing response 1' },
647+
],
648+
title: 'Existing Conversation',
649+
someExistingField: 'existingValue',
650+
anotherExistingField: 'anotherValue',
651+
temperature: 0.7,
652+
modelLabel: 'GPT-3.5',
653+
};
654+
655+
getConvo.mockResolvedValue(existingConvo);
656+
saveConvo.mockResolvedValue(existingConvo);
657+
658+
TestClient = initializeFakeClient(
659+
apiKey,
660+
{
661+
...options,
662+
modelOptions: {
663+
model: 'gpt-4',
664+
temperature: 0.5,
665+
},
666+
},
667+
[],
668+
);
669+
670+
const newMessage = 'New message in existing conversation';
671+
await TestClient.sendMessage(newMessage, {
672+
conversationId: existingConvo.conversationId,
673+
});
674+
675+
expect(saveConvo).toHaveBeenCalledTimes(2);
676+
677+
const saveConvoCall = saveConvo.mock.calls[0];
678+
const [, savedFields, saveOptions] = saveConvoCall;
679+
680+
// Instead of checking all excludedKeys, we'll just check specific fields
681+
// that we know should be excluded
682+
expect(savedFields).not.toHaveProperty('messages');
683+
expect(savedFields).not.toHaveProperty('title');
684+
685+
// Only check that someExistingField is in unsetFields
686+
expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1);
687+
688+
// Mock saveConvo to return the expected fields
689+
saveConvo.mockImplementation((req, fields) => {
690+
return Promise.resolve({
691+
...fields,
692+
endpoint: 'openai',
693+
endpointType: 'openai',
694+
model: 'gpt-4',
695+
temperature: 0.5,
696+
});
697+
});
698+
699+
// Only check the conversationId since that's the only field we can be sure about
700+
expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id');
701+
702+
expect(TestClient.fetchedConvo).toBe(true);
703+
704+
await TestClient.sendMessage('Another message', {
705+
conversationId: existingConvo.conversationId,
706+
});
707+
708+
expect(getConvo).toHaveBeenCalledTimes(1);
709+
710+
const secondSaveConvoCall = saveConvo.mock.calls[1];
711+
expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {});
712+
});
713+
560714
test('sendCompletion is called with the correct arguments', async () => {
561715
const payload = {}; // Mock payload
562716
TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null });

api/app/clients/specs/FakeClient.js

-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
5656
let TestClient = new FakeClient(apiKey);
5757
TestClient.options = options;
5858
TestClient.abortController = { abort: jest.fn() };
59-
TestClient.saveMessageToDatabase = jest.fn();
6059
TestClient.loadHistory = jest
6160
.fn()
6261
.mockImplementation((conversationId, parentMessageId = null) => {
@@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
8685
return 'Mock response text';
8786
});
8887

89-
// eslint-disable-next-line no-unused-vars
9088
TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => {
9189
return {
9290
choices: [

api/app/clients/tools/index.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ const availableTools = require('./manifest.json');
22

33
// Structured Tools
44
const DALLE3 = require('./structured/DALLE3');
5+
const FluxAPI = require('./structured/FluxAPI');
56
const OpenWeather = require('./structured/OpenWeather');
6-
const createYouTubeTools = require('./structured/YouTube');
77
const StructuredWolfram = require('./structured/Wolfram');
8+
const createYouTubeTools = require('./structured/YouTube');
89
const StructuredACS = require('./structured/AzureAISearch');
910
const StructuredSD = require('./structured/StableDiffusion');
1011
const GoogleSearchAPI = require('./structured/GoogleSearch');
@@ -30,6 +31,7 @@ module.exports = {
3031
manifestToolMap,
3132
// Structured Tools
3233
DALLE3,
34+
FluxAPI,
3335
OpenWeather,
3436
StructuredSD,
3537
StructuredACS,

api/app/clients/tools/manifest.json

+14
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,19 @@
164164
"description": "Sign up at <a href=\"https://home.openweathermap.org/users/sign_up\" target=\"_blank\">OpenWeather</a>, then get your key at <a href=\"https://home.openweathermap.org/api_keys\" target=\"_blank\">API keys</a>."
165165
}
166166
]
167+
},
168+
{
169+
"name": "Flux",
170+
"pluginKey": "flux",
171+
"description": "Generate images using text with the Flux API.",
172+
"icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png",
173+
"isAuthRequired": "true",
174+
"authConfig": [
175+
{
176+
"authField": "FLUX_API_KEY",
177+
"label": "Your Flux API Key",
178+
"description": "Provide your Flux API key from your user profile."
179+
}
180+
]
167181
}
168182
]

0 commit comments

Comments
 (0)