iT邦幫忙

2025 iThome 鐵人賽

DAY 26
0
Rust

Rust 後端入門系列 第 26

Day 26 Axum 加入 Swagger

  • 分享至 

  • xImage
  •  

目標

  • 使用 utoipa 自動從型別與 handler 標註產生 OpenAPI(OpenAPI v3)文件。
  • 把 OpenAPI JSON 暴露出來,並用 utoipa-swagger-ui 或 swagger-ui-dist 提供互動式 Swagger UI,方便前端、測試人員或第三方閱讀與測試 API。
  • 支援在文件中描述授權(例如 Bearer JWT)與 request/response schema。

為什麼要加入 Swagger / OpenAPI

  1. 文件化:OpenAPI 提供機器可讀與人可讀的 API 規格,讓團隊成員與第三方清楚知道 API 的 request/response、參數、status code 與範例。
  2. 自動化:用 utoipa 可以直接從 Rust 型別與註解產生文件,減少手動維護文件的成本與錯誤。
  3. 測試與除錯:Swagger UI 提供互動介面,能直接在瀏覽器發送 API 請求(含填 header 與 body),方便 QA 與前端測試。
  4. 整合工具鏈:OpenAPI 可用來生成 client SDK、API mock server、契約測試等,增加開發效率。
  5. 可協作:文件就是合約(contract),前後端可以根據同一份規格開發。

所需套件(Cargo.toml)

在 Cargo.toml 加入以下依賴(版本依官方穩定版本為準):

utoipa = { version = "5", features = ["macros"] }
utoipa-swagger-ui = "9"

整體流程

  1. 為 request / response 型別(CreateUser、UserResponse、LoginRequest)加入 utoipa 的 ToSchema 或 #[derive(ToSchema)]。
  2. 在 handlers 上用 utoipa 的 path 或手動使用 #[utoipa::path(...)] 註解,描述 endpoint 的 method、path、request body、responses 與 security。
  3. 建立一個根 OpenAPI 。
  4. 在 main.rs 加入兩個路由:一個回傳 OpenAPI JSON(/api-doc/openapi.json),另一個使用 utoipa-swagger-ui 提供 Swagger UI( /swagger)。
  5. 在 OpenAPI 中註明安全性:Bearer token(JWT)以便 Swagger UI 支援 Authorization header 的輸入。

範例

  1. 把 model 與 DTO 加上 ToSchema

    在 models.rs加上 utoipa 的 derive:

    use utoipa::ToSchema;
    
    #[derive(Serialize, FromRow, Clone, ToSchema)]
    pub struct User {
        pub id: i64,
        pub username: String,
        pub email: String,
        pub password_hash: String,
        pub created_at: OffsetDateTime,
        pub updated_at: OffsetDateTime,
    }
    
    #[derive(Serialize, Deserialize, Clone, ToSchema)]
    pub struct UserResponse {
        pub id: i64,
        pub username: String,
        pub email: String,
        pub created_at: OffsetDateTime,
        pub updated_at: OffsetDateTime,
    }
    
    #[derive(Deserialize, Validate, ToSchema)]
    pub struct CreateUser {
    
        #[validate(length(min = 3, max = 50))]
        pub username: String,
    
    
        #[validate(email)]
        pub email: String,
    
        #[validate(length(min = 8))]
        pub password: String,
    }
    
    #[derive(Deserialize, Validate, ToSchema)]
    pub struct UpdateUser {
        #[validate(length(min = 3, max = 50))]
        pub username: Option<String>,
    
        #[validate(email)]
        pub email: Option<String>,
    
        #[validate(length(min = 8))]
        pub password: Option<String>,
    }
    
    #[derive(Serialize, Deserialize, Clone, Validate, ToSchema)]
    pub struct LoginRequest {
        #[validate(length(min = 1))]
        pub username_or_email: String,
    
        #[validate(length(min = 8))]
        pub password: String,
    }
    

說明:

  • 對所有會出現在 request / response body 的 struct 使用 ToSchema。若有用到時間型別(time::OffsetDateTime),可能需要手動處理 schema(例如以 String 形式描述)。utoipa 已支援 serde 型別,但特殊型別(OffsetDateTime)可用 #[schema(value_type = String)] 或建立 wrapper。

若 OffsetDateTime 在 schema 上造成問題,可為 UserResponse 的時間欄位加入屬性或改成 String 以便文件正確顯示。例如:

    #[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
    pub created_at: OffsetDateTime,
    #[schema(value_type = String, example = "2024-01-02T12:00:00Z")]
    pub updated_at: OffsetDateTime,
  1. 為 handler 加上路由註解(path)

utoipa 支援用屬性巨集描述 path。我們將要把 create_user、login、get_user 三個路由標註:

路由註解 create_user :

#[utoipa::path(
    post,
    path = "/users",
    tag = "users",
    request_body = CreateUser,
    responses(
        (status = 201, description = "User created", body = UserResponse),
        (status = 400, description = "Invalid payload")
    )
)]
pub async fn create_user( /* ... */ ) -> Result<impl IntoResponse, AppError> {
    // 原本函式內容不變
}

路由註解 login :

#[utoipa::path(
    post,
    path = "/users/login",
    tag = "auth",
    request_body = LoginRequest,
    responses(
        (status = 200, description = "Login success", content_type = "application/json"),
        (status = 401, description = "Invalid credentials")
    )
)]
pub async fn login( /* ... */ ) -> Result<impl IntoResponse, AppError> {
    // 原本函式內容不變
}

