iT邦幫忙

2025 iThome 鐵人賽

DAY 11
0
Rust

Rust 實戰專案集:30 個漸進式專案從工具到服務系列 第 11

URL 縮短服務 - 類似 bit.ly 的 URL 縮短器

  • 分享至 

  • xImage
  •  

前言

今天要實作一個 URL 縮短服務,類似 bit.ly、tinyurl 等服務。
這個專案將展示如何使用 Rust 建構一個完整的 Web 服務,包含 REST API、資料存儲、以及 URL redirect 功能。

今天目標

  • 將長 URL 轉換為短 URL
  • 透過短 URL 重定向到原始 URL
  • 記錄訪問統計
  • 提供 REST API 接口

實作概況

Web 框架: Axum (高性能、現代化的 Rust Web 框架)
資料庫: SQLite (輕量級、嵌入式資料庫)
ORM: SQLx (異步 SQL 工具包)
序列化: Serde (JSON 處理)
短碼生成: Base62 編碼

專案結構

url_shortener/
├── Cargo.toml
├── src/
│   ├── main.rs
│   ├── models.rs
│   ├── handlers.rs
│   ├── database.rs
│   └── utils.rs
└── migrations/
    └── 001_initial.sql

依賴

[package]
name = "url_shortener"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.7"
tokio = { version = "1.0", features = ["full"] }
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4"] }
anyhow = "1.0"
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
url = "2.4"
base64 = "0.21"
rand = "0.8"

開始實作

  • 資料結構

src/models.rs

use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;

#[derive(Debug, FromRow, Serialize)]
pub struct Url {
    pub id: i64,
    pub original_url: String,
    pub short_code: String,
    pub clicks: i64,
    pub created_at: DateTime<Utc>,
    pub expires_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Deserialize)]
pub struct CreateUrlRequest {
    pub url: String,
    pub custom_code: Option<String>,
    pub expires_in_days: Option<i64>,
}

#[derive(Debug, Serialize)]
pub struct CreateUrlResponse {
    pub short_url: String,
    pub original_url: String,
    pub expires_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Serialize)]
pub struct UrlStats {
    pub original_url: String,
    pub short_code: String,
    pub clicks: i64,
    pub created_at: DateTime<Utc>,
    pub expires_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Serialize)]
pub struct ErrorResponse {
    pub error: String,
}

基本工具的部分

src/utils.rs

use rand::{thread_rng, Rng};

const BASE62_CHARS: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";

/// 生成隨機的短碼
pub fn generate_short_code(length: usize) -> String {
    let mut rng = thread_rng();
    (0..length)
        .map(|_| {
            let idx = rng.gen_range(0..BASE62_CHARS.len());
            BASE62_CHARS[idx] as char
        })
        .collect()
}

/// 驗證 URL 格式
pub fn validate_url(url: &str) -> Result<(), String> {
    match url::Url::parse(url) {
        Ok(parsed_url) => {
            if parsed_url.scheme() != "http" && parsed_url.scheme() != "https" {
                Err("URL 必須使用 http 或 https 協議".to_string())
            } else {
                Ok(())
            }
        }
        Err(_) => Err("無效的 URL 格式".to_string()),
    }
}

