Refactor (9)

This commit is contained in:
2026-03-05 22:59:05 +09:00
parent 0d895797f6
commit f61298f9e4
2 changed files with 20 additions and 2 deletions

View File

@@ -2,7 +2,13 @@ import { parseArgs } from "node:util";
import { Stream } from "misskey-js"; import { Stream } from "misskey-js";
import type { Note } from "misskey-js/entities.js"; import type { Note } from "misskey-js/entities.js";
import type { ChatHistoryItem, LLamaChatPromptOptions } from "node-llama-cpp"; import type { ChatHistoryItem, LLamaChatPromptOptions } from "node-llama-cpp";
import { LlmSession, createGrammar, getModel, parseResponse } from "./lib/llm"; import {
LlmSession,
createBias,
createGrammar,
getModel,
parseResponse,
} from "./lib/llm";
import { import {
expandReplyTree, expandReplyTree,
getNotes, getNotes,
@@ -30,9 +36,11 @@ const modelName =
console.log(`* loading model '${modelName}'`); console.log(`* loading model '${modelName}'`);
const model = await getModel(modelName); const model = await getModel(modelName);
const grammar = await createGrammar("あるびのちゃん"); const grammar = await createGrammar("あるびのちゃん");
const bias = createBias(model);
const baseChatPromptOptions = { const baseChatPromptOptions = {
grammar, grammar,
maxTokens: 256, maxTokens: 256,
tokenBias: bias,
} as const satisfies LLamaChatPromptOptions; } as const satisfies LLamaChatPromptOptions;
const getSystemPrompt = ( const getSystemPrompt = (
@@ -58,7 +66,7 @@ async function rephrase(text: string) {
await using rephraseSession = new LlmSession( await using rephraseSession = new LlmSession(
model, model,
getSystemPrompt( getSystemPrompt(
"user が与えたテキストを『ですます調』(丁寧な文体)で言い換えたものを、そのまま出力してください。", "ユーザが与えたテキストを「~です」「~ます」調(丁寧な文体)で言い換えたものを、そのまま出力してください。",
), ),
); );
await rephraseSession.init(); await rephraseSession.init();

View File

@@ -7,6 +7,7 @@ import {
type LLamaChatPromptOptions, type LLamaChatPromptOptions,
LlamaChatSession, LlamaChatSession,
type LlamaModel, type LlamaModel,
TokenBias,
createModelDownloader, createModelDownloader,
getLlama, getLlama,
resolveChatWrapper, resolveChatWrapper,
@@ -27,6 +28,15 @@ export async function getModel(model: string) {
return await llama.loadModel({ modelPath }); return await llama.loadModel({ modelPath });
} }
export function createBias(model: LlamaModel) {
const customBias = new TokenBias(model.tokenizer);
for (const token of model.iterateAllTokens()) {
const text = model.detokenize([token]);
if (text === "{") customBias.set(token, -0.9); // suppress JSON string
}
return customBias;
}
export const createGrammar = (assistantName: string) => export const createGrammar = (assistantName: string) =>
llama.createGrammarForJsonSchema({ llama.createGrammarForJsonSchema({
type: "object", type: "object",