iT邦幫忙

2023 iThome 鐵人賽

DAY 16
0
AI & Data

Rust 加 MLOps,你說有沒有搞頭?系列 第 16

[Day 16] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (7/10)|後端 LLM API

  • 分享至 

  • xImage
  •  

今日份 Ferris

今天要接著要實作後端 API 了,是這個專案中最難的部分,要起飛啦~
https://ithelp.ithome.com.tw/upload/images/20231002/20141304nkE8K0Q6Rh.jpg
*Droidstacean by Ivan Lozano, based on a design by Karen Rustad Tölva.

模組化

🏮 今天完整的程式碼可以拉到最底下 Put it together 區塊或是在 GitHub 找到。

我們繼續延續模組化的精神,先建立 src/api.rs,所有的 server-side function 都會放在這裡,記得在 lib.rs 的開頭加上 pub mod api; 以將其加入模組樹中。

converse 函式

src/api.rs 中要建立的是前天 [Day 14] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (5/9)|Signal & Action 放在 todo! 巨集中的 converse 異步函式 (async),其函式簽名如下:

use crate::model::conversation::Conversation;
use leptos::*;

#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {}

其輸入資料的型別為 Conversation,而輸出是一個 Result 列舉,在成功時回傳 String 型別,失敗時回傳 ServerFnError

而我們使用 #[server(Converse "/api")] 巨集來標注這個函式,其中前者是名稱,後者是位址。

這個巨集讓我們只需要建立運行於伺服器端的 API,就能自動得到另一個可用於客戶端的同名函式,以將 HTTP request 發送至相對應的後端 API。

換句話說,我們並不需要在客戶端手動建立 HTTP request 或指定 API 的 URL 等細節,一切都會自動發生,這也是使用 Leptos 的一大特點。

我們要做的只有在函式簽名中給定要從客戶端傳到後端 API 的型別 (Conversation),以及要傳回客戶端的型別 (String) 即可。

🚨 注意,converse 函式是公開的 API,並沒有設定任何的認證與授權,如果真的要產品化,記得保護這些 API~

回到這個函式,它的任務就是要將 Conversation 送給模型進行推論,然後把推論結果回傳到客戶端。

而要進行推論就得使用到 Rustformers 的 llm crate 啦,以下是需要在函式 Scope 中引入的函式庫:

    use actix_web::dev::ConnectionInfo;
    use actix_web::web::Data;
    use leptos_actix::extract;

    use llm::models::Llama;
    use llm::KnownModel;

與此同時也得把 crate 加到 Cargo.toml 中,這裡參考官方在 Using llm in a Rust Project 中的說明,使用與 main 分支相同的版本以獲取最新的功能,所以在 [dependencies] 區塊最後面加上:

llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}

這裡的 optional 的意思是我們可以只將這個 crate 加到 ssr 中,所以在 [features] 區塊 ssr 的 list 中也加上:

ssr = [
  "dep:actix-files",
  "dep:actix-web",
  "dep:leptos_actix",
  **"dep:llm",**
  "leptos/ssr",
  "leptos_meta/ssr",
  "leptos_router/ssr",
]

另外,為了提升除錯上的表現,可以把 ggml-sys 在除錯模式中排除,所以把以下區塊加到 [features] 區塊下面:

[profile.dev.package.ggml-sys]
opt-level = 3

提取器

由於我們選用 Actix 作為後端框架,所以必須了解其處理程序都建立在名為提取器 (extractor) 的概念之上。

提取器能從 HTTP 請求中 “提取” 型別資料,讓我們得以存取伺服器端的特定資料,而 Leptos 提供了 extract 輔助函式,讓我們可以直接在伺服器端函式中使用這些提取器。

這個輔助函式就是 leptos_actix 中的 extract 函式,它接受處理程序函式作為其參數,其中處理程序是一個異步函式,接收從請求中提取的參數並返回某個值。

而處理程序函式將提取到的資料作為其參數,並且可以在異步移動塊 (async move) 的主體內對其進行進一步的異步處理,並回傳任何想要回傳到伺服器端函式中的值,而這裡就是我們的模型:

let model =
        extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
            .await
            .unwrap();

詳細可以參考官方教學 Using Extractors

Prompt engineering

而要建立聊天機器人,我們需要做適當的 prompt engineering,以使 LLM 能以我們想要的方式進行互動,這裡我們直接參考 Taiwan-LLaMa/demo/conversation.pyconv_one_shot 的部分,得知其格式類似於:

A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
### Assistant: Hello! How may I help you today?
### Human: What is the capital of Taiwan?
### Assistant: Taipei is the capital of Taiwan.

不同的語言模型有不同的模板,通常差別在於第一行與角色的名稱,常見的清單可以參考 Prompt Template for OpenSource LLMs