其餘部分依此類推。

建立 OpenAPI 根物件(derive OpenApi)

在 src/api_doc.rs 加入:

use utoipa::OpenApi;
use crate::models::{UserResponse, CreateUser, UpdateUser, LoginRequest};

#[derive(OpenApi)]
#[openapi(
    paths(
        crate::handlers::create_user,
        crate::handlers::login,
        crate::handlers::get_user,
        crate::handlers::list_users,
        crate::handlers::update_user,
        crate::handlers::delete_user,
        crate::handlers::myid,
    ),
    components(
        schemas(UserResponse, CreateUser, UpdateUser, LoginRequest)
    ),
    tags(
        (name = "users", description = "User management endpoints"),
        (name = "auth", description = "Authentication endpoints")
    )
)]
pub struct ApiDoc;

說明:

  • paths 列出你希望包含在文件的 handler 函式(這些函式必須以 #[utoipa::path(...)] 或類似可被 utoipa 解析的方式標註)。
  • components.schemas 裝載所有要被引用的 model schema。
  • tags 幫助在 Swagger UI 中把 endpoints 分類。

註明安全性(Bearer JWT)

在 OpenApi derive 或手動建立 OpenAPI 時可加入 security schema,例如 Bearer token:

在同一個 api_doc.rs 裡可以宣告 security scheme:

#[derive(OpenApi)]
#[openapi(
    paths(...),
    components(schemas(...)),
    modifiers(&SecurityAddon),
    tags(
        (name = "users", description = "User management"),
        (name = "auth", description = "Authentication"),
    )
)]
pub struct ApiDoc;

pub struct SecurityAddon;

impl utoipa::openapi::Modify for SecurityAddon {
    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
        use utoipa::openapi::security::{SecurityScheme, HttpAuthScheme};

        let bearer = SecurityScheme::Http(utoipa::openapi::security::Http::new(
            HttpAuthScheme::Bearer,
            Some("JWT"),
        ));

        openapi
            .components
            .get_or_insert_with(Default::default)
            .security_schemes
            .insert("bearer_auth".to_string(), bearer);
    }
}

這樣在你標註某些 path 的 security(("bearer_auth" = [])) 時,Swagger UI 會顯示出一個鎖。

在 main.rs 提供 OpenAPI JSON 與 Swagger UI 路由

把以下程式碼加入 main.rs,在 /api-doc/openapi.json 提供 JSON,在 /swagger 提供互動 UI(utoipa-swagger-ui):

在 main.rs 的 imports 加:

use utoipa_swagger_ui::SwaggerUi;
use crate::api_doc::ApiDoc;

在建立 Router 時加入 swagger 路由(在建立 app 的地方):

let app = Router::new()
    // 既有 routes
    .route("/users", post(handlers::create_user).get(handlers::list_users))
    // ...
    // swagger UI 路由
    .merge(
        SwaggerUi::new("/swagger") // 用於 UI 的 endpoint
            .url("/api-doc/openapi.json", ApiDoc::openapi()) // 提供 openapi.json 的路徑與內容
    )
    .layer(cors)
    .layer(trace)
    .layer(Extension(pool))
    .layer(Extension(redis_conn));

utoipa-swagger-ui 已經內嵌了 Swagger UI 靜態資源,簡單好用。

啟動專案後,用瀏覽器開啟 localhost:3000/swagger 就能看到精美的API

注意事項與細節

  • 型別支援:某些型別(如 time::OffsetDateTime)在 schema 可能需要用 #[schema(value_type = String)] 。
  • 路由註解與 derive:你可以選擇每個 handler 用 #[utoipa::path(...)] 註解或使用 derive(OpenApi) 並把 path 函式列入。先用少數重要路由標註,確認文件產生正確,再逐步加入其他路由。
  • 安全:在生產環境務必僅在內網或受控環境開放 Swagger UI,避免暴露 API spec 與測試介面給未授權使用者。
  • 測試:在 CI 或本機測試時,可把 openapi.json 檢查作為測試的一部分(確保變更不會破壞既有 API contract)。
  • 文件維護:powertips:把 models 與 handler 的註解保留良好範例(example attribute),讓 Swagger UI 顯示範例值,使用者更容易理解 API。

優點總結

  1. 降低文件維護成本:從程式碼註解自動產生文件,避免文件與實作不一致的風險。
  2. 改善團隊協作效率:前端、後端、QA 可使用相同的 OpenAPI 規格來對齊需求與測試。
  3. 加速開發:Swagger UI 提供即時互動測試,不需額外撰寫 Postman collection 即可試 API。
  4. 支援生態工具:OpenAPI 可用於自動生成 client SDK、mock server、API 測試腳本等,擴大應用情境。
  5. 規格化與契約化:明確的 API 規格有助於版本管理、向後相容性檢查與自動化測試。

接下來附上完成專案的程式:

Cargo.toml

[package]
name = "sqlx_connect_demo"
version = "0.1.0"
edition = "2024"

[dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "macros", "migrate", "time"] }
dotenvy = "0.15"
axum = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
anyhow = "1.0"
time = { version = "0.3", features = ["serde"] }
redis = { version = "0.32", features = ["aio", "tokio-comp"] }
argon2 = "0.5"          # 或最新穩定版本
password-hash = "0.5"   # argon2 會用到,用於 decode/verify
tower-http = { version = "0.6", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
http = "1.3"
validator = "0.20"
validator_derive = "0.20"
jsonwebtoken = "9.3"
chrono = { version = "0.4", features = ["serde"] }
futures = "0.3"
utoipa = { version = "5", features = ["macros"] }
utoipa-swagger-ui = { version = "9", features = ["axum"] }

[dev-dependencies]
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
uuid = { version = "1", features = ["v4"] }

main.rs

mod api_doc;
mod app;
mod auth;
mod cache;
mod extractors;
mod handlers;
mod models;
mod password;

use crate::api_doc::ApiDoc;
use axum::{
    Router,
    extract::{Extension, Path},
    routing::{delete, get, post, put},
};
use dotenvy::dotenv;
use http::{HeaderValue, Method};
use redis::Client as RedisClient;
use redis::aio::MultiplexedConnection;
use sqlx::migrate::MigrateDatabase;
use sqlx::postgres::PgPoolOptions;
use std::env;
use std::time::Duration;
use tokio::time;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tracing::Level;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{EnvFilter, fmt};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;

#[tokio::main]
async fn main() {
    dotenv().ok();

    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
    tracing_subscriber::registry()
        .with(env_filter)
        .with(fmt::layer().with_target(false)) // 或 .json() 輸出 JSON
        .init();

    let database_url = match env::var("DATABASE_URL") {
        Ok(v) => v,
        Err(_) => {
            eprintln!("錯誤:找不到 DATABASE_URL");
            std::process::exit(1);
        }
    };

    // 例如 5 秒超時
    let connect_future = PgPoolOptions::new()
        .max_connections(5)
        .connect(&database_url);

    match sqlx::postgres::Postgres::database_exists(&database_url).await {
        Ok(exists) => {
            if !exists {
                println!("資料庫不存在,嘗試建立...");
                if let Err(e) = sqlx::postgres::Postgres::create_database(&database_url).await {
                    eprintln!("建立失敗: {}", e);
                }
            }
        }
        Err(e) => eprintln!("檢查資料庫是否存在失敗: {}", e),
    }

    let pool = match time::timeout(Duration::from_secs(5), connect_future).await {
        Ok(Ok(p)) => {
            println!("成功建立 PgPool");
            p
        }
        Ok(Err(e)) => {
            eprintln!("建立 PgPool 失敗: {}", e);
            std::process::exit(1);
        }
        Err(_) => {
            eprintln!("建立 PgPool 超時");
            std::process::exit(1);
        }
    };

    if let Err(e) = sqlx::migrate!("./migrations").run(&pool).await {
        eprintln!("migrations失敗: {}", e);
        std::process::exit(1);
    }
    println!("成功完成 migrations");

    let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/".to_string());
    let redis_client = RedisClient::open(redis_url.as_str()).expect("invalid redis url");
    let redis_conn: MultiplexedConnection = redis_client
        .get_multiplexed_tokio_connection()
        .await
        .expect("redis connect fail");

    let localhost = HeaderValue::from_static("http://localhost:5173");

    let cors = CorsLayer::new()
        .allow_origin(AllowOrigin::exact(localhost))
        .allow_methods([
            Method::GET,
            Method::POST,
            Method::PUT,
            Method::DELETE,
            Method::OPTIONS,
        ])
        .allow_headers(Any)
        .max_age(std::time::Duration::from_secs(600));

    let trace = TraceLayer::new_for_http()
        .make_span_with(DefaultMakeSpan::new().include_headers(false)) // 不自動包含 headers,避免敏感資訊
        .on_response(DefaultOnResponse::new().level(Level::INFO));

    let app = Router::new()
        .route(
            "/users",
            post(handlers::create_user).get(handlers::list_users),
        )
        .route(
            "/users/{id}",
            get(handlers::get_user)
                .put(handlers::update_user)
                .delete(handlers::delete_user),
        )
        .route("/users/login", post(handlers::login))
        .route("/myid", get(handlers::myid))
        // swagger UI 路由
        .merge(
            SwaggerUi::new("/swagger") // 用於 UI 的 endpoint
                .url("/api-doc/openapi.json", ApiDoc::openapi()), // 提供 openapi.json 的路徑與內容
        )
        .layer(cors)
        .layer(trace)
        .layer(Extension(pool))
        .layer(Extension(redis_conn));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}

models.rs

use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use time::OffsetDateTime;
use utoipa::ToSchema;
use validator_derive::Validate;

#[derive(Serialize, FromRow, Clone, ToSchema)]
pub struct User {
    pub id: i64,
    pub username: String,
    pub email: String,
    pub password_hash: String,
    #[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
    pub created_at: OffsetDateTime,
    #[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
    pub updated_at: OffsetDateTime,
}

#[derive(Serialize, Deserialize, Clone, ToSchema)]
pub struct UserResponse {
    pub id: i64,
    pub username: String,
    pub email: String,
    #[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
    pub created_at: OffsetDateTime,
    #[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
    pub updated_at: OffsetDateTime,
}

impl From<User> for UserResponse {
    fn from(u: User) -> Self {
        UserResponse {
            id: u.id,
            username: u.username,
            email: u.email,
            created_at: u.created_at,
            updated_at: u.updated_at,
        }
    }
}

// 用於建立(request body)
#[derive(Deserialize, Validate, ToSchema)]
pub struct CreateUser {
    // 介於 3 到 50 字元之間
    #[validate(length(min = 3, max = 50))]
    pub username: String,

    // 驗證是否符合 email 格式
    #[validate(email)]
    pub email: String,

    // 密碼至少 8 字元
    #[validate(length(min = 8))]
    pub password: String,
}

// 用於更新(部分 update 使用 PUT ,替換全部內容可改為 PATCH)
// UpdateUser:Option 欄位在 Some 時套用驗證
#[derive(Deserialize, Validate, ToSchema)]
pub struct UpdateUser {
    #[validate(length(min = 3, max = 50))]
    pub username: Option<String>,

    #[validate(email)]
    pub email: Option<String>,

    #[validate(length(min = 8))]
    pub password: Option<String>,
}

// 用於登入
#[derive(Serialize, Deserialize, Clone, Validate, ToSchema)]
pub struct LoginRequest {
    #[validate(length(min = 1))]
    pub username_or_email: String,

    #[validate(length(min = 8))]
    pub password: String,
}

handlers.rs

use crate::auth::sign_access_token;
use crate::cache::{get_user_cache, invalidate_user_cache, set_user_cache};
use crate::extractors::AuthenticatedUser;
use crate::models::LoginRequest;
use crate::models::UserResponse;
use crate::models::{CreateUser, UpdateUser, User};
use crate::password::{hash_password, verify_password};
use axum::{
    extract::{Extension, Json, Path, Query},
    http::StatusCode,
    response::IntoResponse,
};
use redis::aio::MultiplexedConnection;
use serde::Deserialize;
use serde_json::json;
use sqlx::PgPool;
use std::env;
use validator::{Validate, ValidationErrors};

type AppError = (StatusCode, String);

// Helper: map anyhow::Error -> HTTP 500
fn internal_err(e: impl std::fmt::Display) -> AppError {
    (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", e))
}

// POST /users -> INSERT
#[utoipa::path(
    post,
    path = "/users",
    tag = "users",
    request_body = CreateUser,
    responses(
        (status = 201, description = "User created", body = UserResponse),
        (status = 400, description = "Invalid payload")
    )
)]
pub async fn create_user(
    Extension(pool): Extension<PgPool>,
    Extension(mut redis): Extension<MultiplexedConnection>,
    Json(payload): Json<CreateUser>,
) -> Result<impl IntoResponse, AppError> {
    // DTO 驗證(同步)
    if let Err(e) = payload.validate() {
        let body = validation_errors_to_json(&e);
        return Err((StatusCode::BAD_REQUEST, body.to_string()));
    }

    // 實務上:在此對 payload 做驗證(username/email 格式、密碼強度等)
    // 密碼雜湊:示範使用 argon2 。
    let password_hash = hash_password(payload.password.clone())
        .await
        .map_err(|e| internal_err(e))?;

    let rec = sqlx::query_as::<_, User>(
        r#"
        INSERT INTO users (username, email, password_hash, created_at, updated_at)
        VALUES ($1, $2, $3, now(), now())
        RETURNING id, username, email, password_hash, created_at, updated_at
        "#,
    )
    .bind(&payload.username)
    .bind(&payload.email)
    .bind(&password_hash)
    .fetch_one(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    // 轉成外部回傳用型別
    let resp: UserResponse = rec.clone().into();

    // 背景寫入快取(不阻塞回應)
    // 注意:我們把 redis.clone() 給背景任務
    let mut redis_for_set = redis.clone();
    let user_resp = resp.clone();
    tokio::spawn(async move {
        let _ = set_user_cache(&mut redis_for_set, user_resp.id, &user_resp).await;
    });

    Ok((StatusCode::CREATED, Json(resp)))

    // 回傳 201 Created 與新建立的資源
    //Ok((StatusCode::CREATED, Json(rec)))
}

// GET /users/{id} -> SELECT single
#[utoipa::path(
    get,
    path = "/users/{id}",
    tag = "users",
    params(
        ("id" = i64, Path, description = "User id")
    ),
    responses(
        (status = 200, description = "Get user", body = UserResponse),
        (status = 404, description = "Not found")
    )
)]
pub async fn get_user(
    Extension(pool): Extension<PgPool>,
    Extension(mut redis): Extension<MultiplexedConnection>, // 注意:在 same layer 時要指定兩個 Extension 的順序,axum 會匹配
    Path(id): Path<i64>,
) -> Result<impl IntoResponse, AppError> {
    // 1) 嘗試從 Redis 取
    if let Ok(Some(user_res)) = get_user_cache(&mut redis, id).await {
        // 快取命中 — 直接回傳 (200)
        return Ok((StatusCode::OK, Json(user_res)));
    }

    // 2) Cache miss -> 從 DB 讀
    let user = sqlx::query_as::<_, User>(
        r#"
        SELECT id, username, email, password_hash, created_at, updated_at
        FROM users
        WHERE id = $1
        "#,
    )
    .bind(id)
    .fetch_optional(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    match user {
        Some(u) => {
            let resp: UserResponse = u.into();
            // 3) 將結果寫進 Redis(忽略寫入錯誤,避免阻塞請求)
            let mut redis_for_set = redis.clone();
            let resp_clone = resp.clone();
            // 非阻塞地嘗試 set cache:可 spawn 背景任務
            tokio::spawn(async move {
                let _ = set_user_cache(&mut redis_for_set, id, &resp_clone).await;
            });
            Ok((StatusCode::OK, Json(resp)))
        }
        None => {
            // 負快取:可選擇把不存在的結果也暫時快取,避免 DB 被刷爆
            // 例如 set key with short TTL (30s) to indicate "not found"
            Err((StatusCode::NOT_FOUND, format!("user {} not found", id)))
        }
    }
}

// GET /users -> SELECT list (簡單 limit, offset)
#[derive(Deserialize)]
pub struct ListParams {
    pub limit: Option<u32>,
    pub offset: Option<u32>,
}

#[utoipa::path(
    get,
    path = "/users",
    tag = "users",
    responses(
        (status = 200, description = "Get users", body = UserResponse),
    )
)]
pub async fn list_users(
    Extension(pool): Extension<PgPool>,
    Query(params): Query<ListParams>,
) -> Result<impl IntoResponse, AppError> {
    let limit = params.limit.unwrap_or(50) as i64;
    let offset = params.offset.unwrap_or(0) as i64;

    let users = sqlx::query_as::<_, User>(
        r#"
        SELECT id, username, email, password_hash, created_at, updated_at
        FROM users
        ORDER BY id
        LIMIT $1 OFFSET $2
        "#,
    )
    .bind(limit)
    .bind(offset)
    .fetch_all(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    Ok((StatusCode::OK, Json(users)))
}

// PUT /users/{id} -> UPDATE
#[utoipa::path(
    put,
    path = "/users/{id}",
    tag = "users",
    params(
        ("id" = i64, Path, description = "User id")
    ),
    responses(
        (status = 200, description = "Update user", body = UserResponse),
		(status = 403, description = "Not owner"),
        (status = 404, description = "Not found")
    ),
    security(
        ("bearer_auth" = []) // 如果此路由需要授權,可標註 security
    )
)]
pub async fn update_user(
    Extension(pool): Extension<PgPool>,
    Extension(mut redis): Extension<MultiplexedConnection>,
    AuthenticatedUser(claims): AuthenticatedUser,
    Path(id): Path<i64>,
    Json(payload): Json<UpdateUser>,
) -> Result<impl IntoResponse, AppError> {
    // 授權檢查:只有 owner 可以更新
    let caller_id: i64 = claims
        .sub
        .parse()
        .map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token sub".to_string()))?;

    if caller_id != id {
        return Err((StatusCode::FORBIDDEN, "forbidden".to_string()));
    }

    // 簡單示範:先取得現有資料,再更新指定欄位
    let existing = sqlx::query_as::<_, User>(
        "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE id = $1",
    )
    .bind(id)
    .fetch_optional(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    let existing = match existing {
        Some(e) => e,
        None => return Err((StatusCode::NOT_FOUND, format!("user {} not found", id))),
    };

    let new_username = payload.username.unwrap_or(existing.username);
    let new_email = payload.email.unwrap_or(existing.email);
    let new_password_hash = match payload.password {
        Some(p) => hash_password(p).await.map_err(|e| internal_err(e))?, // 實務上 hash 密碼
        None => existing.password_hash,
    };

    let updated = sqlx::query_as::<_, User>(
        r#"
        UPDATE users
        SET username = $1, email = $2, password_hash = $3, updated_at = now()
        WHERE id = $4
        RETURNING id, username, email, password_hash, created_at, updated_at
        "#,
    )
    .bind(&new_username)
    .bind(&new_email)
    .bind(&new_password_hash)
    .bind(id)
    .fetch_one(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    let resp: UserResponse = updated.into();
    let mut redis_for_set = redis.clone();
    let resp_clone = resp.clone();
    tokio::spawn(async move {
        let _ = set_user_cache(&mut redis_for_set, resp.id, &resp_clone).await;
    });
    Ok((StatusCode::OK, Json(resp)))

    //Ok((StatusCode::OK, Json(updated)))
}

// DELETE /users/{id} -> DELETE
#[utoipa::path(
    delete,
    path = "/users/{id}",
    tag = "users",
    params(
        ("id" = i64, Path, description = "User id")
    ),
    responses(
        (status = 204, description = "Delete user"),
		(status = 403, description = "Not owner"),
        (status = 404, description = "Not found")
    ),
    security(
        ("bearer_auth" = []) // 如果此路由需要授權,可標註 security
    )
)]
pub async fn delete_user(
    Extension(pool): Extension<PgPool>,
    Extension(mut redis): Extension<MultiplexedConnection>,
    AuthenticatedUser(claims): AuthenticatedUser,
    Path(id): Path<i64>,
) -> Result<impl IntoResponse, AppError> {
    let caller_id: i64 = claims
        .sub
        .parse()
        .map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token sub".to_string()))?;

    if caller_id != id {
        return Err((StatusCode::FORBIDDEN, "forbidden".to_string()));
    }

    // 如果要避免 TOCTOU,可以把 owner 條件也寫到 SQL 的 WHERE
    let res = sqlx::query!(
        r#"
		DELETE FROM users
		WHERE id = $1 AND id = $2
		"#,
        id,
        caller_id
    )
    .execute(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    if res.rows_affected() == 0 {
        return Err((StatusCode::NOT_FOUND, format!("user {} not found", id)));
    }

    let mut redis_for_del = redis.clone();
    tokio::spawn(async move {
        let _ = invalidate_user_cache(&mut redis_for_del, id).await;
    });

    // 回傳 204 No Content
    Ok(StatusCode::NO_CONTENT)
}

// POST /users/login -> Login
// LoginRequest { username_or_email, password }
#[utoipa::path(
    post,
    path = "/users/login",
    tag = "auth",
    request_body = LoginRequest,
    responses(
        (status = 200, description = "Login success", content_type = "application/json"),
        (status = 401, description = "Invalid credentials")
    )
)]
pub async fn login(
    Extension(pool): Extension<PgPool>,
    Json(payload): Json<LoginRequest>,
) -> Result<impl IntoResponse, AppError> {
    // 先驗證格式
    if let Err(e) = payload.validate() {
        let body = validation_errors_to_json(&e);
        return Err((StatusCode::BAD_REQUEST, body.to_string()));
    }

    // 1. 先查 DB 取得 user row by username or email
    let user = sqlx::query_as::<_, User>(
        "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = $1 OR email = $1"
    )
    .bind(&payload.username_or_email)
    .fetch_optional(&pool)
    .await
    .map_err(|e| internal_err(e))?;

    let user = match user {
        Some(u) => u,
        None => return Err((StatusCode::UNAUTHORIZED, "invalid credentials".to_string())),
    };

    // 2. 驗證密碼
    let ok = verify_password(payload.password.clone(), user.password_hash.clone())
        .await
        .map_err(|e| internal_err(e))?;

    if !ok {
        return Err((StatusCode::UNAUTHORIZED, "invalid credentials".to_string()));
    }

    // 3. 產生 JWT(access token)
    let jwt_secret = env::var("JWT_SECRET").unwrap_or_else(|_| "set_jwt_secret".to_string());
    // access token 有效 15 分鐘(可用 env 設定)
    let access_token =
        sign_access_token(&jwt_secret, user.id, &user.username, 15).map_err(|e| internal_err(e))?;

    Ok((
        StatusCode::OK,
        Json(json!({"access_token": access_token, "token_type": "bearer", "expires_in": 15*60})),
    ))
}

