今天要接著要實作後端 API 了,是這個專案中最難的部分,要起飛啦~
*Droidstacean by Ivan Lozano, based on a design by Karen Rustad Tölva.
🏮 今天完整的程式碼可以拉到最底下 Put it together 區塊或是在 GitHub 找到。
我們繼續延續模組化的精神,先建立 src/api.rs,所有的 server-side function 都會放在這裡,記得在 lib.rs 的開頭加上 pub mod api; 以將其加入模組樹中。
converse 函式在 src/api.rs 中要建立的是前天 [Day 14] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (5/9)|Signal & Action 放在 todo! 巨集中的 converse 異步函式 (async),其函式簽名如下:
use crate::model::conversation::Conversation;
use leptos::*;
#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {}
其輸入資料的型別為 Conversation,而輸出是一個 Result 列舉,在成功時回傳 String 型別,失敗時回傳 ServerFnError。
而我們使用 #[server(Converse "/api")] 巨集來標注這個函式,其中前者是名稱,後者是位址。
這個巨集讓我們只需要建立運行於伺服器端的 API,就能自動得到另一個可用於客戶端的同名函式,以將 HTTP request 發送至相對應的後端 API。
換句話說,我們並不需要在客戶端手動建立 HTTP request 或指定 API 的 URL 等細節,一切都會自動發生,這也是使用 Leptos 的一大特點。
我們要做的只有在函式簽名中給定要從客戶端傳到後端 API 的型別 (Conversation),以及要傳回客戶端的型別 (String) 即可。
🚨 注意,
converse函式是公開的 API,並沒有設定任何的認證與授權,如果真的要產品化,記得保護這些 API~
回到這個函式,它的任務就是要將 Conversation 送給模型進行推論,然後把推論結果回傳到客戶端。
而要進行推論就得使用到 Rustformers 的 llm crate 啦,以下是需要在函式 Scope 中引入的函式庫:
    use actix_web::dev::ConnectionInfo;
    use actix_web::web::Data;
    use leptos_actix::extract;
    use llm::models::Llama;
    use llm::KnownModel;
與此同時也得把 crate 加到 Cargo.toml 中,這裡參考官方在 Using llm in a Rust Project 中的說明,使用與 main 分支相同的版本以獲取最新的功能,所以在 [dependencies] 區塊最後面加上:
llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}
這裡的 optional 的意思是我們可以只將這個 crate 加到 ssr 中,所以在 [features] 區塊 ssr 的 list 中也加上:
ssr = [
  "dep:actix-files",
  "dep:actix-web",
  "dep:leptos_actix",
  **"dep:llm",**
  "leptos/ssr",
  "leptos_meta/ssr",
  "leptos_router/ssr",
]
另外,為了提升除錯上的表現,可以把 ggml-sys 在除錯模式中排除,所以把以下區塊加到 [features] 區塊下面:
[profile.dev.package.ggml-sys]
opt-level = 3
由於我們選用 Actix 作為後端框架,所以必須了解其處理程序都建立在名為提取器 (extractor) 的概念之上。
提取器能從 HTTP 請求中 “提取” 型別資料,讓我們得以存取伺服器端的特定資料,而 Leptos 提供了 extract 輔助函式,讓我們可以直接在伺服器端函式中使用這些提取器。
這個輔助函式就是 leptos_actix 中的 extract 函式,它接受處理程序函式作為其參數,其中處理程序是一個異步函式,接收從請求中提取的參數並返回某個值。
而處理程序函式將提取到的資料作為其參數,並且可以在異步移動塊 (async move) 的主體內對其進行進一步的異步處理,並回傳任何想要回傳到伺服器端函式中的值,而這裡就是我們的模型:
let model =
        extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
            .await
            .unwrap();
