Refactor
This commit is contained in:
168
index.ts
168
index.ts
@@ -1,157 +1,21 @@
|
|||||||
import path from "node:path";
|
import type { Note } from "misskey-js/entities.js";
|
||||||
import { fileURLToPath } from "node:url";
|
import { complete, getModel, type Message } from "./lib/llm";
|
||||||
|
import {
|
||||||
|
expandReplyTree,
|
||||||
|
getNotes,
|
||||||
|
me,
|
||||||
|
misskey,
|
||||||
|
noteToMessage,
|
||||||
|
} from "./lib/misskey";
|
||||||
|
import { Stream } from "misskey-js";
|
||||||
|
import { sleep } from "./lib/util";
|
||||||
import { parseArgs } from "node:util";
|
import { parseArgs } from "node:util";
|
||||||
|
|
||||||
import { api } from "misskey-js";
|
const modelName =
|
||||||
import { Stream } from "misskey-js";
|
Bun.env["MODEL"] ?? "mradermacher/gemma-2-baku-2b-it-GGUF:IQ4_XS";
|
||||||
import type { Note } from "misskey-js/entities.js";
|
console.log(`* loading model '${modelName}'`);
|
||||||
|
const model = await getModel(modelName);
|
||||||
|
|
||||||
import {
|
|
||||||
type ChatHistoryItem,
|
|
||||||
LlamaChatSession,
|
|
||||||
createModelDownloader,
|
|
||||||
getLlama,
|
|
||||||
resolveChatWrapper,
|
|
||||||
} from "node-llama-cpp";
|
|
||||||
|
|
||||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
||||||
|
|
||||||
// #region llm
|
|
||||||
const model = await (async () => {
|
|
||||||
const downloader = await createModelDownloader({
|
|
||||||
modelUri: `hf:${Bun.env["MODEL"] ?? "mradermacher/gemma-2-baku-2b-it-GGUF:IQ4_XS"}`,
|
|
||||||
dirPath: path.join(__dirname, "models"),
|
|
||||||
});
|
|
||||||
const modelPath = await downloader.download();
|
|
||||||
const llama = await getLlama({
|
|
||||||
maxThreads: 2,
|
|
||||||
});
|
|
||||||
return await llama.loadModel({ modelPath });
|
|
||||||
})();
|
|
||||||
|
|
||||||
type Message = {
|
|
||||||
type: "system" | "model" | "user";
|
|
||||||
text: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
async function complete(messages: Message[]) {
|
|
||||||
if (messages.length < 1) throw new Error("messages are empty");
|
|
||||||
const init = messages.slice(0, -1);
|
|
||||||
const last = messages.at(-1) as Message;
|
|
||||||
const context = await model.createContext();
|
|
||||||
const session = new LlamaChatSession({
|
|
||||||
contextSequence: context.getSequence(),
|
|
||||||
chatWrapper: resolveChatWrapper(model),
|
|
||||||
});
|
|
||||||
session.setChatHistory(
|
|
||||||
init.map((m): ChatHistoryItem => {
|
|
||||||
switch (m.type) {
|
|
||||||
case "system":
|
|
||||||
return {
|
|
||||||
type: "system",
|
|
||||||
text: m.text,
|
|
||||||
};
|
|
||||||
case "model":
|
|
||||||
return {
|
|
||||||
type: "model",
|
|
||||||
response: [m.text],
|
|
||||||
};
|
|
||||||
case "user":
|
|
||||||
return {
|
|
||||||
type: "user",
|
|
||||||
text: m.text,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
const res = await session.prompt(last.text, {
|
|
||||||
temperature: 1.0,
|
|
||||||
repeatPenalty: {
|
|
||||||
frequencyPenalty: 1,
|
|
||||||
},
|
|
||||||
onResponseChunk(chunk) {
|
|
||||||
process.stdout.write(chunk.text);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
console.log();
|
|
||||||
session.dispose();
|
|
||||||
await context.dispose();
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
// #endregion
|
|
||||||
|
|
||||||
// #region util
|
|
||||||
|
|
||||||
/** pick up to random N elements from array.
|
|
||||||
* just shuffle it if N is unspecified or greater than the length of the array.
|
|
||||||
* the original array remains unmodified. */
|
|
||||||
function sample<T>(arr: T[], n: number = arr.length): T[] {
|
|
||||||
if (n > arr.length) return sample(arr, arr.length);
|
|
||||||
const copy = [...arr];
|
|
||||||
for (let i = 0; i < n; i++) {
|
|
||||||
const j = i + Math.floor(Math.random() * (copy.length - i));
|
|
||||||
[copy[i], copy[j]] = [copy[j] as T, copy[i] as T];
|
|
||||||
}
|
|
||||||
return copy.slice(0, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** sleep for N milliseconds */
|
|
||||||
const sleep = (msec: number) =>
|
|
||||||
new Promise((resolve) => setTimeout(resolve, msec));
|
|
||||||
// #endregion
|
|
||||||
|
|
||||||
// #region misskey
|
|
||||||
const misskey = new api.APIClient({
|
|
||||||
origin: Bun.env["MISSKEY_ORIGIN"] || "https://misskey.cannorin.net",
|
|
||||||
credential: Bun.env["MISSKEY_CREDENTIAL"],
|
|
||||||
});
|
|
||||||
|
|
||||||
const me = await misskey.request("i", {});
|
|
||||||
|
|
||||||
/** check if a note is suitable as an input */
|
|
||||||
const isSuitableAsInput = (n: Note) =>
|
|
||||||
!n.user.isBot &&
|
|
||||||
!n.replyId &&
|
|
||||||
(!n.mentions || n.mentions.length === 0) &&
|
|
||||||
n.text?.length &&
|
|
||||||
n.text.length > 0;
|
|
||||||
|
|
||||||
/** randomly sample some notes from the timeline */
|
|
||||||
async function getNotes() {
|
|
||||||
// randomly sample N local notes
|
|
||||||
const localNotes = (count: number) =>
|
|
||||||
misskey
|
|
||||||
.request("notes/local-timeline", { limit: 100 })
|
|
||||||
.then((xs) => xs.filter(isSuitableAsInput))
|
|
||||||
.then((xs) => sample(xs, count));
|
|
||||||
|
|
||||||
// randomly sample N global notes
|
|
||||||
const globalNotes = (count: number) =>
|
|
||||||
misskey
|
|
||||||
.request("notes/global-timeline", { limit: 100 })
|
|
||||||
.then((xs) => xs.filter(isSuitableAsInput))
|
|
||||||
.then((xs) => sample(xs, count));
|
|
||||||
|
|
||||||
const notes = await Promise.all([localNotes(5), globalNotes(10)]);
|
|
||||||
return sample(notes.flat());
|
|
||||||
}
|
|
||||||
|
|
||||||
/** fetch the whole reply tree */
|
|
||||||
async function expandReplyTree(note: Note, acc: Note[] = [], cutoff = 5) {
|
|
||||||
if (!note.reply || cutoff < 1) return [...acc, note];
|
|
||||||
const reply = await misskey.request("notes/show", { noteId: note.reply.id });
|
|
||||||
return await expandReplyTree(reply, [...acc, note], cutoff - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** convert a note to a chat message */
|
|
||||||
const noteToMessage = (note: Note): Message => ({
|
|
||||||
type: note.userId === me.id ? ("model" as const) : ("user" as const),
|
|
||||||
text: note.text?.replaceAll(`@${me.username}`, "") || "",
|
|
||||||
});
|
|
||||||
// #endregion
|
|
||||||
|
|
||||||
// #region job
|
|
||||||
type Job =
|
type Job =
|
||||||
// read posts and post a note
|
// read posts and post a note
|
||||||
| { type: "post" }
|
| { type: "post" }
|
||||||
@@ -219,7 +83,7 @@ async function generate(job: Job) {
|
|||||||
const messages = await preparePrompt(job);
|
const messages = await preparePrompt(job);
|
||||||
|
|
||||||
// request chat completion
|
// request chat completion
|
||||||
const response = await complete(messages);
|
const response = await complete(model, messages);
|
||||||
|
|
||||||
// concatenate the partial responses
|
// concatenate the partial responses
|
||||||
const text = response
|
const text = response
|
||||||
|
|||||||
76
lib/llm.ts
Normal file
76
lib/llm.ts
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import path from "node:path";
|
||||||
|
import { fileURLToPath } from "node:url";
|
||||||
|
|
||||||
|
import {
|
||||||
|
type ChatHistoryItem,
|
||||||
|
LlamaChatSession,
|
||||||
|
type LlamaModel,
|
||||||
|
createModelDownloader,
|
||||||
|
getLlama,
|
||||||
|
resolveChatWrapper,
|
||||||
|
} from "node-llama-cpp";
|
||||||
|
|
||||||
|
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||||
|
|
||||||
|
export async function getModel(model: string) {
|
||||||
|
const downloader = await createModelDownloader({
|
||||||
|
modelUri: `hf:${model}`,
|
||||||
|
dirPath: path.join(__dirname, "..", "models"),
|
||||||
|
});
|
||||||
|
const modelPath = await downloader.download();
|
||||||
|
const llama = await getLlama({
|
||||||
|
maxThreads: 2,
|
||||||
|
});
|
||||||
|
return await llama.loadModel({ modelPath });
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Message = {
|
||||||
|
type: "system" | "model" | "user";
|
||||||
|
text: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function complete(model: LlamaModel, messages: Message[]) {
|
||||||
|
if (messages.length < 1) throw new Error("messages are empty");
|
||||||
|
const init = messages.slice(0, -1);
|
||||||
|
const last = messages.at(-1) as Message;
|
||||||
|
const context = await model.createContext();
|
||||||
|
const session = new LlamaChatSession({
|
||||||
|
contextSequence: context.getSequence(),
|
||||||
|
chatWrapper: resolveChatWrapper(model),
|
||||||
|
});
|
||||||
|
session.setChatHistory(
|
||||||
|
init.map((m): ChatHistoryItem => {
|
||||||
|
switch (m.type) {
|
||||||
|
case "system":
|
||||||
|
return {
|
||||||
|
type: "system",
|
||||||
|
text: m.text,
|
||||||
|
};
|
||||||
|
case "model":
|
||||||
|
return {
|
||||||
|
type: "model",
|
||||||
|
response: [m.text],
|
||||||
|
};
|
||||||
|
case "user":
|
||||||
|
return {
|
||||||
|
type: "user",
|
||||||
|
text: m.text,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const res = await session.prompt(last.text, {
|
||||||
|
temperature: 1.0,
|
||||||
|
repeatPenalty: {
|
||||||
|
frequencyPenalty: 1,
|
||||||
|
},
|
||||||
|
onResponseChunk(chunk) {
|
||||||
|
process.stderr.write(chunk.text);
|
||||||
|
},
|
||||||
|
maxTokens: 200,
|
||||||
|
});
|
||||||
|
session.dispose();
|
||||||
|
await context.dispose();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
56
lib/misskey.ts
Normal file
56
lib/misskey.ts
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import { api } from "misskey-js";
|
||||||
|
import type { Note } from "misskey-js/entities.js";
|
||||||
|
import { sample } from "./util";
|
||||||
|
import type { Message } from "./llm";
|
||||||
|
|
||||||
|
export const misskey = new api.APIClient({
|
||||||
|
origin: Bun.env["MISSKEY_ORIGIN"] || "https://misskey.cannorin.net",
|
||||||
|
credential: Bun.env["MISSKEY_CREDENTIAL"],
|
||||||
|
});
|
||||||
|
|
||||||
|
export const me = await misskey.request("i", {});
|
||||||
|
|
||||||
|
/** check if a note is suitable as an input */
|
||||||
|
export const isSuitableAsInput = (n: Note) =>
|
||||||
|
!n.user.isBot &&
|
||||||
|
!n.replyId &&
|
||||||
|
(!n.mentions || n.mentions.length === 0) &&
|
||||||
|
n.text?.length &&
|
||||||
|
n.text.length > 0;
|
||||||
|
|
||||||
|
/** randomly sample some notes from the timeline */
|
||||||
|
export async function getNotes() {
|
||||||
|
// randomly sample N local notes
|
||||||
|
const localNotes = (count: number) =>
|
||||||
|
misskey
|
||||||
|
.request("notes/local-timeline", { limit: 100 })
|
||||||
|
.then((xs) => xs.filter(isSuitableAsInput))
|
||||||
|
.then((xs) => sample(xs, count));
|
||||||
|
|
||||||
|
// randomly sample N global notes
|
||||||
|
const globalNotes = (count: number) =>
|
||||||
|
misskey
|
||||||
|
.request("notes/global-timeline", { limit: 100 })
|
||||||
|
.then((xs) => xs.filter(isSuitableAsInput))
|
||||||
|
.then((xs) => sample(xs, count));
|
||||||
|
|
||||||
|
const notes = await Promise.all([localNotes(5), globalNotes(10)]);
|
||||||
|
return sample(notes.flat());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** fetch the whole reply tree */
|
||||||
|
export async function expandReplyTree(
|
||||||
|
note: Note,
|
||||||
|
acc: Note[] = [],
|
||||||
|
cutoff = 5,
|
||||||
|
) {
|
||||||
|
if (!note.reply || cutoff < 1) return [...acc, note];
|
||||||
|
const reply = await misskey.request("notes/show", { noteId: note.reply.id });
|
||||||
|
return await expandReplyTree(reply, [...acc, note], cutoff - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** convert a note to a chat message */
|
||||||
|
export const noteToMessage = (note: Note): Message => ({
|
||||||
|
type: note.userId === me.id ? ("model" as const) : ("user" as const),
|
||||||
|
text: note.text?.replaceAll(`@${me.username}`, "") || "",
|
||||||
|
});
|
||||||
16
lib/util.ts
Normal file
16
lib/util.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
/** pick up to random N elements from array.
|
||||||
|
* just shuffle it if N is unspecified or greater than the length of the array.
|
||||||
|
* the original array remains unmodified. */
|
||||||
|
export function sample<T>(arr: T[], n: number = arr.length): T[] {
|
||||||
|
if (n > arr.length) return sample(arr, arr.length);
|
||||||
|
const copy = [...arr];
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
const j = i + Math.floor(Math.random() * (copy.length - i));
|
||||||
|
[copy[i], copy[j]] = [copy[j] as T, copy[i] as T];
|
||||||
|
}
|
||||||
|
return copy.slice(0, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** sleep for N milliseconds */
|
||||||
|
export const sleep = (msec: number) =>
|
||||||
|
new Promise((resolve) => setTimeout(resolve, msec));
|
||||||
Reference in New Issue
Block a user