iT邦幫忙

2025 iThome 鐵人賽

DAY 30
1
Rust

Rust 實戰專案集:30 個漸進式專案從工具到服務系列 第 30

微服務閘道器 - 實作 API Gateway 與負載均衡 & (後記完賽感言)

  • 分享至 

  • xImage
  •  

前言

今天為最後一天的主題,主要實現 API Gateway 的功能,當然因為以學習為目標,
所以未必會比仿間做 Loadbalancer ,Api Gateway 考慮的面向還多,今天是最後一天了
努力加油!!

今天學習目標

  • 實作反向代理與請求路由
  • 實現多種負載均衡策略(輪詢、最少連接、加權輪詢)
  • 加入健康檢查機制
  • 實作限流與熔斷器模式
  • 提供服務發現與動態配置

專案結構

api-gateway/
├── Cargo.toml
├── config.yaml
└── src/
    ├── main.rs
    ├── config.rs
    ├── proxy.rs
    ├── load_balancer.rs
    ├── health_check.rs
    ├── rate_limiter.rs
    └── circuit_breaker.rs

依賴

cargo.toml

[package]
name = "api-gateway"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1.35", features = ["full"] }
axum = "0.7"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
tower = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9"
tracing = "0.1"
tracing-subscriber = "0.3"
tokio-util = "0.7"
http-body-util = "0.1"
bytes = "1.5"
governor = "0.6"
parking_lot = "0.12"

開始實作

一樣我們把專案結構生出來

src/config.rs

use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
    pub server: ServerConfig,
    pub services: HashMap<String, ServiceConfig>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServiceConfig {
    pub path_prefix: String,
    pub backends: Vec<Backend>,
    pub load_balancer: LoadBalancerType,
    pub health_check: HealthCheckConfig,
    pub rate_limit: Option<RateLimitConfig>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Backend {
    pub url: String,
    pub weight: Option<u32>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum LoadBalancerType {
    RoundRobin,
    LeastConnections,
    WeightedRoundRobin,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HealthCheckConfig {
    pub interval_secs: u64,
    pub timeout_secs: u64,
    pub path: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitConfig {
    pub requests_per_second: u32,
}

impl Config {
    pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
        let content = std::fs::read_to_string(path)?;
        let config: Config = serde_yaml::from_str(&content)?;
        Ok(config)
    }
}

load balancer 實作

src/load_balancer.rs

use crate::config::{Backend, LoadBalancerType};
use parking_lot::RwLock;
use std::sync::Arc;

#[derive(Clone)]
pub struct LoadBalancer {
    backends: Arc<RwLock<Vec<BackendState>>>,
    strategy: LoadBalancerType,
    current_index: Arc<RwLock<usize>>,
}

#[derive(Clone)]
struct BackendState {
    backend: Backend,
    healthy: bool,
    active_connections: usize,
}

impl LoadBalancer {
    pub fn new(backends: Vec<Backend>, strategy: LoadBalancerType) -> Self {
        let backend_states = backends
            .into_iter()
            .map(|b| BackendState {
                backend: b,
                healthy: true,
                active_connections: 0,
            })
            .collect();

        Self {
            backends: Arc::new(RwLock::new(backend_states)),
            strategy,
            current_index: Arc::new(RwLock::new(0)),
        }
    }

    pub fn next_backend(&self) -> Option<String> {
        let backends = self.backends.read();
        let healthy_backends: Vec<_> = backends
            .iter()
            .filter(|b| b.healthy)
            .collect();

        if healthy_backends.is_empty() {
            return None;
        }

        match self.strategy {
            LoadBalancerType::RoundRobin => self.round_robin(&healthy_backends),
            LoadBalancerType::LeastConnections => self.least_connections(&healthy_backends),
            LoadBalancerType::WeightedRoundRobin => self.weighted_round_robin(&healthy_backends),
        }
    }

    fn round_robin(&self, backends: &[&BackendState]) -> Option<String> {
        let mut index = self.current_index.write();
        let backend = backends.get(*index % backends.len())?;
        *index += 1;
        Some(backend.backend.url.clone())
    }

    fn least_connections(&self, backends: &[&BackendState]) -> Option<String> {
        backends
            .iter()
            .min_by_key(|b| b.active_connections)
            .map(|b| b.backend.url.clone())
    }

    fn weighted_round_robin(&self, backends: &[&BackendState]) -> Option<String> {
        let total_weight: u32 = backends
            .iter()
            .map(|b| b.backend.weight.unwrap_or(1))
            .sum();

        if total_weight == 0 {
            return self.round_robin(backends);
        }

        let mut index = self.current_index.write();
        let position = (*index % total_weight as usize) as u32;
        *index += 1;

        let mut cumulative = 0u32;
        for backend in backends {
            cumulative += backend.backend.weight.unwrap_or(1);
            if position < cumulative {
                return Some(backend.backend.url.clone());
            }
        }

        backends.first().map(|b| b.backend.url.clone())
    }

    pub fn increment_connections(&self, url: &str) {
        let mut backends = self.backends.write();
        if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
            backend.active_connections += 1;
        }
    }

    pub fn decrement_connections(&self, url: &str) {
        let mut backends = self.backends.write();
        if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
            backend.active_connections = backend.active_connections.saturating_sub(1);
        }
    }

    pub fn mark_unhealthy(&self, url: &str) {
        let mut backends = self.backends.write();
        if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
            backend.healthy = false;
        }
    }

    pub fn mark_healthy(&self, url: &str) {
        let mut backends = self.backends.write();
        if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
            backend.healthy = true;
        }
    }
}