/// 驗證自定義短碼
pub fn validate_custom_code(code: &str) -> Result<(), String> {
    if code.len() < 3 || code.len() > 20 {
        return Err("短碼長度必須在 3-20 字符之間".to_string());
    }

    if !code.chars().all(|c| c.is_ascii_alphanumeric()) {
        return Err("短碼只能包含字母和數字".to_string());
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_generate_short_code() {
        let code = generate_short_code(6);
        assert_eq!(code.len(), 6);
        assert!(code.chars().all(|c| BASE62_CHARS.contains(&(c as u8))));
    }

    #[test]
    fn test_validate_url() {
        assert!(validate_url("https://example.com").is_ok());
        assert!(validate_url("http://example.com").is_ok());
        assert!(validate_url("ftp://example.com").is_err());
        assert!(validate_url("invalid-url").is_err());
    }

    #[test]
    fn test_validate_custom_code() {
        assert!(validate_custom_code("abc123").is_ok());
        assert!(validate_custom_code("ab").is_err()); // 太短
        assert!(validate_custom_code("a".repeat(21).as_str()).is_err()); // 太長
        assert!(validate_custom_code("abc-123").is_err()); // 包含非法字符
    }
}

  • 資料庫的部分

src/database.rs

use anyhow::Result;
use chrono::{DateTime, Utc};
use sqlx::{Row, SqlitePool};

use crate::models::Url;

pub struct Database {
    pool: SqlitePool,
}

impl Database {
    pub async fn new(database_url: &str) -> Result<Self> {
        let pool = SqlitePool::connect(database_url).await?;
        
        // 運行資料庫遷移
        sqlx::migrate!("./migrations").run(&pool).await?;
        
        Ok(Self { pool })
    }

    /// 創建新的短 URL
    pub async fn create_url(
        &self,
        original_url: &str,
        short_code: &str,
        expires_at: Option<DateTime<Utc>>,
    ) -> Result<Url> {
        let now = Utc::now();
        
        let result = sqlx::query!(
            r#"
            INSERT INTO urls (original_url, short_code, clicks, created_at, expires_at)
            VALUES (?, ?, 0, ?, ?)
            "#,
            original_url,
            short_code,
            now,
            expires_at
        )
        .execute(&self.pool)
        .await?;

        let url = Url {
            id: result.last_insert_rowid(),
            original_url: original_url.to_string(),
            short_code: short_code.to_string(),
            clicks: 0,
            created_at: now,
            expires_at,
        };

        Ok(url)
    }

    /// 根據短碼查找 URL
    pub async fn find_by_short_code(&self, short_code: &str) -> Result<Option<Url>> {
        let url = sqlx::query_as!(
            Url,
            "SELECT id, original_url, short_code, clicks, created_at, expires_at FROM urls WHERE short_code = ?",
            short_code
        )
        .fetch_optional(&self.pool)
        .await?;

        Ok(url)
    }

    /// 檢查短碼是否已存在
    pub async fn short_code_exists(&self, short_code: &str) -> Result<bool> {
        let count = sqlx::query!(
            "SELECT COUNT(*) as count FROM urls WHERE short_code = ?",
            short_code
        )
        .fetch_one(&self.pool)
        .await?;

        Ok(count.count > 0)
    }

    /// 增加點擊次數
    pub async fn increment_clicks(&self, short_code: &str) -> Result<()> {
        sqlx::query!(
            "UPDATE urls SET clicks = clicks + 1 WHERE short_code = ?",
            short_code
        )
        .execute(&self.pool)
        .await?;

        Ok(())
    }

    /// 獲取 URL 統計信息
    pub async fn get_url_stats(&self, short_code: &str) -> Result<Option<Url>> {
        self.find_by_short_code(short_code).await
    }

    /// 清理過期的 URL
    pub async fn cleanup_expired_urls(&self) -> Result<u64> {
        let now = Utc::now();
        
        let result = sqlx::query!(
            "DELETE FROM urls WHERE expires_at IS NOT NULL AND expires_at < ?",
            now
        )
        .execute(&self.pool)
        .await?;

        Ok(result.rows_affected())
    }

    /// 獲取所有 URL 列表(分頁)
    pub async fn list_urls(&self, limit: i64, offset: i64) -> Result<Vec<Url>> {
        let urls = sqlx::query_as!(
            Url,
            "SELECT id, original_url, short_code, clicks, created_at, expires_at FROM urls ORDER BY created_at DESC LIMIT ? OFFSET ?",
            limit,
            offset
        )
        .fetch_all(&self.pool)
        .await?;

        Ok(urls)
    }
}

http handler 相關

src/handlers.rs

use axum::{
    extract::{Path, Query, State},
    http::{header, StatusCode},
    response::{IntoResponse, Redirect},
    Json,
};
use chrono::{Duration, Utc};
use serde::Deserialize;
use std::collections::HashMap;

use crate::{
    database::Database,
    models::{CreateUrlRequest, CreateUrlResponse, ErrorResponse, UrlStats},
    utils::{generate_short_code, validate_custom_code, validate_url},
};

pub type AppState = std::sync::Arc<Database>;

/// 創建短 URL
pub async fn create_short_url(
    State(db): State<AppState>,
    Json(payload): Json<CreateUrlRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
    // 驗證 URL 格式
    if let Err(msg) = validate_url(&payload.url) {
        return Err((
            StatusCode::BAD_REQUEST,
            Json(ErrorResponse { error: msg }),
        ));
    }

    // 確定短碼
    let short_code = if let Some(custom_code) = &payload.custom_code {
        // 驗證自定義短碼
        if let Err(msg) = validate_custom_code(custom_code) {
            return Err((
                StatusCode::BAD_REQUEST,
                Json(ErrorResponse { error: msg }),
            ));
        }

        // 檢查是否已存在
        match db.short_code_exists(custom_code).await {
            Ok(true) => {
                return Err((
                    StatusCode::CONFLICT,
                    Json(ErrorResponse {
                        error: "短碼已被使用".to_string(),
                    }),
                ));
            }
            Ok(false) => custom_code.clone(),
            Err(_) => {
                return Err((
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(ErrorResponse {
                        error: "資料庫錯誤".to_string(),
                    }),
                ));
            }
        }
    } else {
        // 生成隨機短碼,確保不重複
        loop {
            let code = generate_short_code(6);
            match db.short_code_exists(&code).await {
                Ok(false) => break code,
                Ok(true) => continue,
                Err(_) => {
                    return Err((
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(ErrorResponse {
                            error: "資料庫錯誤".to_string(),
                        }),
                    ));
                }
            }
        }
    };

    // 計算過期時間
    let expires_at = payload
        .expires_in_days
        .map(|days| Utc::now() + Duration::days(days));

    // 創建 URL 記錄
    match db.create_url(&payload.url, &short_code, expires_at).await {
        Ok(_) => {
            let response = CreateUrlResponse {
                short_url: format!("http://localhost:3000/{}", short_code),
                original_url: payload.url,
                expires_at,
            };
            Ok((StatusCode::CREATED, Json(response)))
        }
        Err(_) => Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "創建短 URL 失敗".to_string(),
            }),
        )),
    }
}

