在 Cargo.toml 加入以下依賴(版本依官方穩定版本為準):
utoipa = { version = "5", features = ["macros"] }
utoipa-swagger-ui = "9"
把 model 與 DTO 加上 ToSchema
在 models.rs加上 utoipa 的 derive:
use utoipa::ToSchema;
#[derive(Serialize, FromRow, Clone, ToSchema)]
pub struct User {
pub id: i64,
pub username: String,
pub email: String,
pub password_hash: String,
pub created_at: OffsetDateTime,
pub updated_at: OffsetDateTime,
}
#[derive(Serialize, Deserialize, Clone, ToSchema)]
pub struct UserResponse {
pub id: i64,
pub username: String,
pub email: String,
pub created_at: OffsetDateTime,
pub updated_at: OffsetDateTime,
}
#[derive(Deserialize, Validate, ToSchema)]
pub struct CreateUser {
#[validate(length(min = 3, max = 50))]
pub username: String,
#[validate(email)]
pub email: String,
#[validate(length(min = 8))]
pub password: String,
}
#[derive(Deserialize, Validate, ToSchema)]
pub struct UpdateUser {
#[validate(length(min = 3, max = 50))]
pub username: Option<String>,
#[validate(email)]
pub email: Option<String>,
#[validate(length(min = 8))]
pub password: Option<String>,
}
#[derive(Serialize, Deserialize, Clone, Validate, ToSchema)]
pub struct LoginRequest {
#[validate(length(min = 1))]
pub username_or_email: String,
#[validate(length(min = 8))]
pub password: String,
}
說明:
若 OffsetDateTime 在 schema 上造成問題,可為 UserResponse 的時間欄位加入屬性或改成 String 以便文件正確顯示。例如:
#[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
pub created_at: OffsetDateTime,
#[schema(value_type = String, example = "2024-01-02T12:00:00Z")]
pub updated_at: OffsetDateTime,
utoipa 支援用屬性巨集描述 path。我們將要把 create_user、login、get_user 三個路由標註:
路由註解 create_user :
#[utoipa::path(
post,
path = "/users",
tag = "users",
request_body = CreateUser,
responses(
(status = 201, description = "User created", body = UserResponse),
(status = 400, description = "Invalid payload")
)
)]
pub async fn create_user( /* ... */ ) -> Result<impl IntoResponse, AppError> {
// 原本函式內容不變
}
路由註解 login :
#[utoipa::path(
post,
path = "/users/login",
tag = "auth",
request_body = LoginRequest,
responses(
(status = 200, description = "Login success", content_type = "application/json"),
(status = 401, description = "Invalid credentials")
)
)]
pub async fn login( /* ... */ ) -> Result<impl IntoResponse, AppError> {
// 原本函式內容不變
}
其餘部分依此類推。
建立 OpenAPI 根物件(derive OpenApi)
在 src/api_doc.rs 加入:
use utoipa::OpenApi;
use crate::models::{UserResponse, CreateUser, UpdateUser, LoginRequest};
#[derive(OpenApi)]
#[openapi(
paths(
crate::handlers::create_user,
crate::handlers::login,
crate::handlers::get_user,
crate::handlers::list_users,
crate::handlers::update_user,
crate::handlers::delete_user,
crate::handlers::myid,
),
components(
schemas(UserResponse, CreateUser, UpdateUser, LoginRequest)
),
tags(
(name = "users", description = "User management endpoints"),
(name = "auth", description = "Authentication endpoints")
)
)]
pub struct ApiDoc;
說明:
註明安全性(Bearer JWT)
在 OpenApi derive 或手動建立 OpenAPI 時可加入 security schema,例如 Bearer token:
在同一個 api_doc.rs 裡可以宣告 security scheme:
#[derive(OpenApi)]
#[openapi(
paths(...),
components(schemas(...)),
modifiers(&SecurityAddon),
tags(
(name = "users", description = "User management"),
(name = "auth", description = "Authentication"),
)
)]
pub struct ApiDoc;
pub struct SecurityAddon;
impl utoipa::openapi::Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
use utoipa::openapi::security::{SecurityScheme, HttpAuthScheme};
let bearer = SecurityScheme::Http(utoipa::openapi::security::Http::new(
HttpAuthScheme::Bearer,
Some("JWT"),
));
openapi
.components
.get_or_insert_with(Default::default)
.security_schemes
.insert("bearer_auth".to_string(), bearer);
}
}
這樣在你標註某些 path 的 security(("bearer_auth" = [])) 時,Swagger UI 會顯示出一個鎖。
在 main.rs 提供 OpenAPI JSON 與 Swagger UI 路由
把以下程式碼加入 main.rs,在 /api-doc/openapi.json 提供 JSON,在 /swagger 提供互動 UI(utoipa-swagger-ui):
在 main.rs 的 imports 加:
use utoipa_swagger_ui::SwaggerUi;
use crate::api_doc::ApiDoc;
在建立 Router 時加入 swagger 路由(在建立 app 的地方):
let app = Router::new()
// 既有 routes
.route("/users", post(handlers::create_user).get(handlers::list_users))
// ...
// swagger UI 路由
.merge(
SwaggerUi::new("/swagger") // 用於 UI 的 endpoint
.url("/api-doc/openapi.json", ApiDoc::openapi()) // 提供 openapi.json 的路徑與內容
)
.layer(cors)
.layer(trace)
.layer(Extension(pool))
.layer(Extension(redis_conn));
utoipa-swagger-ui 已經內嵌了 Swagger UI 靜態資源,簡單好用。
啟動專案後,用瀏覽器開啟 localhost:3000/swagger 就能看到精美的API
接下來附上完成專案的程式:
Cargo.toml
[package]
name = "sqlx_connect_demo"
version = "0.1.0"
edition = "2024"
[dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "macros", "migrate", "time"] }
dotenvy = "0.15"
axum = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
anyhow = "1.0"
time = { version = "0.3", features = ["serde"] }
redis = { version = "0.32", features = ["aio", "tokio-comp"] }
argon2 = "0.5" # 或最新穩定版本
password-hash = "0.5" # argon2 會用到,用於 decode/verify
tower-http = { version = "0.6", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
http = "1.3"
validator = "0.20"
validator_derive = "0.20"
jsonwebtoken = "9.3"
chrono = { version = "0.4", features = ["serde"] }
futures = "0.3"
utoipa = { version = "5", features = ["macros"] }
utoipa-swagger-ui = { version = "9", features = ["axum"] }
[dev-dependencies]
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
uuid = { version = "1", features = ["v4"] }
main.rs
mod api_doc;
mod app;
mod auth;
mod cache;
mod extractors;
mod handlers;
mod models;
mod password;
use crate::api_doc::ApiDoc;
use axum::{
Router,
extract::{Extension, Path},
routing::{delete, get, post, put},
};
use dotenvy::dotenv;
use http::{HeaderValue, Method};
use redis::Client as RedisClient;
use redis::aio::MultiplexedConnection;
use sqlx::migrate::MigrateDatabase;
use sqlx::postgres::PgPoolOptions;
use std::env;
use std::time::Duration;
use tokio::time;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tracing::Level;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{EnvFilter, fmt};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
#[tokio::main]
async fn main() {
dotenv().ok();
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(fmt::layer().with_target(false)) // 或 .json() 輸出 JSON
.init();
let database_url = match env::var("DATABASE_URL") {
Ok(v) => v,
Err(_) => {
eprintln!("錯誤:找不到 DATABASE_URL");
std::process::exit(1);
}
};
// 例如 5 秒超時
let connect_future = PgPoolOptions::new()
.max_connections(5)
.connect(&database_url);
match sqlx::postgres::Postgres::database_exists(&database_url).await {
Ok(exists) => {
if !exists {
println!("資料庫不存在,嘗試建立...");
if let Err(e) = sqlx::postgres::Postgres::create_database(&database_url).await {
eprintln!("建立失敗: {}", e);
}
}
}
Err(e) => eprintln!("檢查資料庫是否存在失敗: {}", e),
}
let pool = match time::timeout(Duration::from_secs(5), connect_future).await {
Ok(Ok(p)) => {
println!("成功建立 PgPool");
p
}
Ok(Err(e)) => {
eprintln!("建立 PgPool 失敗: {}", e);
std::process::exit(1);
}
Err(_) => {
eprintln!("建立 PgPool 超時");
std::process::exit(1);
}
};
if let Err(e) = sqlx::migrate!("./migrations").run(&pool).await {
eprintln!("migrations失敗: {}", e);
std::process::exit(1);
}
println!("成功完成 migrations");
let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/".to_string());
let redis_client = RedisClient::open(redis_url.as_str()).expect("invalid redis url");
let redis_conn: MultiplexedConnection = redis_client
.get_multiplexed_tokio_connection()
.await
.expect("redis connect fail");
let localhost = HeaderValue::from_static("http://localhost:5173");
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::exact(localhost))
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers(Any)
.max_age(std::time::Duration::from_secs(600));
let trace = TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(false)) // 不自動包含 headers,避免敏感資訊
.on_response(DefaultOnResponse::new().level(Level::INFO));
let app = Router::new()
.route(
"/users",
post(handlers::create_user).get(handlers::list_users),
)
.route(
"/users/{id}",
get(handlers::get_user)
.put(handlers::update_user)
.delete(handlers::delete_user),
)
.route("/users/login", post(handlers::login))
.route("/myid", get(handlers::myid))
// swagger UI 路由
.merge(
SwaggerUi::new("/swagger") // 用於 UI 的 endpoint
.url("/api-doc/openapi.json", ApiDoc::openapi()), // 提供 openapi.json 的路徑與內容
)
.layer(cors)
.layer(trace)
.layer(Extension(pool))
.layer(Extension(redis_conn));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
models.rs
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use time::OffsetDateTime;
use utoipa::ToSchema;
use validator_derive::Validate;
#[derive(Serialize, FromRow, Clone, ToSchema)]
pub struct User {
pub id: i64,
pub username: String,
pub email: String,
pub password_hash: String,
#[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
pub created_at: OffsetDateTime,
#[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
pub updated_at: OffsetDateTime,
}
#[derive(Serialize, Deserialize, Clone, ToSchema)]
pub struct UserResponse {
pub id: i64,
pub username: String,
pub email: String,
#[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
pub created_at: OffsetDateTime,
#[schema(value_type = String, example = "2024-01-01T12:00:00Z")]
pub updated_at: OffsetDateTime,
}
impl From<User> for UserResponse {
fn from(u: User) -> Self {
UserResponse {
id: u.id,
username: u.username,
email: u.email,
created_at: u.created_at,
updated_at: u.updated_at,
}
}
}
// 用於建立(request body)
#[derive(Deserialize, Validate, ToSchema)]
pub struct CreateUser {
// 介於 3 到 50 字元之間
#[validate(length(min = 3, max = 50))]
pub username: String,
// 驗證是否符合 email 格式
#[validate(email)]
pub email: String,
// 密碼至少 8 字元
#[validate(length(min = 8))]
pub password: String,
}
// 用於更新(部分 update 使用 PUT ,替換全部內容可改為 PATCH)
// UpdateUser:Option 欄位在 Some 時套用驗證
#[derive(Deserialize, Validate, ToSchema)]
pub struct UpdateUser {
#[validate(length(min = 3, max = 50))]
pub username: Option<String>,
#[validate(email)]
pub email: Option<String>,
#[validate(length(min = 8))]
pub password: Option<String>,
}
// 用於登入
#[derive(Serialize, Deserialize, Clone, Validate, ToSchema)]
pub struct LoginRequest {
#[validate(length(min = 1))]
pub username_or_email: String,
#[validate(length(min = 8))]
pub password: String,
}
handlers.rs
use crate::auth::sign_access_token;
use crate::cache::{get_user_cache, invalidate_user_cache, set_user_cache};
use crate::extractors::AuthenticatedUser;
use crate::models::LoginRequest;
use crate::models::UserResponse;
use crate::models::{CreateUser, UpdateUser, User};
use crate::password::{hash_password, verify_password};
use axum::{
extract::{Extension, Json, Path, Query},
http::StatusCode,
response::IntoResponse,
};
use redis::aio::MultiplexedConnection;
use serde::Deserialize;
use serde_json::json;
use sqlx::PgPool;
use std::env;
use validator::{Validate, ValidationErrors};
type AppError = (StatusCode, String);
// Helper: map anyhow::Error -> HTTP 500
fn internal_err(e: impl std::fmt::Display) -> AppError {
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", e))
}
// POST /users -> INSERT
#[utoipa::path(
post,
path = "/users",
tag = "users",
request_body = CreateUser,
responses(
(status = 201, description = "User created", body = UserResponse),
(status = 400, description = "Invalid payload")
)
)]
pub async fn create_user(
Extension(pool): Extension<PgPool>,
Extension(mut redis): Extension<MultiplexedConnection>,
Json(payload): Json<CreateUser>,
) -> Result<impl IntoResponse, AppError> {
// DTO 驗證(同步)
if let Err(e) = payload.validate() {
let body = validation_errors_to_json(&e);
return Err((StatusCode::BAD_REQUEST, body.to_string()));
}
// 實務上:在此對 payload 做驗證(username/email 格式、密碼強度等)
// 密碼雜湊:示範使用 argon2 。
let password_hash = hash_password(payload.password.clone())
.await
.map_err(|e| internal_err(e))?;
let rec = sqlx::query_as::<_, User>(
r#"
INSERT INTO users (username, email, password_hash, created_at, updated_at)
VALUES ($1, $2, $3, now(), now())
RETURNING id, username, email, password_hash, created_at, updated_at
"#,
)
.bind(&payload.username)
.bind(&payload.email)
.bind(&password_hash)
.fetch_one(&pool)
.await
.map_err(|e| internal_err(e))?;
// 轉成外部回傳用型別
let resp: UserResponse = rec.clone().into();
// 背景寫入快取(不阻塞回應)
// 注意:我們把 redis.clone() 給背景任務
let mut redis_for_set = redis.clone();
let user_resp = resp.clone();
tokio::spawn(async move {
let _ = set_user_cache(&mut redis_for_set, user_resp.id, &user_resp).await;
});
Ok((StatusCode::CREATED, Json(resp)))
// 回傳 201 Created 與新建立的資源
//Ok((StatusCode::CREATED, Json(rec)))
}
// GET /users/{id} -> SELECT single
#[utoipa::path(
get,
path = "/users/{id}",
tag = "users",
params(
("id" = i64, Path, description = "User id")
),
responses(
(status = 200, description = "Get user", body = UserResponse),
(status = 404, description = "Not found")
)
)]
pub async fn get_user(
Extension(pool): Extension<PgPool>,
Extension(mut redis): Extension<MultiplexedConnection>, // 注意:在 same layer 時要指定兩個 Extension 的順序,axum 會匹配
Path(id): Path<i64>,
) -> Result<impl IntoResponse, AppError> {
// 1) 嘗試從 Redis 取
if let Ok(Some(user_res)) = get_user_cache(&mut redis, id).await {
// 快取命中 — 直接回傳 (200)
return Ok((StatusCode::OK, Json(user_res)));
}
// 2) Cache miss -> 從 DB 讀
let user = sqlx::query_as::<_, User>(
r#"
SELECT id, username, email, password_hash, created_at, updated_at
FROM users
WHERE id = $1
"#,
)
.bind(id)
.fetch_optional(&pool)
.await
.map_err(|e| internal_err(e))?;
match user {
Some(u) => {
let resp: UserResponse = u.into();
// 3) 將結果寫進 Redis(忽略寫入錯誤,避免阻塞請求)
let mut redis_for_set = redis.clone();
let resp_clone = resp.clone();
// 非阻塞地嘗試 set cache:可 spawn 背景任務
tokio::spawn(async move {
let _ = set_user_cache(&mut redis_for_set, id, &resp_clone).await;
});
Ok((StatusCode::OK, Json(resp)))
}
None => {
// 負快取:可選擇把不存在的結果也暫時快取,避免 DB 被刷爆
// 例如 set key with short TTL (30s) to indicate "not found"
Err((StatusCode::NOT_FOUND, format!("user {} not found", id)))
}
}
}
// GET /users -> SELECT list (簡單 limit, offset)
#[derive(Deserialize)]
pub struct ListParams {
pub limit: Option<u32>,
pub offset: Option<u32>,
}
#[utoipa::path(
get,
path = "/users",
tag = "users",
responses(
(status = 200, description = "Get users", body = UserResponse),
)
)]
pub async fn list_users(
Extension(pool): Extension<PgPool>,
Query(params): Query<ListParams>,
) -> Result<impl IntoResponse, AppError> {
let limit = params.limit.unwrap_or(50) as i64;
let offset = params.offset.unwrap_or(0) as i64;
let users = sqlx::query_as::<_, User>(
r#"
SELECT id, username, email, password_hash, created_at, updated_at
FROM users
ORDER BY id
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&pool)
.await
.map_err(|e| internal_err(e))?;
Ok((StatusCode::OK, Json(users)))
}
// PUT /users/{id} -> UPDATE
#[utoipa::path(
put,
path = "/users/{id}",
tag = "users",
params(
("id" = i64, Path, description = "User id")
),
responses(
(status = 200, description = "Update user", body = UserResponse),
(status = 403, description = "Not owner"),
(status = 404, description = "Not found")
),
security(
("bearer_auth" = []) // 如果此路由需要授權,可標註 security
)
)]
pub async fn update_user(
Extension(pool): Extension<PgPool>,
Extension(mut redis): Extension<MultiplexedConnection>,
AuthenticatedUser(claims): AuthenticatedUser,
Path(id): Path<i64>,
Json(payload): Json<UpdateUser>,
) -> Result<impl IntoResponse, AppError> {
// 授權檢查:只有 owner 可以更新
let caller_id: i64 = claims
.sub
.parse()
.map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token sub".to_string()))?;
if caller_id != id {
return Err((StatusCode::FORBIDDEN, "forbidden".to_string()));
}
// 簡單示範:先取得現有資料,再更新指定欄位
let existing = sqlx::query_as::<_, User>(
"SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE id = $1",
)
.bind(id)
.fetch_optional(&pool)
.await
.map_err(|e| internal_err(e))?;
let existing = match existing {
Some(e) => e,
None => return Err((StatusCode::NOT_FOUND, format!("user {} not found", id))),
};
let new_username = payload.username.unwrap_or(existing.username);
let new_email = payload.email.unwrap_or(existing.email);
let new_password_hash = match payload.password {
Some(p) => hash_password(p).await.map_err(|e| internal_err(e))?, // 實務上 hash 密碼
None => existing.password_hash,
};
let updated = sqlx::query_as::<_, User>(
r#"
UPDATE users
SET username = $1, email = $2, password_hash = $3, updated_at = now()
WHERE id = $4
RETURNING id, username, email, password_hash, created_at, updated_at
"#,
)
.bind(&new_username)
.bind(&new_email)
.bind(&new_password_hash)
.bind(id)
.fetch_one(&pool)
.await
.map_err(|e| internal_err(e))?;
let resp: UserResponse = updated.into();
let mut redis_for_set = redis.clone();
let resp_clone = resp.clone();
tokio::spawn(async move {
let _ = set_user_cache(&mut redis_for_set, resp.id, &resp_clone).await;
});
Ok((StatusCode::OK, Json(resp)))
//Ok((StatusCode::OK, Json(updated)))
}
// DELETE /users/{id} -> DELETE
#[utoipa::path(
delete,
path = "/users/{id}",
tag = "users",
params(
("id" = i64, Path, description = "User id")
),
responses(
(status = 204, description = "Delete user"),
(status = 403, description = "Not owner"),
(status = 404, description = "Not found")
),
security(
("bearer_auth" = []) // 如果此路由需要授權,可標註 security
)
)]
pub async fn delete_user(
Extension(pool): Extension<PgPool>,
Extension(mut redis): Extension<MultiplexedConnection>,
AuthenticatedUser(claims): AuthenticatedUser,
Path(id): Path<i64>,
) -> Result<impl IntoResponse, AppError> {
let caller_id: i64 = claims
.sub
.parse()
.map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token sub".to_string()))?;
if caller_id != id {
return Err((StatusCode::FORBIDDEN, "forbidden".to_string()));
}
// 如果要避免 TOCTOU,可以把 owner 條件也寫到 SQL 的 WHERE
let res = sqlx::query!(
r#"
DELETE FROM users
WHERE id = $1 AND id = $2
"#,
id,
caller_id
)
.execute(&pool)
.await
.map_err(|e| internal_err(e))?;
if res.rows_affected() == 0 {
return Err((StatusCode::NOT_FOUND, format!("user {} not found", id)));
}
let mut redis_for_del = redis.clone();
tokio::spawn(async move {
let _ = invalidate_user_cache(&mut redis_for_del, id).await;
});
// 回傳 204 No Content
Ok(StatusCode::NO_CONTENT)
}
// POST /users/login -> Login
// LoginRequest { username_or_email, password }
#[utoipa::path(
post,
path = "/users/login",
tag = "auth",
request_body = LoginRequest,
responses(
(status = 200, description = "Login success", content_type = "application/json"),
(status = 401, description = "Invalid credentials")
)
)]
pub async fn login(
Extension(pool): Extension<PgPool>,
Json(payload): Json<LoginRequest>,
) -> Result<impl IntoResponse, AppError> {
// 先驗證格式
if let Err(e) = payload.validate() {
let body = validation_errors_to_json(&e);
return Err((StatusCode::BAD_REQUEST, body.to_string()));
}
// 1. 先查 DB 取得 user row by username or email
let user = sqlx::query_as::<_, User>(
"SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = $1 OR email = $1"
)
.bind(&payload.username_or_email)
.fetch_optional(&pool)
.await
.map_err(|e| internal_err(e))?;
let user = match user {
Some(u) => u,
None => return Err((StatusCode::UNAUTHORIZED, "invalid credentials".to_string())),
};
// 2. 驗證密碼
let ok = verify_password(payload.password.clone(), user.password_hash.clone())
.await
.map_err(|e| internal_err(e))?;
if !ok {
return Err((StatusCode::UNAUTHORIZED, "invalid credentials".to_string()));
}
// 3. 產生 JWT(access token)
let jwt_secret = env::var("JWT_SECRET").unwrap_or_else(|_| "set_jwt_secret".to_string());
// access token 有效 15 分鐘(可用 env 設定)
let access_token =
sign_access_token(&jwt_secret, user.id, &user.username, 15).map_err(|e| internal_err(e))?;
Ok((
StatusCode::OK,
Json(json!({"access_token": access_token, "token_type": "bearer", "expires_in": 15*60})),
))
}
// 把 ValidationErrors 轉成 { "errors": { "field": ["msg1", "msg2"], ... } }
pub fn validation_errors_to_json(errs: &ValidationErrors) -> serde_json::Value {
use serde_json::Value;
let mut map = serde_json::Map::new();
for (field, errors) in errs.field_errors().iter() {
let messages: Vec<String> = errors
.iter()
.map(|fe| {
// 優先使用 message,若無則使用 code(或其他 fallback)
if let Some(msg) = &fe.message {
msg.clone().to_string()
} else {
fe.code.to_string().into()
}
})
.collect();
map.insert(
field.to_string(),
Value::Array(messages.into_iter().map(Value::String).collect()),
);
}
json!({ "errors": Value::Object(map) })
}
#[utoipa::path(
post,
path = "/users/myid",
tag = "auth",
request_body = LoginRequest,
responses(
(status = 200, description = "Display User Info", content_type = "application/json")
)
)]
pub async fn myid(
AuthenticatedUser(claims): AuthenticatedUser,
) -> Result<impl IntoResponse, AppError> {
// claims.sub 是 user id(字串)
Ok((
StatusCode::OK,
Json(json!({"id": claims.sub, "username": claims.username})),
))
}
api_doc.rs
use crate::models::{CreateUser, LoginRequest, UpdateUser, UserResponse};
use utoipa::OpenApi;
#[derive(OpenApi)]
#[openapi(
paths(
crate::handlers::create_user,
crate::handlers::login,
crate::handlers::get_user,
crate::handlers::list_users,
crate::handlers::update_user,
crate::handlers::delete_user,
crate::handlers::myid,
),
components(
schemas(UserResponse, CreateUser, UpdateUser, LoginRequest)
),
tags(
(name = "users", description = "User management endpoints"),
(name = "auth", description = "Authentication endpoints")
)
)]
pub struct ApiDoc;
app.rs
use axum::{
Extension, Router,
routing::{delete, get, post, put},
};
use redis::aio::MultiplexedConnection;
use sqlx::PgPool;
pub fn create_app(pool: PgPool, redis: MultiplexedConnection) -> Router {
Router::new()
.route(
"/users",
post(crate::handlers::create_user).get(crate::handlers::list_users),
)
.route(
"/users/{id}",
get(crate::handlers::get_user)
.put(crate::handlers::update_user)
.delete(crate::handlers::delete_user),
)
.route("/users/login", post(crate::handlers::login))
.layer(Extension(pool))
.layer(Extension(redis))
}
auth.rs
use chrono::{Duration, Utc};
use jsonwebtoken::{
DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
errors::Result as JwtResult,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // subject,例如 user id
pub username: String, // 放 username 方便前端顯示
pub exp: i64, // 失效時間,用UNIX timestamp表示
}
/// 產生 access token
pub fn sign_access_token(
secret: &str,
user_id: i64,
username: &str,
minutes: i64,
) -> anyhow::Result<String> {
let exp = Utc::now()
.checked_add_signed(Duration::minutes(minutes))
.ok_or_else(|| anyhow::anyhow!("invalid exp"))?
.timestamp();
let claims = Claims {
sub: user_id.to_string(),
username: username.to_string(),
exp,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.map_err(|e| anyhow::anyhow!("token sign error: {}", e))?;
Ok(token)
}
/// 驗證 token 並回傳 Claims
pub fn validate_token(secret: &str, token: &str) -> JwtResult<TokenData<Claims>> {
decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::default(),
)
}
cache.rs
use crate::models::UserResponse;
use redis::AsyncCommands;
use redis::aio::MultiplexedConnection;
use serde_json;
const USER_CACHE_PREFIX: &str = "user:"; // key = user:{id}
const USER_CACHE_TTL_SECS: usize = 60 * 5; // 5 分鐘 TTL, 可調
pub async fn get_user_cache(
redis: &mut MultiplexedConnection,
id: i64,
) -> redis::RedisResult<Option<UserResponse>> {
let key = format!("{}{}", USER_CACHE_PREFIX, id);
let v: Option<String> = redis.get(&key).await?; // get returns Option<String>
match v {
Some(s) => {
let user: UserResponse = serde_json::from_str(&s).map_err(|_e| {
// map serde error into redis::RedisError
redis::RedisError::from((redis::ErrorKind::TypeError, "serde json parse error"))
})?;
Ok(Some(user))
}
None => Ok(None),
}
}
pub async fn set_user_cache(
redis: &mut MultiplexedConnection,
id: i64,
user: &UserResponse,
) -> redis::RedisResult<()> {
let key = format!("{}{}", USER_CACHE_PREFIX, id);
let s = serde_json::to_string(user).map_err(|_| {
redis::RedisError::from((redis::ErrorKind::TypeError, "serde json serialize error"))
})?;
// SETEX: set with TTL
let _: () = redis
.set_ex(key, s, USER_CACHE_TTL_SECS.try_into().unwrap())
.await?;
Ok(())
}
pub async fn invalidate_user_cache(
redis: &mut MultiplexedConnection,
id: i64,
) -> redis::RedisResult<()> {
let key = format!("{}{}", USER_CACHE_PREFIX, id);
let _: () = redis.del(key).await?;
Ok(())
}
extractors.rs
use crate::auth::{Claims, validate_token};
use axum::extract::FromRequestParts;
use axum::http::StatusCode;
use axum::http::request::Parts;
use futures::future::{BoxFuture, FutureExt};
use std::env;
pub struct AuthenticatedUser(pub Claims);
impl<S> FromRequestParts<S> for AuthenticatedUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> BoxFuture<'static, Result<Self, Self::Rejection>> {
// 同步地從 parts.headers 取出 Authorization header 的 String(或錯誤訊息)
let token_opt: Result<String, (StatusCode, String)> = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
"missing authorization".to_string(),
)
})
.and_then(|auth| {
auth.strip_prefix("Bearer ")
.map(|s| s.to_string())
.ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
"invalid authorization scheme".to_string(),
)
})
});
// 讀取 env(也是同步)
let jwt_secret = env::var("JWT_SECRET").unwrap_or_else(|_| "set_jwt_secret".to_string());
// 將同步處理的結果移入 async block
async move {
let token = match token_opt {
Ok(t) => t,
Err(e) => return Err(e),
};
match validate_token(&jwt_secret, &token) {
Ok(data) => Ok(AuthenticatedUser(data.claims)),
Err(_) => Err((StatusCode::UNAUTHORIZED, "invalid token".to_string())),
}
}
.boxed()
}
}
lib.rs
pub mod app;
pub mod auth;
pub mod cache;
pub mod extractors;
pub mod handlers;
pub mod models;
pub mod password;
password.rs
use anyhow::Result;
use argon2::{
Argon2, Params, PasswordHasher, PasswordVerifier,
password_hash::{PasswordHash, SaltString, rand_core::OsRng},
};
use tokio::task;
/// Argon2 參數(可依環境調整)
fn default_argon2_params() -> Params {
// 這裡的 memory_size 單位是 KB(例如 65536 KB = 64 MB)
// time_cost = iterations
// lanes = parallelism
Params::new(65536, 3, 1, None).expect("invalid argon2 params")
}
/// 非同步呼叫:將明文密碼雜湊成 encoded string(blocking work)
pub async fn hash_password(password: String) -> Result<String> {
// spawn_blocking 避免阻塞
task::spawn_blocking(move || {
// 使用 Argon2id
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
argon2::Version::V0x13,
default_argon2_params(),
);
// 自動產生 salt
let salt = SaltString::generate(&mut OsRng);
// hash
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("argon2 hash error: {}", e))?
.to_string();
Ok(password_hash)
})
.await?
}
/// 非同步呼叫:驗證明文密碼是否與已儲存的 hash 相符
pub async fn verify_password(password: String, password_hash: String) -> Result<bool> {
task::spawn_blocking(move || {
let parsed_hash = PasswordHash::new(&password_hash)
.map_err(|e| anyhow::anyhow!("invalid password hash format: {}", e))?;
let argon2 = Argon2::default();
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
Ok(_) => Ok(true),
Err(argon2::password_hash::Error::Password) => Ok(false),
Err(e) => Err(anyhow::anyhow!("argon2 verify error: {}", e)),
}
})
.await?
}
test/integration_tests.rs
use redis::Client as RedisClient;
use reqwest::StatusCode;
use serde_json::Value;
use sqlx::migrate::MigrateDatabase;
use sqlx::{Executor, PgPool};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use uuid::Uuid;
async fn spawn_app() -> (String, PgPool, redis::aio::MultiplexedConnection) {
// 建立測試資料庫(用 TEST_DATABASE_URL 隔離真實資料庫)
let base_url = std::env::var("TEST_DATABASE_URL_BASE").expect("TEST_DATABASE_URL_BASE");
let db_name = format!("test_db_{}", Uuid::new_v4().to_string().replace("-", ""));
let db_url = format!("{}/{}", base_url, db_name);
if !sqlx::postgres::Postgres::database_exists(&db_url)
.await
.unwrap_or(false)
{
sqlx::Postgres::create_database(&db_url).await.unwrap();
}
let pool = PgPool::connect(&db_url).await.unwrap();
// 執行 migrations
sqlx::migrate!("./migrations").run(&pool).await.unwrap();
// Redis - 指定 TEST_REDIS_URL,隔離開發環境
let redis_url =
std::env::var("TEST_REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/1".to_string());
let redis_client = RedisClient::open(redis_url.as_str()).unwrap();
let redis_conn = redis_client
.get_multiplexed_tokio_connection()
.await
.unwrap();
// 使用不同 port 避免占用真實專案的 port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = sqlx_connect_demo::app::create_app(pool.clone(), redis_conn.clone());
let server_handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{}", addr), pool, redis_conn)
}
#[tokio::test]
async fn integration_create_get_update_delete_user_flow() {
let (base_url, pool, _redis) = spawn_app().await;
let client = reqwest::Client::new();
// POST /users (create)
let res = client
.post(format!("{}/users", base_url))
.json(&serde_json::json!({
"username": "user1",
"email": "user1@a.com",
"password": "password"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::CREATED);
let body: Value = res.json().await.unwrap();
let id = body["id"].as_i64().unwrap();
// GET /users/{id} - 成功
let res = client
.get(format!("{}/users/{}", base_url, id))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// PUT /users/{id} - 成功
let res = client
.put(format!("{}/users/{}", base_url, id))
.json(&serde_json::json!({"username": "new_user1"}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// POST /users/login -> 成功
let res = client
.post(format!("{}/users/login", base_url))
.json(&serde_json::json!({
"username_or_email": "new_user1",
"password": "password"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// DELETE /users/{id} -> 成功
let res = client
.delete(format!("{}/users/{}", base_url, id))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NO_CONTENT);
// GET 已被刪除的資料 -> 404
let res = client
.get(format!("{}/users/{}", base_url, id))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn integration_login_failures() {
let (base_url, _pool, _redis) = spawn_app().await;
let client = reqwest::Client::new();
// 登入不存在的用戶 -> 401
let res = client
.post(format!("{}/users/login", base_url))
.json(&serde_json::json!({
"username_or_email": "not_exists",
"password": "password"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
// 建立用戶 -> 成功
let res = client
.post(format!("{}/users", base_url))
.json(&serde_json::json!({
"username": "user1",
"email": "user1@a.com",
"password": "truepass"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::CREATED);
// 登入,但密碼錯誤 -> 401
let res = client
.post(format!("{}/users/login", base_url))
.json(&serde_json::json!({
"username_or_email": "user1",
"password": "falsepass"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}