// 把 ValidationErrors 轉成 { "errors": { "field": ["msg1", "msg2"], ... } }
pub fn validation_errors_to_json(errs: &ValidationErrors) -> serde_json::Value {
    use serde_json::Value;
    let mut map = serde_json::Map::new();

    for (field, errors) in errs.field_errors().iter() {
        let messages: Vec<String> = errors
            .iter()
            .map(|fe| {
                // 優先使用 message,若無則使用 code(或其他 fallback)
                if let Some(msg) = &fe.message {
                    msg.clone().to_string()
                } else {
                    fe.code.to_string().into()
                }
            })
            .collect();
        map.insert(
            field.to_string(),
            Value::Array(messages.into_iter().map(Value::String).collect()),
        );
    }

    json!({ "errors": Value::Object(map) })
}

#[utoipa::path(
    post,
    path = "/users/myid",
    tag = "auth",
    request_body = LoginRequest,
    responses(
        (status = 200, description = "Display User Info", content_type = "application/json")
    )
)]
pub async fn myid(
    AuthenticatedUser(claims): AuthenticatedUser,
) -> Result<impl IntoResponse, AppError> {
    // claims.sub 是 user id(字串)
    Ok((
        StatusCode::OK,
        Json(json!({"id": claims.sub, "username": claims.username})),
    ))
}

