This commit is contained in:
2026-02-24 12:27:53 +00:00
parent c276b8e319
commit 0f9eb68262
5 changed files with 234 additions and 160 deletions

View File

@@ -3,6 +3,7 @@ import { fileURLToPath } from "node:url";
import {
type ChatHistoryItem,
type ChatSessionModelFunctions,
type LLamaChatPromptOptions,
LlamaChatSession,
type LlamaModel,
@@ -13,66 +14,88 @@ import {
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const llama = await getLlama({
maxThreads: 2,
});
export async function getModel(model: string) {
const downloader = await createModelDownloader({
modelUri: `hf:${model}`,
dirPath: path.join(__dirname, "..", "models"),
});
const modelPath = await downloader.download();
const llama = await getLlama({
maxThreads: 2,
});
return await llama.loadModel({ modelPath });
}
export type Message = {
type: "system" | "model" | "user";
text: string;
};
export const grammar = await llama.createGrammarForJsonSchema({
type: "object",
properties: {
text: { type: "string" },
},
required: ["text"],
additionalProperties: false,
});
export async function complete(
model: LlamaModel,
messages: Message[],
options: LLamaChatPromptOptions = {},
) {
if (messages.length < 1) throw new Error("messages are empty");
const init = messages.slice(0, -1);
const last = messages.at(-1) as Message;
const context = await model.createContext();
const session = new LlamaChatSession({
contextSequence: context.getSequence(),
chatWrapper: resolveChatWrapper(model),
});
session.setChatHistory(
init.map((m): ChatHistoryItem => {
switch (m.type) {
case "system":
return {
type: "system",
text: m.text,
};
case "model":
return {
type: "model",
response: [m.text],
};
case "user":
return {
type: "user",
text: m.text,
};
}
}),
);
const res = await session.prompt(last.text, {
trimWhitespaceSuffix: true,
onResponseChunk(chunk) {
process.stderr.write(chunk.text);
},
...options,
});
session.dispose();
await context.dispose();
return res;
export function parseResponse(text: string) {
try {
const res = grammar.parse(text.trim());
return res.text;
} catch (e) {
console.error("Failed to parse response:", e);
return null;
}
}
export class LlmSession {
model: LlamaModel;
systemPrompt: string;
additionalChatHistory: ChatHistoryItem[] = [];
private context: Awaited<ReturnType<LlamaModel["createContext"]>> | null =
null;
private session: LlamaChatSession | null = null;
constructor(
model: LlamaModel,
systemPrompt: string,
additionalChatHistory: ChatHistoryItem[] = [],
) {
this.model = model;
this.systemPrompt = systemPrompt;
this.additionalChatHistory = additionalChatHistory;
}
async init() {
this.context = await this.model.createContext();
this.session = new LlamaChatSession({
contextSequence: this.context.getSequence(),
chatWrapper: resolveChatWrapper(this.model),
});
this.session.setChatHistory([
{
type: "system",
text: this.systemPrompt,
},
...this.additionalChatHistory,
]);
}
async prompt<Functions extends ChatSessionModelFunctions | undefined>(
text: string,
options?: LLamaChatPromptOptions<Functions>,
) {
if (!this.session) await this.init();
if (!this.session) throw new Error("session is not initialized");
return await this.session.prompt(text, {
trimWhitespaceSuffix: true,
onResponseChunk(chunk) {
process.stderr.write(chunk.text);
},
...options,
});
}
async [Symbol.asyncDispose]() {
await this.session?.dispose();
await this.context?.dispose();
}
}

View File

@@ -15,10 +15,23 @@ export const isSuitableAsInput = (n: Note) =>
!n.replyId &&
(!n.mentions || n.mentions.length === 0) &&
n.text?.length &&
["public", "home"].includes(n.visibility) &&
!n.cw &&
n.text.length > 0;
/** randomly sample some notes from the timeline */
export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
export async function getNotes(
followNotesCount: number,
localNotesCount: number,
globalNotesCount: number,
) {
// randomly sample N following notes
const followNotes = (count: number) =>
misskey
.request("notes/timeline", { limit: 100 })
.then((xs) => xs.filter(isSuitableAsInput))
.then((xs) => sample(xs, count));
// randomly sample N local notes
const localNotes = (count: number) =>
misskey
@@ -34,6 +47,7 @@ export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
.then((xs) => sample(xs, count));
const notes = await Promise.all([
followNotes(followNotesCount),
localNotes(localNotesCount),
globalNotes(globalNotesCount),
]);
@@ -43,10 +57,18 @@ export async function getNotes(localNotesCount = 5, globalNotesCount = 10) {
/** fetch the whole reply tree */
export async function expandReplyTree(
note: Note,
acc: Note[] = [],
cutoff = 5,
) {
if (!note.reply || cutoff < 1) return [...acc, note];
const reply = await misskey.request("notes/show", { noteId: note.reply.id });
return await expandReplyTree(reply, [...acc, note], cutoff - 1);
): Promise<{ last: Note; history: Note[] }> {
let current = note;
let count = 0;
const history: Note[] = [];
while (current.replyId && count < cutoff) {
const parent = await misskey.request("notes/show", {
noteId: current.replyId,
});
history.push(parent);
current = parent;
count++;
}
return { last: current, history: history.reverse() };
}