From cc43431a90d12129749fb2ce88b37a3a576e07d4 Mon Sep 17 00:00:00 2001 From: cannorin Date: Thu, 2 Oct 2025 09:50:48 +0000 Subject: [PATCH] Refactor --- index.ts | 168 +++++-------------------------------------------- lib/llm.ts | 76 ++++++++++++++++++++++ lib/misskey.ts | 56 +++++++++++++++++ lib/util.ts | 16 +++++ 4 files changed, 164 insertions(+), 152 deletions(-) create mode 100644 lib/llm.ts create mode 100644 lib/misskey.ts create mode 100644 lib/util.ts diff --git a/index.ts b/index.ts index ff10ce6..8256d58 100644 --- a/index.ts +++ b/index.ts @@ -1,157 +1,21 @@ -import path from "node:path"; -import { fileURLToPath } from "node:url"; +import type { Note } from "misskey-js/entities.js"; +import { complete, getModel, type Message } from "./lib/llm"; +import { + expandReplyTree, + getNotes, + me, + misskey, + noteToMessage, +} from "./lib/misskey"; +import { Stream } from "misskey-js"; +import { sleep } from "./lib/util"; import { parseArgs } from "node:util"; -import { api } from "misskey-js"; -import { Stream } from "misskey-js"; -import type { Note } from "misskey-js/entities.js"; +const modelName = + Bun.env["MODEL"] ?? "mradermacher/gemma-2-baku-2b-it-GGUF:IQ4_XS"; +console.log(`* loading model '${modelName}'`); +const model = await getModel(modelName); -import { - type ChatHistoryItem, - LlamaChatSession, - createModelDownloader, - getLlama, - resolveChatWrapper, -} from "node-llama-cpp"; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); - -// #region llm -const model = await (async () => { - const downloader = await createModelDownloader({ - modelUri: `hf:${Bun.env["MODEL"] ?? "mradermacher/gemma-2-baku-2b-it-GGUF:IQ4_XS"}`, - dirPath: path.join(__dirname, "models"), - }); - const modelPath = await downloader.download(); - const llama = await getLlama({ - maxThreads: 2, - }); - return await llama.loadModel({ modelPath }); -})(); - -type Message = { - type: "system" | "model" | "user"; - text: string; -}; - -async function complete(messages: Message[]) { - if (messages.length < 1) throw new Error("messages are empty"); - const init = messages.slice(0, -1); - const last = messages.at(-1) as Message; - const context = await model.createContext(); - const session = new LlamaChatSession({ - contextSequence: context.getSequence(), - chatWrapper: resolveChatWrapper(model), - }); - session.setChatHistory( - init.map((m): ChatHistoryItem => { - switch (m.type) { - case "system": - return { - type: "system", - text: m.text, - }; - case "model": - return { - type: "model", - response: [m.text], - }; - case "user": - return { - type: "user", - text: m.text, - }; - } - }), - ); - - const res = await session.prompt(last.text, { - temperature: 1.0, - repeatPenalty: { - frequencyPenalty: 1, - }, - onResponseChunk(chunk) { - process.stdout.write(chunk.text); - }, - }); - console.log(); - session.dispose(); - await context.dispose(); - return res; -} -// #endregion - -// #region util - -/** pick up to random N elements from array. - * just shuffle it if N is unspecified or greater than the length of the array. - * the original array remains unmodified. */ -function sample(arr: T[], n: number = arr.length): T[] { - if (n > arr.length) return sample(arr, arr.length); - const copy = [...arr]; - for (let i = 0; i < n; i++) { - const j = i + Math.floor(Math.random() * (copy.length - i)); - [copy[i], copy[j]] = [copy[j] as T, copy[i] as T]; - } - return copy.slice(0, n); -} - -/** sleep for N milliseconds */ -const sleep = (msec: number) => - new Promise((resolve) => setTimeout(resolve, msec)); -// #endregion - -// #region misskey -const misskey = new api.APIClient({ - origin: Bun.env["MISSKEY_ORIGIN"] || "https://misskey.cannorin.net", - credential: Bun.env["MISSKEY_CREDENTIAL"], -}); - -const me = await misskey.request("i", {}); - -/** check if a note is suitable as an input */ -const isSuitableAsInput = (n: Note) => - !n.user.isBot && - !n.replyId && - (!n.mentions || n.mentions.length === 0) && - n.text?.length && - n.text.length > 0; - -/** randomly sample some notes from the timeline */ -async function getNotes() { - // randomly sample N local notes - const localNotes = (count: number) => - misskey - .request("notes/local-timeline", { limit: 100 }) - .then((xs) => xs.filter(isSuitableAsInput)) - .then((xs) => sample(xs, count)); - - // randomly sample N global notes - const globalNotes = (count: number) => - misskey - .request("notes/global-timeline", { limit: 100 }) - .then((xs) => xs.filter(isSuitableAsInput)) - .then((xs) => sample(xs, count)); - - const notes = await Promise.all([localNotes(5), globalNotes(10)]); - return sample(notes.flat()); -} - -/** fetch the whole reply tree */ -async function expandReplyTree(note: Note, acc: Note[] = [], cutoff = 5) { - if (!note.reply || cutoff < 1) return [...acc, note]; - const reply = await misskey.request("notes/show", { noteId: note.reply.id }); - return await expandReplyTree(reply, [...acc, note], cutoff - 1); -} - -/** convert a note to a chat message */ -const noteToMessage = (note: Note): Message => ({ - type: note.userId === me.id ? ("model" as const) : ("user" as const), - text: note.text?.replaceAll(`@${me.username}`, "") || "", -}); -// #endregion - -// #region job type Job = // read posts and post a note | { type: "post" } @@ -219,7 +83,7 @@ async function generate(job: Job) { const messages = await preparePrompt(job); // request chat completion - const response = await complete(messages); + const response = await complete(model, messages); // concatenate the partial responses const text = response diff --git a/lib/llm.ts b/lib/llm.ts new file mode 100644 index 0000000..7c422f1 --- /dev/null +++ b/lib/llm.ts @@ -0,0 +1,76 @@ +import path from "node:path"; +import { fileURLToPath } from "node:url"; + +import { + type ChatHistoryItem, + LlamaChatSession, + type LlamaModel, + createModelDownloader, + getLlama, + resolveChatWrapper, +} from "node-llama-cpp"; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +export async function getModel(model: string) { + const downloader = await createModelDownloader({ + modelUri: `hf:${model}`, + dirPath: path.join(__dirname, "..", "models"), + }); + const modelPath = await downloader.download(); + const llama = await getLlama({ + maxThreads: 2, + }); + return await llama.loadModel({ modelPath }); +} + +export type Message = { + type: "system" | "model" | "user"; + text: string; +}; + +export async function complete(model: LlamaModel, messages: Message[]) { + if (messages.length < 1) throw new Error("messages are empty"); + const init = messages.slice(0, -1); + const last = messages.at(-1) as Message; + const context = await model.createContext(); + const session = new LlamaChatSession({ + contextSequence: context.getSequence(), + chatWrapper: resolveChatWrapper(model), + }); + session.setChatHistory( + init.map((m): ChatHistoryItem => { + switch (m.type) { + case "system": + return { + type: "system", + text: m.text, + }; + case "model": + return { + type: "model", + response: [m.text], + }; + case "user": + return { + type: "user", + text: m.text, + }; + } + }), + ); + + const res = await session.prompt(last.text, { + temperature: 1.0, + repeatPenalty: { + frequencyPenalty: 1, + }, + onResponseChunk(chunk) { + process.stderr.write(chunk.text); + }, + maxTokens: 200, + }); + session.dispose(); + await context.dispose(); + return res; +} diff --git a/lib/misskey.ts b/lib/misskey.ts new file mode 100644 index 0000000..afbfd70 --- /dev/null +++ b/lib/misskey.ts @@ -0,0 +1,56 @@ +import { api } from "misskey-js"; +import type { Note } from "misskey-js/entities.js"; +import { sample } from "./util"; +import type { Message } from "./llm"; + +export const misskey = new api.APIClient({ + origin: Bun.env["MISSKEY_ORIGIN"] || "https://misskey.cannorin.net", + credential: Bun.env["MISSKEY_CREDENTIAL"], +}); + +export const me = await misskey.request("i", {}); + +/** check if a note is suitable as an input */ +export const isSuitableAsInput = (n: Note) => + !n.user.isBot && + !n.replyId && + (!n.mentions || n.mentions.length === 0) && + n.text?.length && + n.text.length > 0; + +/** randomly sample some notes from the timeline */ +export async function getNotes() { + // randomly sample N local notes + const localNotes = (count: number) => + misskey + .request("notes/local-timeline", { limit: 100 }) + .then((xs) => xs.filter(isSuitableAsInput)) + .then((xs) => sample(xs, count)); + + // randomly sample N global notes + const globalNotes = (count: number) => + misskey + .request("notes/global-timeline", { limit: 100 }) + .then((xs) => xs.filter(isSuitableAsInput)) + .then((xs) => sample(xs, count)); + + const notes = await Promise.all([localNotes(5), globalNotes(10)]); + return sample(notes.flat()); +} + +/** fetch the whole reply tree */ +export async function expandReplyTree( + note: Note, + acc: Note[] = [], + cutoff = 5, +) { + if (!note.reply || cutoff < 1) return [...acc, note]; + const reply = await misskey.request("notes/show", { noteId: note.reply.id }); + return await expandReplyTree(reply, [...acc, note], cutoff - 1); +} + +/** convert a note to a chat message */ +export const noteToMessage = (note: Note): Message => ({ + type: note.userId === me.id ? ("model" as const) : ("user" as const), + text: note.text?.replaceAll(`@${me.username}`, "") || "", +}); diff --git a/lib/util.ts b/lib/util.ts new file mode 100644 index 0000000..5fedb57 --- /dev/null +++ b/lib/util.ts @@ -0,0 +1,16 @@ +/** pick up to random N elements from array. + * just shuffle it if N is unspecified or greater than the length of the array. + * the original array remains unmodified. */ +export function sample(arr: T[], n: number = arr.length): T[] { + if (n > arr.length) return sample(arr, arr.length); + const copy = [...arr]; + for (let i = 0; i < n; i++) { + const j = i + Math.floor(Math.random() * (copy.length - i)); + [copy[i], copy[j]] = [copy[j] as T, copy[i] as T]; + } + return copy.slice(0, n); +} + +/** sleep for N milliseconds */ +export const sleep = (msec: number) => + new Promise((resolve) => setTimeout(resolve, msec));