api_doc.rs

use crate::models::{CreateUser, LoginRequest, UpdateUser, UserResponse};
use utoipa::OpenApi;

#[derive(OpenApi)]
#[openapi(
    paths(
        crate::handlers::create_user,
        crate::handlers::login,
        crate::handlers::get_user,
        crate::handlers::list_users,
        crate::handlers::update_user,
        crate::handlers::delete_user,
        crate::handlers::myid,
    ),
    components(
        schemas(UserResponse, CreateUser, UpdateUser, LoginRequest)
    ),
    tags(
        (name = "users", description = "User management endpoints"),
        (name = "auth", description = "Authentication endpoints")
    )
)]
pub struct ApiDoc;

app.rs

use axum::{
    Extension, Router,
    routing::{delete, get, post, put},
};
use redis::aio::MultiplexedConnection;
use sqlx::PgPool;

pub fn create_app(pool: PgPool, redis: MultiplexedConnection) -> Router {
    Router::new()
        .route(
            "/users",
            post(crate::handlers::create_user).get(crate::handlers::list_users),
        )
        .route(
            "/users/{id}",
            get(crate::handlers::get_user)
                .put(crate::handlers::update_user)
                .delete(crate::handlers::delete_user),
        )
        .route("/users/login", post(crate::handlers::login))
        .layer(Extension(pool))
        .layer(Extension(redis))
}

