Skip to content

Commit ee45b6e

Browse files
authored
ArliAI support (#1095)
* log image gen failures * client-side horde image gen * add arliai support
1 parent 8d0d801 commit ee45b6e

26 files changed

+443
-30
lines changed

common/adapters.ts

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ export const THIRDPARTY_HANDLERS: { [svc in ThirdPartyFormat]: AIAdapter } = {
9999
vllm: 'kobold',
100100
featherless: 'kobold',
101101
gemini: 'kobold',
102+
arli: 'kobold',
102103
}
103104

104105
export const BASIC_PROMPT_ONLY: { [svc in ThirdPartyFormat]?: boolean } = {
@@ -122,6 +123,7 @@ export const THIRDPARTY_FORMATS = [
122123
'ollama',
123124
'vllm',
124125
'featherless',
126+
'arli',
125127
'gemini',
126128
] as const
127129

common/horde-gen.ts

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import crypto from 'crypto'
12
import { AppSchema } from './types/schema'
23
import { defaultPresets } from './default-preset'
34
import { SD_SAMPLER } from './image'
45
import { toArray } from './util'
56
import type { AppLog } from '../srv/middleware'
7+
import { v4 } from 'uuid'
8+
9+
export const HORDE_SEED = v4()
10+
11+
const ALGO = 'aes256'
612

713
const HORDE_GUEST_KEY = '0000000000'
814
const baseUrl = 'https://aihorde.net/api/v2'
@@ -107,7 +113,8 @@ export async function generateImage(
107113
height: base?.height ?? 1024,
108114
width: base?.width ?? 1024,
109115
cfg_scale: base?.cfg ?? 9,
110-
seed: Math.trunc(Math.random() * 1_000_000_000).toString(),
116+
clip_skip: base?.clipSkip,
117+
denoising_strength: 1,
111118
karras: false,
112119
n: 1,
113120
post_processing: [],
@@ -125,10 +132,17 @@ export async function generateImage(
125132
log?.debug({ ...payload, prompt: null }, 'Horde payload')
126133
log?.debug(`Prompt:\n${payload.prompt}`)
127134

135+
let key = user.hordeKey
136+
if (!key) {
137+
key = HORDE_GUEST_KEY
138+
} else {
139+
key = decryptText(user.hordeKey)
140+
}
141+
128142
const image = await generate({
129143
type: 'image',
130144
payload,
131-
key: user.hordeKey || HORDE_GUEST_KEY,
145+
key,
132146
onTick,
133147
})
134148

@@ -280,6 +294,16 @@ function wait(secs: number) {
280294
return new Promise((resolve) => setTimeout(resolve, secs * 1000))
281295
}
282296

297+
function decryptText(text: string) {
298+
try {
299+
const decipher = crypto.createDecipher(ALGO, HORDE_SEED)
300+
const decrypted = decipher.update(text, 'hex', 'utf8') + decipher.final('utf8')
301+
return decrypted
302+
} catch (ex) {
303+
return text
304+
}
305+
}
306+
283307
export type FindUserResponse = {
284308
kudos_details: {
285309
accumulated: number

common/presets.ts

+6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export const presetValidator = {
5757
oaiModel: 'string',
5858
openRouterModel: 'any?',
5959
featherlessModel: 'string?',
60+
arliModel: 'string?',
6061
googleModel: 'string?',
6162

6263
mirostatTau: 'number?',
@@ -73,6 +74,11 @@ export const presetValidator = {
7374
thirdPartyUrlNoSuffix: 'boolean?',
7475
thirdPartyModel: 'string?',
7576

77+
dryAllowedLength: 'number?',
78+
dryBase: 'number?',
79+
drySequenceBreakers: ['string?'],
80+
dryMultiplier: 'number?',
81+
7682
novelModel: 'string?',
7783
novelModelOverride: 'string?',
7884

common/types/presets.ts

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ export interface GenSettings {
164164
openRouterModel?: OpenRouterModel
165165
googleModel?: string
166166
featherlessModel?: string
167+
arliModel?: string
167168

168169
thirdPartyUrl?: string
169170
thirdPartyFormat?: ThirdPartyFormat

common/types/schema.ts

+3
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ export namespace AppSchema {
130130
featherlessApiKey?: string
131131
featherlessApiKeySet?: boolean
132132

133+
arliApiKey?: string
134+
arliApiKeySet?: boolean
135+
133136
defaultAdapter: AIAdapter
134137
defaultPresets?: { [key in AIAdapter]?: string }
135138
defaultPreset?: string

srv/adapter/agnaistic.ts

+1
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ export function getHandlers(settings: Partial<AppSchema.GenSettings>) {
427427
return handlers.openai
428428

429429
case 'featherless':
430+
case 'arli':
430431
return handlers.kobold
431432

432433
case 'gemini':

srv/adapter/arli.ts

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import { logger } from '../middleware'
2+
3+
type V1Model = {
4+
name: string
5+
promptFormat: string
6+
contextSize: string
7+
modelLink: string
8+
modelSize: string
9+
status: boolean
10+
}
11+
12+
export type ArliModel = {
13+
id: string
14+
name: string
15+
model_class: string
16+
status: 'active' | 'not_deployed' | 'pending_deploy'
17+
health?: 'OFFLINE' | 'UNHEALTHY' | 'HEALTHY'
18+
19+
ctx: number
20+
res: number
21+
}
22+
23+
let modelCache: ArliModel[] = []
24+
let classCache: Record<string, { ctx: number; res: number }> = {}
25+
26+
export function getArliModels() {
27+
return { models: modelCache, classes: classCache }
28+
}
29+
30+
async function getModelList() {
31+
try {
32+
const models = await fetch('https://api.arliai.com/model/all', {
33+
headers: {
34+
accept: '*/*',
35+
},
36+
method: 'GET',
37+
})
38+
39+
const next: ArliModel[] = []
40+
41+
if (models.status && models.status > 200) {
42+
const body = await models.json()
43+
logger.warn({ body, status: models.status }, `ArliAI model list failed`)
44+
return
45+
}
46+
47+
const map = await models.json().then((res) => {
48+
const list = res as V1Model[]
49+
if (!list) return {}
50+
51+
const map: { [key: string]: V1Model } = {}
52+
for (const model of list) {
53+
if (!classCache[model.modelSize]) {
54+
classCache[model.modelSize] = {
55+
ctx: +model.contextSize,
56+
res: 500,
57+
}
58+
}
59+
60+
next.push({
61+
id: model.name,
62+
model_class: model.modelSize,
63+
name: model.name,
64+
res: 500,
65+
status: model.status ? 'active' : 'not_deployed',
66+
ctx: +model.contextSize,
67+
})
68+
69+
map[model.name] = model
70+
}
71+
return map
72+
})
73+
74+
modelCache = next
75+
76+
return map
77+
} catch (ex) {
78+
logger.error({ err: ex }, `Featherless model list failed`)
79+
}
80+
}
81+
82+
getModelList()
83+
84+
setInterval(getModelList, 120000)

srv/adapter/kobold.ts

+27-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ export const handleThirdParty: ModelAdapter = async function* (opts) {
4747
opts.gen.thirdPartyFormat === 'llamacpp' ||
4848
opts.gen.thirdPartyFormat === 'exllamav2' ||
4949
opts.gen.thirdPartyFormat === 'koboldcpp' ||
50-
opts.gen.thirdPartyFormat === 'featherless'
50+
opts.gen.thirdPartyFormat === 'featherless' ||
51+
opts.gen.thirdPartyFormat === 'arli'
5152
? getThirdPartyPayload(opts)
5253
: { ...base, ...mappedSettings, prompt }
5354

@@ -186,6 +187,13 @@ async function dispatch(opts: AdapterProps, body: any) {
186187
: fullCompletion(url, body, headers, opts.gen.thirdPartyFormat, opts.log)
187188
}
188189

190+
case 'arli': {
191+
const url = 'https://api.arliai.com/v1/completions'
192+
return opts.gen.streamResponse
193+
? streamCompletion(url, body, headers, opts.gen.thirdPartyFormat, opts.log)
194+
: fullCompletion(url, body, headers, opts.gen.thirdPartyFormat, opts.log)
195+
}
196+
189197
default:
190198
const isStreamSupported = await checkStreamSupported(`${baseURL}/api/extra/version`)
191199
return opts.gen.streamResponse && isStreamSupported
@@ -253,6 +261,24 @@ async function getHeaders(opts: AdapterProps) {
253261
break
254262
}
255263

264+
case 'arli': {
265+
if (!opts.gen.arliModel) {
266+
throw new Error(`ArliAI model not set. Check your preset`)
267+
}
268+
269+
const key = opts.gen.thirdPartyKey || opts.user.arliApiKey
270+
if (!key) {
271+
throw new Error(`ArliAI API key not set. Check your Settings->AI->Third-party settings`)
272+
}
273+
274+
const apiKey = key ? (opts.guest ? key : decryptText(key)) : ''
275+
if (apiKey) {
276+
headers['Authorization'] = `Bearer ${apiKey}`
277+
}
278+
headers['Content-Type'] = 'application/json'
279+
break
280+
}
281+
256282
case 'mistral': {
257283
const key = opts.user.mistralKey
258284
if (!key)

srv/adapter/payloads.ts

+63
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,55 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
216216
return payload
217217
}
218218

219+
if (format === 'arli') {
220+
const body: any = {
221+
model: gen.arliModel,
222+
prompt,
223+
stop: getStoppingStrings(opts, stops),
224+
presence_penalty: gen.presencePenalty,
225+
frequency_penalty: gen.frequencyPenalty,
226+
length_penalty: gen.repetitionPenalty,
227+
tfs: gen.tailFreeSampling,
228+
temperature: gen.temp,
229+
top_p: gen.topP,
230+
top_k: gen.topK,
231+
min_p: gen.minP,
232+
typical_p: gen.typicalP,
233+
ignore_eos: false,
234+
max_tokens: gen.maxTokens,
235+
smoothing_factor: gen.smoothingFactor,
236+
smoothing_curve: gen.smoothingCurve,
237+
238+
stream: gen.streamResponse,
239+
}
240+
241+
if (gen.dryMultiplier) {
242+
body.dry_multiplier = gen.dryMultiplier
243+
body.dry_base = gen.dryBase
244+
body.dry_allowed_length = gen.dryAllowedLength
245+
body.dry_range = gen.dryRange
246+
body.dry_sequence_breakers = sequenceBreakers
247+
}
248+
249+
if (gen.dynatemp_range) {
250+
body.dynamic_temperature = true
251+
body.dynatemp_min = (gen.temp ?? 1) - (gen.dynatemp_range ?? 0)
252+
body.dynatemp_max = (gen.temp ?? 1) + (gen.dynatemp_range ?? 0)
253+
body.dynatemp_exponent = gen.dynatemp_exponent
254+
}
255+
256+
if (gen.xtcThreshold) {
257+
body.xtc_threshold = gen.xtcThreshold
258+
body.xtc_probability = gen.xtcProbability
259+
}
260+
261+
if (body.top_k <= 0) {
262+
body.top_k = -1
263+
}
264+
265+
return body
266+
}
267+
219268
if (format === 'ollama') {
220269
const payload: any = {
221270
prompt,
@@ -426,12 +475,26 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
426475
epsilon_cutoff: gen.epsilonCutoff,
427476
}
428477

478+
if (gen.dryMultiplier) {
479+
body.dry_multiplier = gen.dryMultiplier
480+
body.dry_base = gen.dryBase
481+
body.dry_allowed_length = gen.dryAllowedLength
482+
body.dry_range = gen.dryRange
483+
body.dry_sequence_breakers = sequenceBreakers
484+
}
485+
429486
if (gen.dynatemp_range) {
487+
body.dynamic_temperature = true
430488
body.dynatemp_min = (gen.temp ?? 1) - (gen.dynatemp_range ?? 0)
431489
body.dynatemp_max = (gen.temp ?? 1) + (gen.dynatemp_range ?? 0)
432490
body.dynatemp_exponent = gen.dynatemp_exponent
433491
}
434492

493+
if (gen.xtcThreshold) {
494+
body.xtc_threshold = gen.xtcThreshold
495+
body.xtc_probability = gen.xtcProbability
496+
}
497+
435498
return body
436499
}
437500

srv/api/settings.ts

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { getOpenRouterModels } from '../adapter/openrouter'
1313
import { updateRegisteredSubs } from '../adapter/agnaistic'
1414
import { getFeatherModels } from '../adapter/featherless'
1515
import { filterImageModels } from '/common/image-util'
16+
import { getArliModels } from '../adapter/arli'
1617

1718
const router = Router()
1819

@@ -35,6 +36,10 @@ router.get('/featherless', (_, res) => {
3536
const { models, classes } = getFeatherModels()
3637
res.json({ models, classes })
3738
})
39+
router.get('/arli', (_, res) => {
40+
const { models, classes } = getArliModels()
41+
res.json({ models, classes })
42+
})
3843

3944
export default router
4045

0 commit comments

Comments
 (0)