79 lines
1.8 KiB
TypeScript
79 lines
1.8 KiB
TypeScript
import path from "node:path";
|
|
import { fileURLToPath } from "node:url";
|
|
|
|
import {
|
|
type ChatHistoryItem,
|
|
type LLamaChatPromptOptions,
|
|
LlamaChatSession,
|
|
type LlamaModel,
|
|
createModelDownloader,
|
|
getLlama,
|
|
resolveChatWrapper,
|
|
} from "node-llama-cpp";
|
|
|
|
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
|
|
export async function getModel(model: string) {
|
|
const downloader = await createModelDownloader({
|
|
modelUri: `hf:${model}`,
|
|
dirPath: path.join(__dirname, "..", "models"),
|
|
});
|
|
const modelPath = await downloader.download();
|
|
const llama = await getLlama({
|
|
maxThreads: 2,
|
|
});
|
|
return await llama.loadModel({ modelPath });
|
|
}
|
|
|
|
export type Message = {
|
|
type: "system" | "model" | "user";
|
|
text: string;
|
|
};
|
|
|
|
export async function complete(
|
|
model: LlamaModel,
|
|
messages: Message[],
|
|
options: LLamaChatPromptOptions = {},
|
|
) {
|
|
if (messages.length < 1) throw new Error("messages are empty");
|
|
const init = messages.slice(0, -1);
|
|
const last = messages.at(-1) as Message;
|
|
const context = await model.createContext();
|
|
const session = new LlamaChatSession({
|
|
contextSequence: context.getSequence(),
|
|
chatWrapper: resolveChatWrapper(model),
|
|
});
|
|
session.setChatHistory(
|
|
init.map((m): ChatHistoryItem => {
|
|
switch (m.type) {
|
|
case "system":
|
|
return {
|
|
type: "system",
|
|
text: m.text,
|
|
};
|
|
case "model":
|
|
return {
|
|
type: "model",
|
|
response: [m.text],
|
|
};
|
|
case "user":
|
|
return {
|
|
type: "user",
|
|
text: m.text,
|
|
};
|
|
}
|
|
}),
|
|
);
|
|
|
|
const res = await session.prompt(last.text, {
|
|
trimWhitespaceSuffix: true,
|
|
onResponseChunk(chunk) {
|
|
process.stderr.write(chunk.text);
|
|
},
|
|
...options,
|
|
});
|
|
session.dispose();
|
|
await context.dispose();
|
|
return res;
|
|
}
|