詳細可以參考官方教學 Using Extractors
而要建立聊天機器人,我們需要做適當的 prompt engineering,以使 LLM 能以我們想要的方式進行互動,這裡我們直接參考 Taiwan-LLaMa/demo/conversation.py 中 conv_one_shot 的部分,得知其格式類似於:
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
### Assistant: Hello! How may I help you today?
### Human: What is the capital of Taiwan?
### Assistant: Taipei is the capital of Taiwan.
不同的語言模型有不同的模板,通常差別在於第一行與角色的名稱,常見的清單可以參考 Prompt Template for OpenSource LLMs。
這個 prompt engineering 就會作為指令的開頭餵給 LLM,而在其之後就會接著歷史對話,而在歷史對話之後,很重要的是要再接上 ### Assistant: 作為結尾,讓模型可以接著回答下去。
現在就先建立開頭與歷史對話的部分:
		let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
    let user_name = "### Human";
    let bot_name = "### Assistant";
    let mut history = format!(
        "{bot_name}:Hello! How may I help you today?\n\
        {user_name}:What is the capital of Taiwan?\n\
        {bot_name}:Taipei is the capital of Taiwan.\n"
    );
    for message in prompt.messages.into_iter() {
        let msg = message.text;
        let curr_line = if message.user {
            format!("{user_name}:{msg}\n")
        } else {
            format!("{bot_name}:{msg}\n")
        };
        history.push_str(&curr_line);
    }
接著參考 llm 官方文件 與 官方範例 vicuna-chat.rs 我們可以開始使用模型來推論了,其中說明直接註解在上面:
	  let mut res = String::new();
    let mut buf = String::new();
    // use the model to generate text from a prompt
    let mut session = model.start_session(Default::default());
    session
        .infer(
            // model to use for text generation
            model.as_ref(),
            // randomness provider
            &mut rand::thread_rng(),
            // the prompt to use for text generation, as well as other
            // inference parameters
            &llm::InferenceRequest {
                prompt: format!("{system_prompt}\n{history}\n{bot_name}:}")
                    .as_str()
                    .into(),
                parameters: &llm::InferenceParameters::default(),
                play_back_previous_tokens: false,
                maximum_token_count: None,
            },
            // llm::OutputRequest
            &mut Default::default(),
            // output callback
            inference_callback(String::from(user_name), &mut buf, &mut res),
        )
        .unwrap_or_else(|e| panic!("{e}"));
