今天要接著要實作後端 API 了,是這個專案中最難的部分,要起飛啦~
*Droidstacean by Ivan Lozano, based on a design by Karen Rustad Tölva.
🏮 今天完整的程式碼可以拉到最底下 Put it together 區塊或是在 GitHub 找到。
我們繼續延續模組化的精神,先建立 src/api.rs
,所有的 server-side function 都會放在這裡,記得在 lib.rs
的開頭加上 pub mod api;
以將其加入模組樹中。
converse
函式在 src/api.rs
中要建立的是前天 [Day 14] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (5/9)|Signal & Action 放在 todo! 巨集中的 converse
異步函式 (async),其函式簽名如下:
use crate::model::conversation::Conversation;
use leptos::*;
#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {}
其輸入資料的型別為 Conversation
,而輸出是一個 Result 列舉,在成功時回傳 String 型別,失敗時回傳 ServerFnError
。
而我們使用 #[server(Converse "/api")]
巨集來標注這個函式,其中前者是名稱,後者是位址。
這個巨集讓我們只需要建立運行於伺服器端的 API,就能自動得到另一個可用於客戶端的同名函式,以將 HTTP request 發送至相對應的後端 API。
換句話說,我們並不需要在客戶端手動建立 HTTP request 或指定 API 的 URL 等細節,一切都會自動發生,這也是使用 Leptos 的一大特點。
我們要做的只有在函式簽名中給定要從客戶端傳到後端 API 的型別 (Conversation
),以及要傳回客戶端的型別 (String
) 即可。
🚨 注意,
converse
函式是公開的 API,並沒有設定任何的認證與授權,如果真的要產品化,記得保護這些 API~
回到這個函式,它的任務就是要將 Conversation
送給模型進行推論,然後把推論結果回傳到客戶端。
而要進行推論就得使用到 Rustformers 的 llm crate 啦,以下是需要在函式 Scope 中引入的函式庫:
use actix_web::dev::ConnectionInfo;
use actix_web::web::Data;
use leptos_actix::extract;
use llm::models::Llama;
use llm::KnownModel;
與此同時也得把 crate 加到 Cargo.toml 中,這裡參考官方在 Using llm
in a Rust Project 中的說明,使用與 main
分支相同的版本以獲取最新的功能,所以在 [dependencies]
區塊最後面加上:
llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}
這裡的 optional
的意思是我們可以只將這個 crate 加到 ssr 中,所以在 [features]
區塊 ssr
的 list 中也加上:
ssr = [
"dep:actix-files",
"dep:actix-web",
"dep:leptos_actix",
**"dep:llm",**
"leptos/ssr",
"leptos_meta/ssr",
"leptos_router/ssr",
]
另外,為了提升除錯上的表現,可以把 ggml-sys
在除錯模式中排除,所以把以下區塊加到 [features]
區塊下面:
[profile.dev.package.ggml-sys]
opt-level = 3
由於我們選用 Actix 作為後端框架,所以必須了解其處理程序都建立在名為提取器 (extractor) 的概念之上。
提取器能從 HTTP 請求中 “提取” 型別資料,讓我們得以存取伺服器端的特定資料,而 Leptos 提供了 extract
輔助函式,讓我們可以直接在伺服器端函式中使用這些提取器。
這個輔助函式就是 leptos_actix 中的 extract 函式,它接受處理程序函式作為其參數,其中處理程序是一個異步函式,接收從請求中提取的參數並返回某個值。
而處理程序函式將提取到的資料作為其參數,並且可以在異步移動塊 (async move) 的主體內對其進行進一步的異步處理,並回傳任何想要回傳到伺服器端函式中的值,而這裡就是我們的模型:
let model =
extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
.await
.unwrap();
詳細可以參考官方教學 Using Extractors
而要建立聊天機器人,我們需要做適當的 prompt engineering,以使 LLM 能以我們想要的方式進行互動,這裡我們直接參考 Taiwan-LLaMa/demo/conversation.py 中 conv_one_shot 的部分,得知其格式類似於:
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
### Assistant: Hello! How may I help you today?
### Human: What is the capital of Taiwan?
### Assistant: Taipei is the capital of Taiwan.
不同的語言模型有不同的模板,通常差別在於第一行與角色的名稱,常見的清單可以參考 Prompt Template for OpenSource LLMs。
這個 prompt engineering 就會作為指令的開頭餵給 LLM,而在其之後就會接著歷史對話,而在歷史對話之後,很重要的是要再接上 ### Assistant:
作為結尾,讓模型可以接著回答下去。
現在就先建立開頭與歷史對話的部分:
let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
let user_name = "### Human";
let bot_name = "### Assistant";
let mut history = format!(
"{bot_name}:Hello! How may I help you today?\n\
{user_name}:What is the capital of Taiwan?\n\
{bot_name}:Taipei is the capital of Taiwan.\n"
);
for message in prompt.messages.into_iter() {
let msg = message.text;
let curr_line = if message.user {
format!("{user_name}:{msg}\n")
} else {
format!("{bot_name}:{msg}\n")
};
history.push_str(&curr_line);
}
接著參考 llm 官方文件 與 官方範例 vicuna-chat.rs 我們可以開始使用模型來推論了,其中說明直接註解在上面:
let mut res = String::new();
let mut buf = String::new();
// use the model to generate text from a prompt
let mut session = model.start_session(Default::default());
session
.infer(
// model to use for text generation
model.as_ref(),
// randomness provider
&mut rand::thread_rng(),
// the prompt to use for text generation, as well as other
// inference parameters
&llm::InferenceRequest {
prompt: format!("{system_prompt}\n{history}\n{bot_name}:}")
.as_str()
.into(),
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
// llm::OutputRequest
&mut Default::default(),
// output callback
inference_callback(String::from(user_name), &mut buf, &mut res),
)
.unwrap_or_else(|e| panic!("{e}"));
預設上,llm
會自動從 Hugging Face 的 model hub 下載 tokenizer。
由於其中用到了 rand crate,所以要記得要在 Cargo.toml 加上:
rand = { version = "0.8.5", optional = true}
以及 ssr 也要記得加。
inference_callback
函式其中 inference_callback
的主要功能為告訴 session.infer()
函式何時該停止,而我們想要模型在輸出人類的角色名稱 ### Human
時停止,否則模型就會開始無止境地自言自語了。
這裡我們參考的是 llm/crates/llm-base/src/inference_session.rs 中 conversation_inference_callback 函式的寫法,同樣說明也註解在其中:
cfg_if! {
if #[cfg(feature = "ssr")] {
use std::convert::Infallible;
fn inference_callback<'a>(
stop_sequence: String,
buf: &'a mut String,
out_str: &'a mut String,
) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
move |resp| match resp {
llm::InferenceResponse::InferredToken(token) => {
// We've generated a token, so we need to check if it's contained in the stop sequence.
let mut stop_sequence_buf = buf.clone();
stop_sequence_buf.push_str(token.as_str());
if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
// modifying the model.
buf.clear();
return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
// We've generated a prefix of the stop sequence, so we need to keep buffering.
buf.push_str(token.as_str());
return Ok(llm::InferenceFeedback::Continue);
}
// We've generated a token that isn't part of the stop sequence, so we can
// pass it to the callback.
if buf.is_empty() {
out_str.push_str(&token);
} else {
out_str.push_str(&stop_sequence_buf);
}
Ok(llm::InferenceFeedback::Continue)
}
llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
_ => Ok(llm::InferenceFeedback::Continue),
}
}
}
}
其中 cfg_if
巨集的用意在於讓這個函式只在伺服器端 (ssr) 被編譯,如此一來就完成 API 的建立囉。
另外這裡也是 Lifetimes 註解 <'a>
第一次在這次系列文中出現,這部分之後會再說明!
此時整個 src/api.rs
會像這樣:
use crate::model::conversation::Conversation;
use cfg_if::cfg_if;
use leptos::*;
#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {
use actix_web::dev::ConnectionInfo;
use actix_web::web::Data;
use leptos_actix::extract;
use llm::models::Llama;
use llm::KnownModel;
let model =
extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
.await
.unwrap();
let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
let user_name = "### Human";
let bot_name = "### Assistant";
let mut history = format!(
"{bot_name}:Hello! How may I help you today?\n\
{user_name}:What is the capital of Taiwan?\n\
{bot_name}:Taipei is the capital of Taiwan.\n"
);
for message in prompt.messages.into_iter() {
let msg = message.text;
let curr_line = if message.user {
format!("{user_name}:{msg}\n")
} else {
format!("{bot_name}:{msg}\n")
};
history.push_str(&curr_line);
}
let mut res = String::new();
let mut buf = String::new();
// use the model to generate text from a prompt
let mut session = model.start_session(Default::default());
session
.infer(
// model to use for text generation
model.as_ref(),
// randomness provider
&mut rand::thread_rng(),
// the prompt to use for text generation, as well as other
// inference parameters
&llm::InferenceRequest {
prompt: format!("{system_prompt}\n{history}\n{bot_name}:")
.as_str()
.into(),
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
// llm::OutputRequest
&mut Default::default(),
// output callback
inference_callback(String::from(user_name), &mut buf, &mut res),
)
.unwrap_or_else(|e| panic!("{e}"));
Ok(res)
}
cfg_if! {
if #[cfg(feature = "ssr")] {
use std::convert::Infallible;
fn inference_callback<'a>(
stop_sequence: String,
buf: &'a mut String,
out_str: &'a mut String,
) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
move |resp| match resp {
llm::InferenceResponse::InferredToken(token) => {
// We've generated a token, so we need to check if it's contained in the stop sequence.
let mut stop_sequence_buf = buf.clone();
stop_sequence_buf.push_str(token.as_str());
if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
// modifying the model.
buf.clear();
return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
// We've generated a prefix of the stop sequence, so we need to keep buffering.
buf.push_str(token.as_str());
return Ok(llm::InferenceFeedback::Continue);
}
// We've generated a token that isn't part of the stop sequence, so we can
// pass it to the callback.
if buf.is_empty() {
out_str.push_str(&token);
} else {
out_str.push_str(&stop_sequence_buf);
}
Ok(llm::InferenceFeedback::Continue)
}
llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
_ => Ok(llm::InferenceFeedback::Continue),
}
}
}
}
而 Cargo.toml 在 [dependencies]
以下的區塊如下:
[dependencies]
actix-files = { version = "0.6", optional = true }
actix-web = { version = "4", optional = true, features = ["macros"] }
console_error_panic_hook = "0.1"
cfg-if = "1"
http = { version = "0.2", optional = true }
leptos = { version = "0.5.0", features = ["nightly"] }
leptos_meta = { version = "0.5.0", features = ["nightly"] }
leptos_actix = { version = "0.5.0", optional = true }
leptos_router = { version = "0.5.0", features = ["nightly"] }
wasm-bindgen = "=0.2.87"
serde = { version = "1.0.188", features = ["derive"] }
llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}
rand = { version = "0.8.5", optional = true}
[features]
csr = ["leptos/csr", "leptos_meta/csr", "leptos_router/csr"]
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
ssr = [
"dep:actix-files",
"dep:actix-web",
"dep:leptos_actix",
"dep:llm",
"dep:rand",
"leptos/ssr",
"leptos_meta/ssr",
"leptos_router/ssr",
]
[profile.dev.package.ggml-sys]
opt-level = 3
注意在這個 converse
函式中,我們假定模型已經載入進來了,但實際上還沒,這就是我們明天的工作囉,明天見!!!