Skip to content

Commit bc5c66f

Browse files
authored
fix context pills (#1100)
1 parent d785fce commit bc5c66f

22 files changed

+278
-157
lines changed

common/presets.ts

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export const presetValidator = {
102102
epsilonCutoff: 'number?',
103103
etaCutoff: 'number?',
104104
mirostatToggle: 'boolean?',
105+
presetMode: ['simple', 'advanced', null],
105106
} as const
106107

107108
const disabledValues: { [key in keyof GenMap]?: AppSchema.GenSettings[key] } = {

common/prompt.ts

+11-2
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ export async function assemblePrompt(
283283
const template = getTemplate(opts)
284284

285285
const history = { lines, order: 'asc' } as const
286-
let { parsed, inserts, length, sections } = await injectPlaceholders(template, {
286+
let { parsed, inserts, length, sections, linesAddedCount } = await injectPlaceholders(template, {
287287
opts,
288288
parts,
289289
history,
@@ -293,7 +293,16 @@ export async function assemblePrompt(
293293
jsonValues: opts.jsonValues,
294294
})
295295

296-
return { lines: history.lines, prompt: parsed, inserts, parts, post, length, sections }
296+
return {
297+
lines: history.lines,
298+
prompt: parsed,
299+
inserts,
300+
parts,
301+
post,
302+
length,
303+
sections,
304+
linesAddedCount,
305+
}
297306
}
298307

299308
export function getTemplate(opts: Pick<GenerateRequestV2, 'settings' | 'chat'>) {

common/template-parser.ts

+11-6
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ export async function parseTemplate(
221221
// }
222222
// }
223223

224+
/**
225+
* Some placeholders require re-parsing as they also contain placeholders
226+
*/
227+
opts.isFinal = true
228+
const result = render(output, opts).replace(/\r\n/g, '\n').replace(/\n\n+/g, '\n\n').trim()
229+
opts.isFinal = false
230+
224231
/** Replace iterators */
225232
let history: string[] = []
226233
if (opts.limit && opts.limit.output) {
@@ -236,7 +243,7 @@ export async function parseTemplate(
236243
})
237244
unusedTokens = filled.unusedTokens
238245
const trimmed = filled.adding.slice().reverse()
239-
output = output.replace(id, trimmed.join('\n'))
246+
output = output.replace(new RegExp(id, 'gi'), trimmed.join('\n'))
240247
linesAddedCount += filled.linesAddedCount
241248
history = trimmed
242249
}
@@ -256,10 +263,6 @@ export async function parseTemplate(
256263
}
257264
}
258265

259-
opts.isFinal = true
260-
const result = render(output, opts).replace(/\r\n/g, '\n').replace(/\n\n+/g, '\n\n').trim()
261-
opts.isFinal = false
262-
263266
sections.sections.history = history
264267

265268
// console.log(
@@ -273,8 +276,10 @@ export async function parseTemplate(
273276
// sections.sections.post.join('')
274277
// )
275278

279+
output = output.replace(/\r\n/g, '\n').replace(/\n\n+/g, '\n\n').trim()
280+
276281
return {
277-
parsed: result,
282+
parsed: output,
278283
inserts: opts.inserts ?? new Map(),
279284
length: await opts.limit?.encoder?.(result),
280285
linesAddedCount,

common/types/presets.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ export interface UserGenPreset extends GenSettings {
7676
export interface GenSettings {
7777
name: string
7878
description?: string
79-
presetMode?: 'simple' | 'advanced'
79+
presetMode?: 'simple' | 'advanced' | undefined
8080

8181
service?: AIAdapter
8282

common/valid/validate.ts

+7-7
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,21 @@ export function isValidPartial<T extends Validator>(
6161
}
6262

6363
export function validateBody<T extends Validator>(
64-
type: T,
64+
guard: T,
6565
compare: any,
6666
opts: { partial?: boolean; prefix?: string; notThrow?: boolean } = {}
67-
): { errors: string[]; actual: UnwrapBody<T> } {
67+
): { errors: string[]; actual: UnwrapBody<T>; original: UnwrapBody<T> } {
6868
const prefix = opts.prefix ? `${opts.prefix}.` : ''
6969
const errors: string[] = []
7070
const actual: any = {}
7171

72-
if (!compare && '?' in type && (type as any)['?'] === '?') {
73-
return { errors, actual }
72+
if (!compare && '?' in guard && (guard as any)['?'] === '?') {
73+
return { errors, actual, original: compare }
7474
}
7575

76-
start: for (const key in type) {
76+
start: for (const key in guard) {
7777
const prop = `${prefix}${key}`
78-
const bodyType = type[key]
78+
const bodyType = guard[key]
7979
let value
8080
try {
8181
value = compare?.[key]
@@ -246,5 +246,5 @@ export function validateBody<T extends Validator>(
246246
throw new Error(`Object does not match type: ${errors.join(', ')}`)
247247
}
248248

249-
return { errors, actual }
249+
return { errors, actual, original: compare }
250250
}

srv/adapter/gemini.ts

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import needle from 'needle'
22
import { decryptText } from '../db/util'
3-
import { getEncoderByName } from '../tokenize'
3+
import { getEncoder, getEncoderByName } from '../tokenize'
44
import { toChatCompletionPayload } from './chat-completion'
55
import { getStoppingStrings } from './prompt'
6-
import { ModelAdapter } from './type'
6+
import { AdapterProps, ModelAdapter } from './type'
77
import { AppLog } from '../middleware'
88
import { sanitise, sanitiseAndTrim, trimResponseV2 } from '/common/requests/util'
99
import { requestStream } from './stream'
10+
import { injectPlaceholders } from '/common/prompt'
1011

1112
const BASE_URL = `https://generativelanguage.googleapis.com/v1beta/models/`
1213

@@ -34,12 +35,13 @@ export const handleGemini: ModelAdapter = async function* (opts) {
3435
},
3536
}
3637

37-
const systems: string[] = []
38+
const fallback = await fallbackSystemMessage(opts)
39+
const systems: string[] = [opts.parts.systemPrompt || fallback.parsed]
3840
const contents: any[] = []
3941

4042
for (const msg of messages) {
4143
if (msg.role === 'system') {
42-
systems.push(msg.content)
44+
contents.push({ role: 'user', parts: [{ text: msg.content }] })
4345
continue
4446
}
4547

@@ -51,11 +53,7 @@ export const handleGemini: ModelAdapter = async function* (opts) {
5153
if (systems.length) {
5254
if (!SYSTEM_INCAPABLE[opts.gen.googleModel]) {
5355
payload.system_instruction = {
54-
parts: [
55-
{
56-
text: systems.join('\n'),
57-
},
58-
],
56+
parts: [{ text: systems.join('\n') }],
5957
}
6058
} else {
6159
contents.unshift({ role: 'user', parts: [{ text: systems.join('\n') }] })
@@ -224,3 +222,18 @@ const safetySettings = [
224222
threshold: 'BLOCK_NONE',
225223
},
226224
]
225+
226+
function fallbackSystemMessage(opts: AdapterProps) {
227+
const message = injectPlaceholders(
228+
`Write "{{char}}'s" next reply in a fictional roleplay chat between "{{user}}" and "{{char}}"`,
229+
{
230+
characters: opts.characters,
231+
encoder: getEncoder('main').count,
232+
jsonValues: {},
233+
parts: opts.parts,
234+
opts,
235+
}
236+
)
237+
238+
return message
239+
}

srv/adapter/generate.ts

+8
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export type InferenceRequest = {
6464
guest?: string
6565
user: AppSchema.User
6666
settings?: Partial<AppSchema.UserGenPreset>
67+
maxKnownLines?: number
6768

6869
guidance?: boolean
6970
placeholders?: any
@@ -394,6 +395,13 @@ export async function createChatStream(
394395
*/
395396

396397
const prompt = await assemblePrompt(opts, opts.parts, opts.lines, encoder)
398+
if (prompt.linesAddedCount === 0 && opts.linesCount) {
399+
throw new StatusError(
400+
`Could not fit any messages in prompt. Check your character definition, context size, and template`,
401+
400
402+
)
403+
}
404+
397405
const messages = await toChatMessages(opts, prompt, encoder)
398406

399407
const size = encoder(

srv/adapter/payloads.ts

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
102102
skip_special_tokens: gen.skipSpecialTokens ?? true,
103103
stopping_strings: getStoppingStrings(opts, stops),
104104
dynamic_temperature: gen.dynatemp_range ? true : false,
105+
smoothing_curve: gen.smoothingCurve,
105106
smoothing_factor: gen.smoothingFactor,
106107
token_healing: gen.tokenHealing,
107108
temp_last: gen.tempLast,

srv/adapter/type.ts

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export type GenerateRequestV2 = {
6161

6262
parts: PromptParts
6363
lines: string[]
64+
linesCount?: number
6465
text?: string
6566
settings?: Partial<AppSchema.GenSettings>
6667
replacing?: AppSchema.ChatMessage

srv/api/chat/message.ts

+29-15
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const genValidator = {
6666
userEmbeds: 'any?',
6767
},
6868
lines: ['string'],
69+
linesCount: 'number?',
6970
text: 'string?',
7071
settings: 'any?',
7172
lastMessage: 'string?',
@@ -247,15 +248,6 @@ export const generateMessageV2 = handle(async (req, res) => {
247248
})
248249
}
249250

250-
res.json({
251-
requestId,
252-
success: true,
253-
generating: true,
254-
message: 'Generating message',
255-
messageId,
256-
created: userMsg,
257-
})
258-
259251
const entities = await getResponseEntities(chat, body.sender.userId, body.settings)
260252
const schema = entities.gen.jsonSource === 'character' ? replyAs.json : entities.gen.json
261253
const hydrator = entities.gen.jsonEnabled && schema ? jsonHydrator(schema) : undefined
@@ -271,9 +263,10 @@ export const generateMessageV2 = handle(async (req, res) => {
271263
let probs: any
272264

273265
if (body.response === undefined) {
274-
const { stream, ...metadata } = await createChatStream(
266+
const chatStream = await createChatStream(
275267
{
276268
...body,
269+
linesCount: body.linesCount,
277270
chat,
278271
replyAs,
279272
impersonate,
@@ -282,7 +275,22 @@ export const generateMessageV2 = handle(async (req, res) => {
282275
chatSchema: schema,
283276
},
284277
log
285-
)
278+
).catch((err) => ({ err }))
279+
280+
if ('err' in chatStream) {
281+
throw chatStream.err
282+
}
283+
284+
res.json({
285+
requestId,
286+
success: true,
287+
generating: true,
288+
message: 'Generating message',
289+
messageId,
290+
created: userMsg,
291+
})
292+
293+
const { stream, ...metadata } = chatStream
286294

287295
adapter = metadata.adapter
288296

@@ -584,8 +592,6 @@ async function handleGuestGenerate(body: GenRequest, req: AppRequest, res: Respo
584592
return { success: true }
585593
}
586594

587-
res.json({ success: true, generating: true, message: 'Generating message', requestId })
588-
589595
const schema = body.settings.jsonSource === 'character' ? body.char.json : body.settings.json
590596
const hydrator = body.settings.jsonEnabled && schema ? jsonHydrator(schema) : undefined
591597
let generated = body.response || ''
@@ -597,12 +603,20 @@ async function handleGuestGenerate(body: GenRequest, req: AppRequest, res: Respo
597603
let jsonPartial: any
598604

599605
if (body.response === undefined) {
600-
const { stream, ...entities } = await createChatStream(
601-
{ ...body, chat, replyAs, requestId, chatSchema: schema },
606+
const chatStream = await createChatStream(
607+
{ ...body, chat, replyAs, requestId, chatSchema: schema, linesCount: body.linesCount },
602608
log,
603609
guest
604610
)
605611

612+
if ('err' in chatStream) {
613+
throw chatStream.err
614+
}
615+
616+
const { stream, ...entities } = chatStream
617+
618+
res.json({ success: true, generating: true, message: 'Generating message', requestId })
619+
606620
log.setBindings({ adapter })
607621

608622
adapter = entities.adapter

srv/api/user/presets.ts

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { assertValid } from '/common/valid'
22
import { defaultPresets, presetValidator } from '../../../common/presets'
33
import { store } from '../../db'
4-
import { StatusError, errors, handle } from '../wrap'
4+
import { StatusError, handle } from '../wrap'
55
import { AIAdapter } from '../../../common/adapters'
66
import { AppSchema } from '/common/types'
77
import { toSamplerOrder } from '/common/sampler-order'
@@ -43,10 +43,7 @@ export const createUserPreset = handle(async ({ userId, body, authed }) => {
4343
}
4444

4545
if (body.chatId) {
46-
const res = await store.chats.getChat(body.chatId)
47-
if (res?.chat.userId !== userId) {
48-
throw errors.Forbidden
49-
}
46+
delete body.chatId
5047
}
5148

5249
const samplers = toSamplerOrder(body.service, body.order, body.disabledSamplers)

tests/__snapshots__/prompt.spec.js.snap

-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ Object {
220220
},
221221
"template": Object {
222222
"inserts": Map {},
223-
"length": 38,
224223
"linesAddedCount": 2,
225224
"parsed": "GASLIGHT TEMPLATE
226225
ChatOwner

tests/prompt.spec.ts

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ This is how {{char}} should talk: {{example_dialogue}}`,
171171
},
172172
})
173173
delete (actual.template as any).sections
174+
delete actual.template.length
174175
expect(actual).to.matchSnapshot()
175176
})
176177

0 commit comments

Comments
 (0)