預設上,llm 會自動從 Hugging Face 的 model hub 下載 tokenizer。
由於其中用到了 rand crate,所以要記得要在 Cargo.toml 加上:
rand = { version = "0.8.5", optional = true}
以及 ssr 也要記得加。
inference_callback 函式其中 inference_callback 的主要功能為告訴 session.infer() 函式何時該停止,而我們想要模型在輸出人類的角色名稱 ### Human 時停止,否則模型就會開始無止境地自言自語了。
這裡我們參考的是 llm/crates/llm-base/src/inference_session.rs 中 conversation_inference_callback 函式的寫法,同樣說明也註解在其中:
cfg_if! {
    if #[cfg(feature = "ssr")] {
    use std::convert::Infallible;
        fn inference_callback<'a>(
            stop_sequence: String,
            buf: &'a mut String,
            out_str: &'a mut String,
        ) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
            move |resp| match resp {
                llm::InferenceResponse::InferredToken(token) => {
                    // We've generated a token, so we need to check if it's contained in the stop sequence.
                    let mut stop_sequence_buf = buf.clone();
                    stop_sequence_buf.push_str(token.as_str());
                    if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
                        // We've generated the stop sequence, so we're done.
                        // Note that this will contain the extra tokens that were generated after the stop sequence,
                        // which may affect generation. This is non-ideal, but it's the best we can do without
                        // modifying the model.
                        buf.clear();
                        return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
                    } else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
                        // We've generated a prefix of the stop sequence, so we need to keep buffering.
                        buf.push_str(token.as_str());
                        return Ok(llm::InferenceFeedback::Continue);
                    }
                    // We've generated a token that isn't part of the stop sequence, so we can
                    // pass it to the callback.
                    if buf.is_empty() {
                        out_str.push_str(&token);
                    } else {
                        out_str.push_str(&stop_sequence_buf);
                    }
                    Ok(llm::InferenceFeedback::Continue)
                }
                llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
                _ => Ok(llm::InferenceFeedback::Continue),
            }
        }
    }
}
其中 cfg_if 巨集的用意在於讓這個函式只在伺服器端 (ssr) 被編譯,如此一來就完成 API 的建立囉。
另外這裡也是 Lifetimes 註解 <'a> 第一次在這次系列文中出現,這部分之後會再說明!
此時整個 src/api.rs 會像這樣:
use crate::model::conversation::Conversation;
use cfg_if::cfg_if;
use leptos::*;
#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {
    use actix_web::dev::ConnectionInfo;
    use actix_web::web::Data;
    use leptos_actix::extract;
    use llm::models::Llama;
    use llm::KnownModel;
    let model =
        extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
            .await
            .unwrap();
    let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
    let user_name = "### Human";
    let bot_name = "### Assistant";
    let mut history = format!(
        "{bot_name}:Hello! How may I help you today?\n\
        {user_name}:What is the capital of Taiwan?\n\
        {bot_name}:Taipei is the capital of Taiwan.\n"
    );
    for message in prompt.messages.into_iter() {
        let msg = message.text;
        let curr_line = if message.user {
            format!("{user_name}:{msg}\n")
        } else {
            format!("{bot_name}:{msg}\n")
        };
        history.push_str(&curr_line);
    }
    let mut res = String::new();
    let mut buf = String::new();
    // use the model to generate text from a prompt
    let mut session = model.start_session(Default::default());
    session
        .infer(
            // model to use for text generation
            model.as_ref(),
            // randomness provider
            &mut rand::thread_rng(),
            // the prompt to use for text generation, as well as other
            // inference parameters
            &llm::InferenceRequest {
                prompt: format!("{system_prompt}\n{history}\n{bot_name}:")
                    .as_str()
                    .into(),
                parameters: &llm::InferenceParameters::default(),
                play_back_previous_tokens: false,
                maximum_token_count: None,
            },
            // llm::OutputRequest
            &mut Default::default(),
            // output callback
            inference_callback(String::from(user_name), &mut buf, &mut res),
        )
        .unwrap_or_else(|e| panic!("{e}"));
    Ok(res)
}
cfg_if! {
    if #[cfg(feature = "ssr")] {
    use std::convert::Infallible;
        fn inference_callback<'a>(
            stop_sequence: String,
            buf: &'a mut String,
            out_str: &'a mut String,
        ) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
            move |resp| match resp {
                llm::InferenceResponse::InferredToken(token) => {
                    // We've generated a token, so we need to check if it's contained in the stop sequence.
                    let mut stop_sequence_buf = buf.clone();
                    stop_sequence_buf.push_str(token.as_str());
                    if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
                        // We've generated the stop sequence, so we're done.
                        // Note that this will contain the extra tokens that were generated after the stop sequence,
                        // which may affect generation. This is non-ideal, but it's the best we can do without
                        // modifying the model.
                        buf.clear();
                        return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
                    } else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
                        // We've generated a prefix of the stop sequence, so we need to keep buffering.
                        buf.push_str(token.as_str());
                        return Ok(llm::InferenceFeedback::Continue);
                    }
                    // We've generated a token that isn't part of the stop sequence, so we can
                    // pass it to the callback.
                    if buf.is_empty() {
                        out_str.push_str(&token);
                    } else {
                        out_str.push_str(&stop_sequence_buf);
                    }
                    Ok(llm::InferenceFeedback::Continue)
                }
                llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
                _ => Ok(llm::InferenceFeedback::Continue),
            }
        }
    }
}
而 Cargo.toml 在 [dependencies] 以下的區塊如下:
[dependencies]
actix-files = { version = "0.6", optional = true }
actix-web = { version = "4", optional = true, features = ["macros"] }
console_error_panic_hook = "0.1"
cfg-if = "1"
http = { version = "0.2", optional = true }
leptos = { version = "0.5.0", features = ["nightly"] }
leptos_meta = { version = "0.5.0", features = ["nightly"] }
leptos_actix = { version = "0.5.0", optional = true }
leptos_router = { version = "0.5.0", features = ["nightly"] }
wasm-bindgen = "=0.2.87"
serde = { version = "1.0.188", features = ["derive"] }
llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}
rand = { version = "0.8.5", optional = true}
[features]
csr = ["leptos/csr", "leptos_meta/csr", "leptos_router/csr"]
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
ssr = [
  "dep:actix-files",
  "dep:actix-web",
  "dep:leptos_actix",
  "dep:llm",
  "dep:rand",
  "leptos/ssr",
  "leptos_meta/ssr",
  "leptos_router/ssr",
]
[profile.dev.package.ggml-sys]
opt-level = 3
注意在這個 converse 函式中,我們假定模型已經載入進來了,但實際上還沒,這就是我們明天的工作囉,明天見!!!