健康檢查

src/health_check.rs

use crate::config::HealthCheckConfig;
use crate::load_balancer::LoadBalancer;
use std::time::Duration;
use tokio::time;
use tracing::{error, info};

pub struct HealthChecker {
    load_balancer: LoadBalancer,
    config: HealthCheckConfig,
    backends: Vec<String>,
}

impl HealthChecker {
    pub fn new(
        load_balancer: LoadBalancer,
        config: HealthCheckConfig,
        backends: Vec<String>,
    ) -> Self {
        Self {
            load_balancer,
            config,
            backends,
        }
    }

    pub async fn start(self) {
        let mut interval = time::interval(Duration::from_secs(self.config.interval_secs));

        loop {
            interval.tick().await;
            self.check_all_backends().await;
        }
    }

    async fn check_all_backends(&self) {
        for backend_url in &self.backends {
            let health_url = format!("{}{}", backend_url, self.config.path);
            
            match self.check_backend(&health_url).await {
                Ok(true) => {
                    info!("Backend {} is healthy", backend_url);
                    self.load_balancer.mark_healthy(backend_url);
                }
                Ok(false) | Err(_) => {
                    error!("Backend {} is unhealthy", backend_url);
                    self.load_balancer.mark_unhealthy(backend_url);
                }
            }
        }
    }

    async fn check_backend(&self, url: &str) -> Result<bool, Box<dyn std::error::Error>> {
        let client = reqwest::Client::builder()
            .timeout(Duration::from_secs(self.config.timeout_secs))
            .build()?;

        let response = client.get(url).send().await?;
        Ok(response.status().is_success())
    }
}

限流器製作

src/rate_limiter.rs

use governor::{Quota, RateLimiter as GovernorRateLimiter};
use std::num::NonZeroU32;
use std::sync::Arc;

#[derive(Clone)]
pub struct RateLimiter {
    limiter: Arc<GovernorRateLimiter<String, governor::state::direct::NotKeyed, governor::clock::DefaultClock>>,
}

impl RateLimiter {
    pub fn new(requests_per_second: u32) -> Self {
        let quota = Quota::per_second(NonZeroU32::new(requests_per_second).unwrap());
        let limiter = Arc::new(GovernorRateLimiter::direct(quota));
        
        Self { limiter }
    }

    pub fn check(&self) -> bool {
        self.limiter.check().is_ok()
    }
}

circuit breaker 熔斷器製作

src/circuit_breaker.rs

use parking_lot::RwLock;
use std::sync::Arc;
use std::time::{Duration, Instant};

#[derive(Clone)]
pub struct CircuitBreaker {
    state: Arc<RwLock<CircuitState>>,
    failure_threshold: usize,
    timeout: Duration,
}

struct CircuitState {
    failures: usize,
    last_failure: Option<Instant>,
    state: State,
}

#[derive(PartialEq, Clone, Copy)]
enum State {
    Closed,
    Open,
    HalfOpen,
}

impl CircuitBreaker {
    pub fn new(failure_threshold: usize, timeout_secs: u64) -> Self {
        Self {
            state: Arc::new(RwLock::new(CircuitState {
                failures: 0,
                last_failure: None,
                state: State::Closed,
            })),
            failure_threshold,
            timeout: Duration::from_secs(timeout_secs),
        }
    }

    pub fn can_request(&self) -> bool {
        let mut state = self.state.write();

        match state.state {
            State::Closed => true,
            State::Open => {
                if let Some(last_failure) = state.last_failure {
                    if last_failure.elapsed() > self.timeout {
                        state.state = State::HalfOpen;
                        true
                    } else {
                        false
                    }
                } else {
                    false
                }
            }
            State::HalfOpen => true,
        }
    }

    pub fn record_success(&self) {
        let mut state = self.state.write();
        state.failures = 0;
        state.state = State::Closed;
    }