/// redirect 到原始 URL
pub async fn redirect_url(
    State(db): State<AppState>,
    Path(short_code): Path<String>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
    match db.find_by_short_code(&short_code).await {
        Ok(Some(url_record)) => {
            // 檢查是否過期
            if let Some(expires_at) = url_record.expires_at {
                if Utc::now() > expires_at {
                    return Err((
                        StatusCode::NOT_FOUND,
                        Json(ErrorResponse {
                            error: "短 URL 已過期".to_string(),
                        }),
                    ));
                }
            }

            // 增加點擊次數
            if let Err(_) = db.increment_clicks(&short_code).await {
                tracing::warn!("Failed to increment clicks for {}", short_code);
            }

            Ok(Redirect::permanent(&url_record.original_url))
        }
        Ok(None) => Err((
            StatusCode::NOT_FOUND,
            Json(ErrorResponse {
                error: "短 URL 不存在".to_string(),
            }),
        )),
        Err(_) => Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "資料庫錯誤".to_string(),
            }),
        )),
    }
}

/// 獲取 URL 統計
pub async fn get_url_stats(
    State(db): State<AppState>,
    Path(short_code): Path<String>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
    match db.get_url_stats(&short_code).await {
        Ok(Some(url_record)) => {
            let stats = UrlStats {
                original_url: url_record.original_url,
                short_code: url_record.short_code,
                clicks: url_record.clicks,
                created_at: url_record.created_at,
                expires_at: url_record.expires_at,
            };
            Ok(Json(stats))
        }
        Ok(None) => Err((
            StatusCode::NOT_FOUND,
            Json(ErrorResponse {
                error: "短 URL 不存在".to_string(),
            }),
        )),
        Err(_) => Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "資料庫錯誤".to_string(),
            }),
        )),
    }
}

#[derive(Deserialize)]
pub struct ListQuery {
    pub limit: Option<i64>,
    pub offset: Option<i64>,
}

/// 獲取 URL 列表
pub async fn list_urls(
    State(db): State<AppState>,
    Query(params): Query<ListQuery>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
    let limit = params.limit.unwrap_or(10).min(100); // 最多 100 條
    let offset = params.offset.unwrap_or(0);

    match db.list_urls(limit, offset).await {
        Ok(urls) => Ok(Json(urls)),
        Err(_) => Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "獲取 URL 列表失敗".to_string(),
            }),
        )),
    }
}

/// 健康檢查
pub async fn health_check() -> impl IntoResponse {
    Json(serde_json::json!({
        "status": "ok",
        "timestamp": Utc::now()
    }))
}

網頁應用少不了 migration 的部分

migrations/001_initial.sql

-- 創建 URLs 表
CREATE TABLE IF NOT EXISTS urls (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    original_url TEXT NOT NULL,
    short_code TEXT NOT NULL UNIQUE,
    clicks INTEGER NOT NULL DEFAULT 0,
    created_at DATETIME NOT NULL,
    expires_at DATETIME
);

-- 創建索引以提高查詢性能
CREATE INDEX IF NOT EXISTS idx_short_code ON urls(short_code);
CREATE INDEX IF NOT EXISTS idx_created_at ON urls(created_at);
CREATE INDEX IF NOT EXISTS idx_expires_at ON urls(expires_at);