這個 prompt engineering 就會作為指令的開頭餵給 LLM,而在其之後就會接著歷史對話,而在歷史對話之後,很重要的是要再接上 ### Assistant: 作為結尾,讓模型可以接著回答下去。

現在就先建立開頭與歷史對話的部分:

		let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
    let user_name = "### Human";
    let bot_name = "### Assistant";
    let mut history = format!(
        "{bot_name}:Hello! How may I help you today?\n\
        {user_name}:What is the capital of Taiwan?\n\
        {bot_name}:Taipei is the capital of Taiwan.\n"
    );

    for message in prompt.messages.into_iter() {
        let msg = message.text;
        let curr_line = if message.user {
            format!("{user_name}:{msg}\n")
        } else {
            format!("{bot_name}:{msg}\n")
        };

        history.push_str(&curr_line);
    }

模型推論

接著參考 llm 官方文件官方範例 vicuna-chat.rs 我們可以開始使用模型來推論了,其中說明直接註解在上面:

	  let mut res = String::new();
    let mut buf = String::new();

    // use the model to generate text from a prompt
    let mut session = model.start_session(Default::default());

    session
        .infer(
            // model to use for text generation
            model.as_ref(),
            // randomness provider
            &mut rand::thread_rng(),
            // the prompt to use for text generation, as well as other
            // inference parameters
            &llm::InferenceRequest {
                prompt: format!("{system_prompt}\n{history}\n{bot_name}:}")
                    .as_str()
                    .into(),
                parameters: &llm::InferenceParameters::default(),
                play_back_previous_tokens: false,
                maximum_token_count: None,
            },
            // llm::OutputRequest
            &mut Default::default(),
            // output callback
            inference_callback(String::from(user_name), &mut buf, &mut res),
        )
        .unwrap_or_else(|e| panic!("{e}"));

預設上,llm 會自動從 Hugging Face 的 model hub 下載 tokenizer。

由於其中用到了 rand crate,所以要記得要在 Cargo.toml 加上:
rand = { version = "0.8.5", optional = true}
以及 ssr 也要記得加。

inference_callback 函式

其中 inference_callback 的主要功能為告訴 session.infer() 函式何時該停止,而我們想要模型在輸出人類的角色名稱 ### Human 時停止,否則模型就會開始無止境地自言自語了。

這裡我們參考的是 llm/crates/llm-base/src/inference_session.rsconversation_inference_callback 函式的寫法,同樣說明也註解在其中:

cfg_if! {
    if #[cfg(feature = "ssr")] {
    use std::convert::Infallible;

        fn inference_callback<'a>(
            stop_sequence: String,
            buf: &'a mut String,
            out_str: &'a mut String,
        ) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
            move |resp| match resp {
                llm::InferenceResponse::InferredToken(token) => {
                    // We've generated a token, so we need to check if it's contained in the stop sequence.
                    let mut stop_sequence_buf = buf.clone();
                    stop_sequence_buf.push_str(token.as_str());

                    if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
                        // We've generated the stop sequence, so we're done.
                        // Note that this will contain the extra tokens that were generated after the stop sequence,
                        // which may affect generation. This is non-ideal, but it's the best we can do without
                        // modifying the model.
                        buf.clear();
                        return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
                    } else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
                        // We've generated a prefix of the stop sequence, so we need to keep buffering.
                        buf.push_str(token.as_str());
                        return Ok(llm::InferenceFeedback::Continue);
                    }

                    // We've generated a token that isn't part of the stop sequence, so we can
                    // pass it to the callback.
                    if buf.is_empty() {
                        out_str.push_str(&token);
                    } else {
                        out_str.push_str(&stop_sequence_buf);
                    }

                    Ok(llm::InferenceFeedback::Continue)
                }
                llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
                _ => Ok(llm::InferenceFeedback::Continue),
            }
        }
    }
}

其中 cfg_if 巨集的用意在於讓這個函式只在伺服器端 (ssr) 被編譯,如此一來就完成 API 的建立囉。
另外這裡也是 Lifetimes 註解 <'a> 第一次在這次系列文中出現,這部分之後會再說明!

Put it together

此時整個 src/api.rs 會像這樣:

use crate::model::conversation::Conversation;

use cfg_if::cfg_if;
use leptos::*;