    pub fn record_failure(&self) {
        let mut state = self.state.write();
        state.failures += 1;
        state.last_failure = Some(Instant::now());

        if state.failures >= self.failure_threshold {
            state.state = State::Open;
        }
    }
}

代理處理器 (proxy)

src/proxy.rs

use crate::circuit_breaker::CircuitBreaker;
use crate::load_balancer::LoadBalancer;
use crate::rate_limiter::RateLimiter;
use axum::{
    body::Body,
    extract::State,
    http::{Request, Response, StatusCode},
    response::IntoResponse,
};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use std::sync::Arc;
use tracing::{error, info};

#[derive(Clone)]
pub struct ProxyState {
    pub load_balancer: LoadBalancer,
    pub rate_limiter: Option<RateLimiter>,
    pub circuit_breaker: CircuitBreaker,
    pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Body>,
}

impl ProxyState {
    pub fn new(
        load_balancer: LoadBalancer,
        rate_limiter: Option<RateLimiter>,
    ) -> Self {
        let client = Client::builder(TokioExecutor::new()).build_http();
        let circuit_breaker = CircuitBreaker::new(5, 30);

        Self {
            load_balancer,
            rate_limiter,
            circuit_breaker,
            client,
        }
    }
}

pub async fn proxy_handler(
    State(state): State<Arc<ProxyState>>,
    mut req: Request<Body>,
) -> impl IntoResponse {
    // 檢查限流
    if let Some(ref limiter) = state.rate_limiter {
        if !limiter.check() {
            return Response::builder()
                .status(StatusCode::TOO_MANY_REQUESTS)
                .body(Body::from("Rate limit exceeded"))
                .unwrap();
        }
    }

    // 檢查熔斷器
    if !state.circuit_breaker.can_request() {
        return Response::builder()
            .status(StatusCode::SERVICE_UNAVAILABLE)
            .body(Body::from("Service temporarily unavailable"))
            .unwrap();
    }

    // 選擇後端服務
    let backend_url = match state.load_balancer.next_backend() {
        Some(url) => url,
        None => {
            return Response::builder()
                .status(StatusCode::SERVICE_UNAVAILABLE)
                .body(Body::from("No healthy backends available"))
                .unwrap();
        }
    };

    // 修改請求 URI
    let path = req.uri().path();
    let path_query = req
        .uri()
        .path_and_query()
        .map(|pq| pq.as_str())
        .unwrap_or(path);

    let target_url = format!("{}{}", backend_url, path_query);
    
    match target_url.parse::<hyper::Uri>() {
        Ok(uri) => {
            *req.uri_mut() = uri;
        }
        Err(e) => {
            error!("Failed to parse URI: {}", e);
            return Response::builder()
                .status(StatusCode::INTERNAL_SERVER_ERROR)
                .body(Body::from("Invalid backend URL"))
                .unwrap();
        }
    }

    // 增加連接計數
    state.load_balancer.increment_connections(&backend_url);

    // 發送請求
    let response = match state.client.request(req).await {
        Ok(resp) => {
            info!("Request forwarded to {}", backend_url);
            state.circuit_breaker.record_success();
            resp
        }
        Err(e) => {
            error!("Proxy error: {}", e);
            state.circuit_breaker.record_failure();
            state.load_balancer.decrement_connections(&backend_url);
            return Response::builder()
                .status(StatusCode::BAD_GATEWAY)
                .body(Body::from("Backend service error"))
                .unwrap();
        }
    };

    // 減少連接計數
    state.load_balancer.decrement_connections(&backend_url);

    response
}

main.rs 主程式

mod circuit_breaker;
mod config;
mod health_check;
mod load_balancer;
mod proxy;
mod rate_limiter;

use axum::{routing::any, Router};
use config::Config;
use health_check::HealthChecker;
use load_balancer::LoadBalancer;
use proxy::{proxy_handler, ProxyState};
use rate_limiter::RateLimiter;
use std::sync::Arc;
use tracing::info;
use tracing_subscriber;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 初始化日誌
    tracing_subscriber::fmt::init();

    // 載入配置
    let config = Config::from_file("config.yaml")?;
    let addr = format!("{}:{}", config.server.host, config.server.port);

    info!("Starting API Gateway on {}", addr);

    // 為每個服務創建路由
    let mut app = Router::new();

    for (service_name, service_config) in config.services {
        info!("Configuring service: {}", service_name);

        // 創建負載均衡器
        let load_balancer = LoadBalancer::new(
            service_config.backends.clone(),
            service_config.load_balancer.clone(),
        );

        // 創建限流器
        let rate_limiter = service_config
            .rate_limit
            .as_ref()
            .map(|rl| RateLimiter::new(rl.requests_per_second));

        // 創建代理狀態
        let proxy_state = Arc::new(ProxyState::new(load_balancer.clone(), rate_limiter));

        // 添加路由
        let path = format!("{}/*path", service_config.path_prefix);
        app = app.route(&path, any(proxy_handler).with_state(proxy_state));

        // 啟動健康檢查
        let backend_urls: Vec<String> = service_config
            .backends
            .iter()
            .map(|b| b.url.clone())
            .collect();

        let health_checker = HealthChecker::new(
            load_balancer,
            service_config.health_check,
            backend_urls,
        );

        tokio::spawn(async move {
            health_checker.start().await;
        });
    }

    // 啟動服務器
    let listener = tokio::net::TcpListener::bind(&addr).await?;
    info!("API Gateway listening on {}", addr);

    axum::serve(listener, app).await?;

    Ok(())
}