auth.rs

use chrono::{Duration, Utc};
use jsonwebtoken::{
    DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
    errors::Result as JwtResult,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
    pub sub: String,      // subject,例如 user id
    pub username: String, // 放 username 方便前端顯示
    pub exp: i64,         // 失效時間,用UNIX timestamp表示
}

/// 產生 access token
pub fn sign_access_token(
    secret: &str,
    user_id: i64,
    username: &str,
    minutes: i64,
) -> anyhow::Result<String> {
    let exp = Utc::now()
        .checked_add_signed(Duration::minutes(minutes))
        .ok_or_else(|| anyhow::anyhow!("invalid exp"))?
        .timestamp();

    let claims = Claims {
        sub: user_id.to_string(),
        username: username.to_string(),
        exp,
    };

    let token = encode(
        &Header::default(),
        &claims,
        &EncodingKey::from_secret(secret.as_bytes()),
    )
    .map_err(|e| anyhow::anyhow!("token sign error: {}", e))?;
    Ok(token)
}

/// 驗證 token 並回傳 Claims
pub fn validate_token(secret: &str, token: &str) -> JwtResult<TokenData<Claims>> {
    decode::<Claims>(
        token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::default(),
    )
}

cache.rs

use crate::models::UserResponse;
use redis::AsyncCommands;
use redis::aio::MultiplexedConnection;
use serde_json;