#[server(Converse "/api")]
pub async fn converse(prompt: Conversation) -> Result<String, ServerFnError> {
    use actix_web::dev::ConnectionInfo;
    use actix_web::web::Data;
    use leptos_actix::extract;

    use llm::models::Llama;
    use llm::KnownModel;

    let model =
        extract(|data: Data<Llama>, _connection: ConnectionInfo| async move { data.into_inner() })
            .await
            .unwrap();

    let system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
    let user_name = "### Human";
    let bot_name = "### Assistant";
    let mut history = format!(
        "{bot_name}:Hello! How may I help you today?\n\
        {user_name}:What is the capital of Taiwan?\n\
        {bot_name}:Taipei is the capital of Taiwan.\n"
    );

    for message in prompt.messages.into_iter() {
        let msg = message.text;
        let curr_line = if message.user {
            format!("{user_name}:{msg}\n")
        } else {
            format!("{bot_name}:{msg}\n")
        };

        history.push_str(&curr_line);
    }

    let mut res = String::new();
    let mut buf = String::new();

    // use the model to generate text from a prompt
    let mut session = model.start_session(Default::default());

    session
        .infer(
            // model to use for text generation
            model.as_ref(),
            // randomness provider
            &mut rand::thread_rng(),
            // the prompt to use for text generation, as well as other
            // inference parameters
            &llm::InferenceRequest {
                prompt: format!("{system_prompt}\n{history}\n{bot_name}:")
                    .as_str()
                    .into(),
                parameters: &llm::InferenceParameters::default(),
                play_back_previous_tokens: false,
                maximum_token_count: None,
            },
            // llm::OutputRequest
            &mut Default::default(),
            // output callback
            inference_callback(String::from(user_name), &mut buf, &mut res),
        )
        .unwrap_or_else(|e| panic!("{e}"));

    Ok(res)
}

cfg_if! {
    if #[cfg(feature = "ssr")] {
    use std::convert::Infallible;

        fn inference_callback<'a>(
            stop_sequence: String,
            buf: &'a mut String,
            out_str: &'a mut String,
        ) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + 'a {
            move |resp| match resp {
                llm::InferenceResponse::InferredToken(token) => {
                    // We've generated a token, so we need to check if it's contained in the stop sequence.
                    let mut stop_sequence_buf = buf.clone();
                    stop_sequence_buf.push_str(token.as_str());

                    if stop_sequence.as_str().eq(stop_sequence_buf.as_str()) {
                        // We've generated the stop sequence, so we're done.
                        // Note that this will contain the extra tokens that were generated after the stop sequence,
                        // which may affect generation. This is non-ideal, but it's the best we can do without
                        // modifying the model.
                        buf.clear();
                        return Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Halt);
                    } else if stop_sequence.as_str().starts_with(stop_sequence_buf.as_str()) {
                        // We've generated a prefix of the stop sequence, so we need to keep buffering.
                        buf.push_str(token.as_str());
                        return Ok(llm::InferenceFeedback::Continue);
                    }

                    // We've generated a token that isn't part of the stop sequence, so we can
                    // pass it to the callback.
                    if buf.is_empty() {
                        out_str.push_str(&token);
                    } else {
                        out_str.push_str(&stop_sequence_buf);
                    }

                    Ok(llm::InferenceFeedback::Continue)
                }
                llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
                _ => Ok(llm::InferenceFeedback::Continue),
            }
        }
    }
}

而 Cargo.toml 在 [dependencies] 以下的區塊如下:

[dependencies]
actix-files = { version = "0.6", optional = true }
actix-web = { version = "4", optional = true, features = ["macros"] }
console_error_panic_hook = "0.1"
cfg-if = "1"
http = { version = "0.2", optional = true }
leptos = { version = "0.5.0", features = ["nightly"] }
leptos_meta = { version = "0.5.0", features = ["nightly"] }
leptos_actix = { version = "0.5.0", optional = true }
leptos_router = { version = "0.5.0", features = ["nightly"] }
wasm-bindgen = "=0.2.87"
serde = { version = "1.0.188", features = ["derive"] }
llm = { git = "https://github.com/rustformers/llm.git", branch = "main", optional = true}
rand = { version = "0.8.5", optional = true}

[features]
csr = ["leptos/csr", "leptos_meta/csr", "leptos_router/csr"]
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
ssr = [
  "dep:actix-files",
  "dep:actix-web",
  "dep:leptos_actix",
  "dep:llm",
  "dep:rand",
  "leptos/ssr",
  "leptos_meta/ssr",
  "leptos_router/ssr",
]

[profile.dev.package.ggml-sys]
opt-level = 3

注意在這個 converse 函式中,我們假定模型已經載入進來了,但實際上還沒,這就是我們明天的工作囉,明天見!!!
/images/emoticon/emoticon78.gif


上一篇
[Day 15] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (6/10)|GGML 量化 LLaMa
下一篇
[Day 17] - 鋼鐵草泥馬 🦙 LLM chatbot 🤖 (8/10)|Rust 中載入 GGML 模型
系列文
Rust 加 MLOps,你說有沒有搞頭?30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言