107 lines
2.6 KiB
TypeScript
107 lines
2.6 KiB
TypeScript
import path from "node:path";
|
|
import { fileURLToPath } from "node:url";
|
|
|
|
import {
|
|
type ChatHistoryItem,
|
|
type ChatSessionModelFunctions,
|
|
type LLamaChatPromptOptions,
|
|
LlamaChatSession,
|
|
type LlamaModel,
|
|
createModelDownloader,
|
|
getLlama,
|
|
resolveChatWrapper,
|
|
} from "node-llama-cpp";
|
|
|
|
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
|
|
const llama = await getLlama({
|
|
maxThreads: 2,
|
|
});
|
|
|
|
export async function getModel(model: string) {
|
|
const downloader = await createModelDownloader({
|
|
modelUri: `hf:${model}`,
|
|
dirPath: path.join(__dirname, "..", "models"),
|
|
});
|
|
const modelPath = await downloader.download();
|
|
return await llama.loadModel({ modelPath });
|
|
}
|
|
|
|
export const createGrammar = (assistantName: string) =>
|
|
llama.createGrammarForJsonSchema({
|
|
type: "object",
|
|
properties: {
|
|
name: { type: "string", enum: [assistantName] },
|
|
text: { type: "string" },
|
|
},
|
|
required: ["text"],
|
|
additionalProperties: false,
|
|
});
|
|
|
|
export function parseResponse(
|
|
grammar: Awaited<ReturnType<typeof createGrammar>>,
|
|
text: string,
|
|
) {
|
|
try {
|
|
const res = grammar.parse(text.trim());
|
|
return res.text;
|
|
} catch (e) {
|
|
console.error("Failed to parse response:", e);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
export class LlmSession {
|
|
model: LlamaModel;
|
|
systemPrompt: string;
|
|
additionalChatHistory: ChatHistoryItem[] = [];
|
|
private context: Awaited<ReturnType<LlamaModel["createContext"]>> | null =
|
|
null;
|
|
private session: LlamaChatSession | null = null;
|
|
|
|
constructor(
|
|
model: LlamaModel,
|
|
systemPrompt: string,
|
|
additionalChatHistory: ChatHistoryItem[] = [],
|
|
) {
|
|
this.model = model;
|
|
this.systemPrompt = systemPrompt;
|
|
this.additionalChatHistory = additionalChatHistory;
|
|
}
|
|
|
|
async init() {
|
|
this.context = await this.model.createContext();
|
|
this.session = new LlamaChatSession({
|
|
contextSequence: this.context.getSequence(),
|
|
chatWrapper: resolveChatWrapper(this.model),
|
|
});
|
|
this.session.setChatHistory([
|
|
{
|
|
type: "system",
|
|
text: this.systemPrompt,
|
|
},
|
|
...this.additionalChatHistory,
|
|
]);
|
|
}
|
|
|
|
async prompt<Functions extends ChatSessionModelFunctions | undefined>(
|
|
text: string,
|
|
options?: LLamaChatPromptOptions<Functions>,
|
|
) {
|
|
if (!this.session) await this.init();
|
|
if (!this.session) throw new Error("session is not initialized");
|
|
return await this.session.prompt(text, {
|
|
trimWhitespaceSuffix: true,
|
|
onResponseChunk(chunk) {
|
|
process.stderr.write(chunk.text);
|
|
},
|
|
...options,
|
|
});
|
|
}
|
|
|
|
async [Symbol.asyncDispose]() {
|
|
await this.session?.dispose();
|
|
await this.context?.dispose();
|
|
}
|
|
}
|