const USER_CACHE_PREFIX: &str = "user:"; // key = user:{id}
const USER_CACHE_TTL_SECS: usize = 60 * 5; // 5 分鐘 TTL, 可調

pub async fn get_user_cache(
    redis: &mut MultiplexedConnection,
    id: i64,
) -> redis::RedisResult<Option<UserResponse>> {
    let key = format!("{}{}", USER_CACHE_PREFIX, id);
    let v: Option<String> = redis.get(&key).await?; // get returns Option<String>
    match v {
        Some(s) => {
            let user: UserResponse = serde_json::from_str(&s).map_err(|_e| {
                // map serde error into redis::RedisError
                redis::RedisError::from((redis::ErrorKind::TypeError, "serde json parse error"))
            })?;
            Ok(Some(user))
        }
        None => Ok(None),
    }
}

pub async fn set_user_cache(
    redis: &mut MultiplexedConnection,
    id: i64,
    user: &UserResponse,
) -> redis::RedisResult<()> {
    let key = format!("{}{}", USER_CACHE_PREFIX, id);
    let s = serde_json::to_string(user).map_err(|_| {
        redis::RedisError::from((redis::ErrorKind::TypeError, "serde json serialize error"))
    })?;
    // SETEX: set with TTL
    let _: () = redis
        .set_ex(key, s, USER_CACHE_TTL_SECS.try_into().unwrap())
        .await?;
    Ok(())
}

pub async fn invalidate_user_cache(
    redis: &mut MultiplexedConnection,
    id: i64,
) -> redis::RedisResult<()> {
    let key = format!("{}{}", USER_CACHE_PREFIX, id);
    let _: () = redis.del(key).await?;
    Ok(())
}

extractors.rs

use crate::auth::{Claims, validate_token};
use axum::extract::FromRequestParts;
use axum::http::StatusCode;
use axum::http::request::Parts;
use futures::future::{BoxFuture, FutureExt};
use std::env;

pub struct AuthenticatedUser(pub Claims);

impl<S> FromRequestParts<S> for AuthenticatedUser
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, String);

    fn from_request_parts(
        parts: &mut Parts,
        _state: &S,
    ) -> BoxFuture<'static, Result<Self, Self::Rejection>> {
        // 同步地從 parts.headers 取出 Authorization header 的 String(或錯誤訊息)
        let token_opt: Result<String, (StatusCode, String)> = parts
            .headers
            .get(axum::http::header::AUTHORIZATION)
            .and_then(|v| v.to_str().ok())
            .ok_or_else(|| {
                (
                    StatusCode::UNAUTHORIZED,
                    "missing authorization".to_string(),
                )
            })
            .and_then(|auth| {
                auth.strip_prefix("Bearer ")
                    .map(|s| s.to_string())
                    .ok_or_else(|| {
                        (
                            StatusCode::UNAUTHORIZED,
                            "invalid authorization scheme".to_string(),
                        )
                    })
            });

        // 讀取 env(也是同步)
        let jwt_secret = env::var("JWT_SECRET").unwrap_or_else(|_| "set_jwt_secret".to_string());

        // 將同步處理的結果移入 async block
        async move {
            let token = match token_opt {
                Ok(t) => t,
                Err(e) => return Err(e),
            };

            match validate_token(&jwt_secret, &token) {
                Ok(data) => Ok(AuthenticatedUser(data.claims)),
                Err(_) => Err((StatusCode::UNAUTHORIZED, "invalid token".to_string())),
            }
        }
        .boxed()
    }
}

lib.rs

pub mod app;
pub mod auth;
pub mod cache;
pub mod extractors;
pub mod handlers;
pub mod models;
pub mod password;

password.rs

use anyhow::Result;
use argon2::{
    Argon2, Params, PasswordHasher, PasswordVerifier,
    password_hash::{PasswordHash, SaltString, rand_core::OsRng},
};
use tokio::task;

/// Argon2 參數(可依環境調整)
fn default_argon2_params() -> Params {
    // 這裡的 memory_size 單位是 KB(例如 65536 KB = 64 MB)
    // time_cost = iterations
    // lanes = parallelism
    Params::new(65536, 3, 1, None).expect("invalid argon2 params")
}