最後 main.rs 組合起來

use axum::{
    extract::MatchedPath,
    http::Request,
    middleware::{self, Next},
    response::Response,
    routing::{get, post},
    Router,
};
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber;

mod database;
mod handlers;
mod models;
mod utils;

use database::Database;
use handlers::{
    create_short_url, get_url_stats, health_check, list_urls, redirect_url, AppState,
};

async fn track_metrics<B>(req: Request<B>, next: Next<B>) -> Response {
    let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
        matched_path.as_str().to_owned()
    } else {
        req.uri().path().to_owned()
    };

    let start = std::time::Instant::now();
    let response = next.run(req).await;
    let latency = start.elapsed();

    tracing::info!(
        "path = {}, status = {}, latency = {:?}",
        path,
        response.status(),
        latency
    );

    response
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // 初始化日誌
    tracing_subscriber::init();

    // 初始化資料庫
    let database_url = std::env::var("DATABASE_URL")
        .unwrap_or_else(|_| "sqlite:./urls.db".to_string());
    
    let db = Database::new(&database_url).await?;
    let app_state = std::sync::Arc::new(db);

    // 定期清理過期 URL 的背景任務
    let cleanup_db = app_state.clone();
    tokio::spawn(async move {
        let mut interval = tokio::time::interval(Duration::from_secs(3600)); // 每小時清理一次
        loop {
            interval.tick().await;
            match cleanup_db.cleanup_expired_urls().await {
                Ok(deleted) => {
                    if deleted > 0 {
                        tracing::info!("Cleaned up {} expired URLs", deleted);
                    }
                }
                Err(e) => tracing::error!("Failed to cleanup expired URLs: {}", e),
            }
        }
    });

    // 構建路由
    let app = Router::new()
        .route("/", get(|| async { "URL Shortener Service" }))
        .route("/health", get(health_check))
        .route("/api/shorten", post(create_short_url))
        .route("/api/urls", get(list_urls))
        .route("/api/stats/:short_code", get(get_url_stats))
        .route("/:short_code", get(redirect_url))
        .layer(
            ServiceBuilder::new()
                .layer(TraceLayer::new_for_http())
                .layer(CorsLayer::permissive())
                .layer(middleware::from_fn(track_metrics)),
        )
        .with_state(app_state);

    // 啟動服務器
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
        .await
        .unwrap();
    
    tracing::info!("Server running on http://0.0.0.0:3000");
    
    axum::serve(listener, app).await.unwrap();

    Ok(())
}

原神啟動!

cargo run

用 curl try try

curl -X POST http://localhost:3000/api/shorten \
  -H "Content-Type: application/json" \
  -d '{
    "url": "https://www.example.com/very/long/url/that/needs/shortening",
    "expires_in_days": 30
  }'

自定義

curl -X POST http://localhost:3000/api/shorten \
  -H "Content-Type: application/json" \
  -d '{
    "url": "https://www.example.com",
    "custom_code": "example"
  }'

這裡我們用簡易的 shell script 測試

#!/bin/bash

echo "=== URL Shortener 測試 ==="

# 1. 健康檢查
echo "1. 健康檢查"
curl -s http://localhost:3000/health | jq .

# 2. 創建短 URL
echo -e "\n2. 創建短 URL"
RESPONSE=$(curl -s -X POST http://localhost:3000/api/shorten \
  -H "Content-Type: application/json" \
  -d '{"url": "https://www.rust-lang.org", "expires_in_days": 7}')
echo $RESPONSE | jq .

# extract 短碼
SHORT_CODE=$(echo $RESPONSE | jq -r '.short_url' | sed 's|http://localhost:3000/||')

# 3. 獲取統計
echo -e "\n3. 獲取統計 ($SHORT_CODE)"
curl -s http://localhost:3000/api/stats/$SHORT_CODE | jq .

# 4. 訪問短 URL(會重定向,但我們只看 headers)
echo -e "\n4. 測試重定向 ($SHORT_CODE)"
curl -s -I http://localhost:3000/$SHORT_CODE

# 5. 再次獲取統計(點擊次數應該增加)
echo -e "\n5. 再次獲取統計 (點擊次數應該增加)"
curl -s http://localhost:3000/api/stats/$SHORT_CODE | jq .

打完收工!


上一篇
天氣查詢 API 客戶端 - 整合第三方天氣服務
下一篇
RSS 訂閱閱讀器 - 抓取並解析 RSS feeds
系列文
Rust 實戰專案集:30 個漸進式專案從工具到服務13
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言