Refactor
This commit is contained in:
127
lib/llm.ts
127
lib/llm.ts
@@ -3,6 +3,7 @@ import { fileURLToPath } from "node:url";
|
||||
|
||||
import {
|
||||
type ChatHistoryItem,
|
||||
type ChatSessionModelFunctions,
|
||||
type LLamaChatPromptOptions,
|
||||
LlamaChatSession,
|
||||
type LlamaModel,
|
||||
@@ -13,66 +14,88 @@ import {
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
const llama = await getLlama({
|
||||
maxThreads: 2,
|
||||
});
|
||||
|
||||
export async function getModel(model: string) {
|
||||
const downloader = await createModelDownloader({
|
||||
modelUri: `hf:${model}`,
|
||||
dirPath: path.join(__dirname, "..", "models"),
|
||||
});
|
||||
const modelPath = await downloader.download();
|
||||
const llama = await getLlama({
|
||||
maxThreads: 2,
|
||||
});
|
||||
return await llama.loadModel({ modelPath });
|
||||
}
|
||||
|
||||
export type Message = {
|
||||
type: "system" | "model" | "user";
|
||||
text: string;
|
||||
};
|
||||
export const grammar = await llama.createGrammarForJsonSchema({
|
||||
type: "object",
|
||||
properties: {
|
||||
text: { type: "string" },
|
||||
},
|
||||
required: ["text"],
|
||||
additionalProperties: false,
|
||||
});
|
||||
|
||||
export async function complete(
|
||||
model: LlamaModel,
|
||||
messages: Message[],
|
||||
options: LLamaChatPromptOptions = {},
|
||||
) {
|
||||
if (messages.length < 1) throw new Error("messages are empty");
|
||||
const init = messages.slice(0, -1);
|
||||
const last = messages.at(-1) as Message;
|
||||
const context = await model.createContext();
|
||||
const session = new LlamaChatSession({
|
||||
contextSequence: context.getSequence(),
|
||||
chatWrapper: resolveChatWrapper(model),
|
||||
});
|
||||
session.setChatHistory(
|
||||
init.map((m): ChatHistoryItem => {
|
||||
switch (m.type) {
|
||||
case "system":
|
||||
return {
|
||||
type: "system",
|
||||
text: m.text,
|
||||
};
|
||||
case "model":
|
||||
return {
|
||||
type: "model",
|
||||
response: [m.text],
|
||||
};
|
||||
case "user":
|
||||
return {
|
||||
type: "user",
|
||||
text: m.text,
|
||||
};
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
const res = await session.prompt(last.text, {
|
||||
trimWhitespaceSuffix: true,
|
||||
onResponseChunk(chunk) {
|
||||
process.stderr.write(chunk.text);
|
||||
},
|
||||
...options,
|
||||
});
|
||||
session.dispose();
|
||||
await context.dispose();
|
||||
return res;
|
||||
export function parseResponse(text: string) {
|
||||
try {
|
||||
const res = grammar.parse(text.trim());
|
||||
return res.text;
|
||||
} catch (e) {
|
||||
console.error("Failed to parse response:", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export class LlmSession {
|
||||
model: LlamaModel;
|
||||
systemPrompt: string;
|
||||
additionalChatHistory: ChatHistoryItem[] = [];
|
||||
private context: Awaited<ReturnType<LlamaModel["createContext"]>> | null =
|
||||
null;
|
||||
private session: LlamaChatSession | null = null;
|
||||
|
||||
constructor(
|
||||
model: LlamaModel,
|
||||
systemPrompt: string,
|
||||
additionalChatHistory: ChatHistoryItem[] = [],
|
||||
) {
|
||||
this.model = model;
|
||||
this.systemPrompt = systemPrompt;
|
||||
this.additionalChatHistory = additionalChatHistory;
|
||||
}
|
||||
|
||||
async init() {
|
||||
this.context = await this.model.createContext();
|
||||
this.session = new LlamaChatSession({
|
||||
contextSequence: this.context.getSequence(),
|
||||
chatWrapper: resolveChatWrapper(this.model),
|
||||
});
|
||||
this.session.setChatHistory([
|
||||
{
|
||||
type: "system",
|
||||
text: this.systemPrompt,
|
||||
},
|
||||
...this.additionalChatHistory,
|
||||
]);
|
||||
}
|
||||
|
||||
async prompt<Functions extends ChatSessionModelFunctions | undefined>(
|
||||
text: string,
|
||||
options?: LLamaChatPromptOptions<Functions>,
|
||||
) {
|
||||
if (!this.session) await this.init();
|
||||
if (!this.session) throw new Error("session is not initialized");
|
||||
return await this.session.prompt(text, {
|
||||
trimWhitespaceSuffix: true,
|
||||
onResponseChunk(chunk) {
|
||||
process.stderr.write(chunk.text);
|
||||
},
|
||||
...options,
|
||||
});
|
||||
}
|
||||
|
||||
async [Symbol.asyncDispose]() {
|
||||
await this.session?.dispose();
|
||||
await this.context?.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,23 @@ export const isSuitableAsInput = (n: Note) =>
|
||||
!n.replyId &&
|
||||
(!n.mentions || n.mentions.length === 0) &&
|
||||
n.text?.length &&
|
||||
["public", "home"].includes(n.visibility) &&
|
||||
!n.cw &&
|
||||
n.text.length > 0;
|
||||
|
||||
/** randomly sample some notes from the timeline */
|
||||
export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
|
||||
export async function getNotes(
|
||||
followNotesCount: number,
|
||||
localNotesCount: number,
|
||||
globalNotesCount: number,
|
||||
) {
|
||||
// randomly sample N following notes
|
||||
const followNotes = (count: number) =>
|
||||
misskey
|
||||
.request("notes/timeline", { limit: 100 })
|
||||
.then((xs) => xs.filter(isSuitableAsInput))
|
||||
.then((xs) => sample(xs, count));
|
||||
|
||||
// randomly sample N local notes
|
||||
const localNotes = (count: number) =>
|
||||
misskey
|
||||
@@ -34,6 +47,7 @@ export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
|
||||
.then((xs) => sample(xs, count));
|
||||
|
||||
const notes = await Promise.all([
|
||||
followNotes(followNotesCount),
|
||||
localNotes(localNotesCount),
|
||||
globalNotes(globalNotesCount),
|
||||
]);
|
||||
@@ -43,10 +57,18 @@ export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
|
||||
/** fetch the whole reply tree */
|
||||
export async function expandReplyTree(
|
||||
note: Note,
|
||||
acc: Note[] = [],
|
||||
cutoff = 5,
|
||||
) {
|
||||
if (!note.reply || cutoff < 1) return [...acc, note];
|
||||
const reply = await misskey.request("notes/show", { noteId: note.reply.id });
|
||||
return await expandReplyTree(reply, [...acc, note], cutoff - 1);
|
||||
): Promise<{ last: Note; history: Note[] }> {
|
||||
let current = note;
|
||||
let count = 0;
|
||||
const history: Note[] = [];
|
||||
while (current.replyId && count < cutoff) {
|
||||
const parent = await misskey.request("notes/show", {
|
||||
noteId: current.replyId,
|
||||
});
|
||||
history.push(parent);
|
||||
current = parent;
|
||||
count++;
|
||||
}
|
||||
return { last: current, history: history.reverse() };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user