/// 非同步呼叫:將明文密碼雜湊成 encoded string(blocking work)
pub async fn hash_password(password: String) -> Result<String> {
    // spawn_blocking 避免阻塞
    task::spawn_blocking(move || {
        // 使用 Argon2id
        let argon2 = Argon2::new(
            argon2::Algorithm::Argon2id,
            argon2::Version::V0x13,
            default_argon2_params(),
        );
        // 自動產生 salt
        let salt = SaltString::generate(&mut OsRng);
        // hash
        let password_hash = argon2
            .hash_password(password.as_bytes(), &salt)
            .map_err(|e| anyhow::anyhow!("argon2 hash error: {}", e))?
            .to_string();
        Ok(password_hash)
    })
    .await?
}

/// 非同步呼叫:驗證明文密碼是否與已儲存的 hash 相符
pub async fn verify_password(password: String, password_hash: String) -> Result<bool> {
    task::spawn_blocking(move || {
        let parsed_hash = PasswordHash::new(&password_hash)
            .map_err(|e| anyhow::anyhow!("invalid password hash format: {}", e))?;
        let argon2 = Argon2::default();
        match argon2.verify_password(password.as_bytes(), &parsed_hash) {
            Ok(_) => Ok(true),
            Err(argon2::password_hash::Error::Password) => Ok(false),
            Err(e) => Err(anyhow::anyhow!("argon2 verify error: {}", e)),
        }
    })
    .await?
}

test/integration_tests.rs

use redis::Client as RedisClient;
use reqwest::StatusCode;
use serde_json::Value;
use sqlx::migrate::MigrateDatabase;
use sqlx::{Executor, PgPool};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use uuid::Uuid;

async fn spawn_app() -> (String, PgPool, redis::aio::MultiplexedConnection) {
    // 建立測試資料庫(用 TEST_DATABASE_URL 隔離真實資料庫)
    let base_url = std::env::var("TEST_DATABASE_URL_BASE").expect("TEST_DATABASE_URL_BASE");
    let db_name = format!("test_db_{}", Uuid::new_v4().to_string().replace("-", ""));
    let db_url = format!("{}/{}", base_url, db_name);

    if !sqlx::postgres::Postgres::database_exists(&db_url)
        .await
        .unwrap_or(false)
    {
        sqlx::Postgres::create_database(&db_url).await.unwrap();
    }

    let pool = PgPool::connect(&db_url).await.unwrap();
    // 執行 migrations
    sqlx::migrate!("./migrations").run(&pool).await.unwrap();

    // Redis - 指定 TEST_REDIS_URL,隔離開發環境
    let redis_url =
        std::env::var("TEST_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/1".to_string());
    let redis_client = RedisClient::open(redis_url.as_str()).unwrap();
    let redis_conn = redis_client
        .get_multiplexed_tokio_connection()
        .await
        .unwrap();

    // 使用不同 port 避免占用真實專案的 port
    let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = listener.local_addr().unwrap();
    let app = sqlx_connect_demo::app::create_app(pool.clone(), redis_conn.clone());

    let server_handle = tokio::spawn(async move {
        axum::serve(listener, app).await.unwrap();
    });

    (format!("http://{}", addr), pool, redis_conn)
}

#[tokio::test]
async fn integration_create_get_update_delete_user_flow() {
    let (base_url, pool, _redis) = spawn_app().await;

    let client = reqwest::Client::new();

    // POST /users (create)
    let res = client
        .post(format!("{}/users", base_url))
        .json(&serde_json::json!({
            "username": "user1",
            "email": "user1@a.com",
            "password": "password"
        }))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::CREATED);
    let body: Value = res.json().await.unwrap();
    let id = body["id"].as_i64().unwrap();

    // GET /users/{id} - 成功
    let res = client
        .get(format!("{}/users/{}", base_url, id))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::OK);

    // PUT /users/{id} - 成功
    let res = client
        .put(format!("{}/users/{}", base_url, id))
        .json(&serde_json::json!({"username": "new_user1"}))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::OK);

    // POST /users/login -> 成功
    let res = client
        .post(format!("{}/users/login", base_url))
        .json(&serde_json::json!({
            "username_or_email": "new_user1",
            "password": "password"
        }))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::OK);

    // DELETE /users/{id} -> 成功
    let res = client
        .delete(format!("{}/users/{}", base_url, id))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::NO_CONTENT);

    // GET 已被刪除的資料 -> 404
    let res = client
        .get(format!("{}/users/{}", base_url, id))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::NOT_FOUND);
}

#[tokio::test]
async fn integration_login_failures() {
    let (base_url, _pool, _redis) = spawn_app().await;
    let client = reqwest::Client::new();

    // 登入不存在的用戶 -> 401
    let res = client
        .post(format!("{}/users/login", base_url))
        .json(&serde_json::json!({
            "username_or_email": "not_exists",
            "password": "password"
        }))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::UNAUTHORIZED);

    // 建立用戶 -> 成功
    let res = client
        .post(format!("{}/users", base_url))
        .json(&serde_json::json!({
            "username": "user1",
            "email": "user1@a.com",
            "password": "truepass"
        }))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::CREATED);

    // 登入,但密碼錯誤 -> 401
    let res = client
        .post(format!("{}/users/login", base_url))
        .json(&serde_json::json!({
            "username_or_email": "user1",
            "password": "falsepass"
        }))
        .send()
        .await
        .unwrap();
    assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}


上一篇
Day 25 Axum 限制只有使用者能修改自己的資料
下一篇
Day 27 Axum 專案使用 Docker 打包
系列文
Rust 後端入門30
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言