配置文件

config.yaml

server:
  host: "0.0.0.0"
  port: 8080

services:
  user_service:
    path_prefix: "/api/users"
    load_balancer: weighted_round_robin
    backends:
      - url: "http://localhost:3001"
        weight: 3
      - url: "http://localhost:3002"
        weight: 2
      - url: "http://localhost:3003"
        weight: 1
    health_check:
      interval_secs: 10
      timeout_secs: 3
      path: "/health"
    rate_limit:
      requests_per_second: 100

  order_service:
    path_prefix: "/api/orders"
    load_balancer: least_connections
    backends:
      - url: "http://localhost:4001"
      - url: "http://localhost:4002"
    health_check:
      interval_secs: 15
      timeout_secs: 5
      path: "/health"
    rate_limit:
      requests_per_second: 50

  product_service:
    path_prefix: "/api/products"
    load_balancer: round_robin
    backends:
      - url: "http://localhost:5001"
      - url: "http://localhost:5002"
      - url: "http://localhost:5003"
    health_check:
      interval_secs: 10
      timeout_secs: 3
      path: "/health"

測試

簡單的測試後端

use axum::{routing::get, Router};

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/health", get(|| async { "OK" }))
        .route("/api/users/*path", get(|| async { "User Service Response" }));

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3001")
        .await
        .unwrap();
    
    println!("Backend service running on :3001");
    axum::serve(listener, app).await.unwrap();
}

增加 middleware

ouse axum::{
    http::{Request, StatusCode},
    middleware::Next,
    response::Response,
};

pub async fn auth_middleware<B>(
    req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    // 檢查 Authorization header
    let auth_header = req
        .headers()
        .get("Authorization")
        .and_then(|h| h.to_str().ok());

    match auth_header {
        Some(token) if token.starts_with("Bearer ") => {
            // 驗證 token
            Ok(next.run(req).await)
        }
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

增加 log

use axum::middleware;
use tower_http::trace::TraceLayer;

// 在 Router 中添加
app = app.layer(TraceLayer::new_for_http());

websocket 支援

use axum::extract::ws::{WebSocket, WebSocketUpgrade};

async fn ws_handler(
    ws: WebSocketUpgrade,
    State(state): State<Arc<ProxyState>>,
) -> impl IntoResponse {
    ws.on_upgrade(|socket| handle_socket(socket, state))
}

async fn handle_socket(socket: WebSocket, state: Arc<ProxyState>) {
    // WebSocket 代理邏輯
}

補充學習

後記

經過這 30 天我個人認為我了解到很多東西,雖然工作繁忙,還在寫這些其實過程非常痛苦
尤其 15 - 25 這區間是最痛苦的階段,但也慢慢變成習慣每天發兩篇文章這樣
生理時鐘會有雷達告訴我這時候該寫文了這樣。
這 30 天的旅程也是很有收穫,冥冥之中也做了不少事情。
我今年的鐵人賽沒有放額外的圖和講解,因為我發現如果寫得太困難會沒人看
寫得太簡單我會覺得無聊,但太簡單又會想補充很多東西,所以到後來我都放飛自我
志在參加而非得獎,我認為今年我應該也不會有什麼得獎XD
如果我得獎我就把這段給刪掉XD
因為我寫得是最樸素的寫法,並沒有準備圖片或是宣傳語之類的,也是完全按照個人意思寫文章
所以我認為我高機率不會得獎,因為我覺得不會有人被我文章所吸引,聽聞鐵人賽還是有一大基數
是面對初學者居多,那我就知道我不會得獎,我去選擇當『進階技術』的那類文章即可。


上一篇
檔案同步服務 - 類似 Dropbox 的檔案同步系統
系列文
Rust 實戰專案集:30 個漸進式專案從工具到服務30
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

1 則留言

1
mikehsu0618
iT邦研究生 5 級 ‧ 2025-10-14 01:09:35

恭喜大哥完賽~!

MichaelHo iT邦新手 3 級 ‧ 2025-10-22 08:36:32 檢舉

我怎是大哥呢XD平身平身

我要留言

立即登入留言