今天要實作一個 URL 縮短服務,類似 bit.ly、tinyurl 等服務。
這個專案將展示如何使用 Rust 建構一個完整的 Web 服務,包含 REST API、資料存儲、以及 URL redirect 功能。
Web 框架: Axum (高性能、現代化的 Rust Web 框架)
資料庫: SQLite (輕量級、嵌入式資料庫)
ORM: SQLx (異步 SQL 工具包)
序列化: Serde (JSON 處理)
短碼生成: Base62 編碼
url_shortener/
├── Cargo.toml
├── src/
│ ├── main.rs
│ ├── models.rs
│ ├── handlers.rs
│ ├── database.rs
│ └── utils.rs
└── migrations/
└── 001_initial.sql
[package]
name = "url_shortener"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.7"
tokio = { version = "1.0", features = ["full"] }
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4"] }
anyhow = "1.0"
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
url = "2.4"
base64 = "0.21"
rand = "0.8"
src/models.rs
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
#[derive(Debug, FromRow, Serialize)]
pub struct Url {
pub id: i64,
pub original_url: String,
pub short_code: String,
pub clicks: i64,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Deserialize)]
pub struct CreateUrlRequest {
pub url: String,
pub custom_code: Option<String>,
pub expires_in_days: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct CreateUrlResponse {
pub short_url: String,
pub original_url: String,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Serialize)]
pub struct UrlStats {
pub original_url: String,
pub short_code: String,
pub clicks: i64,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
}
基本工具的部分
src/utils.rs
use rand::{thread_rng, Rng};
const BASE62_CHARS: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
/// 生成隨機的短碼
pub fn generate_short_code(length: usize) -> String {
let mut rng = thread_rng();
(0..length)
.map(|_| {
let idx = rng.gen_range(0..BASE62_CHARS.len());
BASE62_CHARS[idx] as char
})
.collect()
}
/// 驗證 URL 格式
pub fn validate_url(url: &str) -> Result<(), String> {
match url::Url::parse(url) {
Ok(parsed_url) => {
if parsed_url.scheme() != "http" && parsed_url.scheme() != "https" {
Err("URL 必須使用 http 或 https 協議".to_string())
} else {
Ok(())
}
}
Err(_) => Err("無效的 URL 格式".to_string()),
}
}
/// 驗證自定義短碼
pub fn validate_custom_code(code: &str) -> Result<(), String> {
if code.len() < 3 || code.len() > 20 {
return Err("短碼長度必須在 3-20 字符之間".to_string());
}
if !code.chars().all(|c| c.is_ascii_alphanumeric()) {
return Err("短碼只能包含字母和數字".to_string());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_short_code() {
let code = generate_short_code(6);
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| BASE62_CHARS.contains(&(c as u8))));
}
#[test]
fn test_validate_url() {
assert!(validate_url("https://example.com").is_ok());
assert!(validate_url("http://example.com").is_ok());
assert!(validate_url("ftp://example.com").is_err());
assert!(validate_url("invalid-url").is_err());
}
#[test]
fn test_validate_custom_code() {
assert!(validate_custom_code("abc123").is_ok());
assert!(validate_custom_code("ab").is_err()); // 太短
assert!(validate_custom_code("a".repeat(21).as_str()).is_err()); // 太長
assert!(validate_custom_code("abc-123").is_err()); // 包含非法字符
}
}
src/database.rs
use anyhow::Result;
use chrono::{DateTime, Utc};
use sqlx::{Row, SqlitePool};
use crate::models::Url;
pub struct Database {
pool: SqlitePool,
}
impl Database {
pub async fn new(database_url: &str) -> Result<Self> {
let pool = SqlitePool::connect(database_url).await?;
// 運行資料庫遷移
sqlx::migrate!("./migrations").run(&pool).await?;
Ok(Self { pool })
}
/// 創建新的短 URL
pub async fn create_url(
&self,
original_url: &str,
short_code: &str,
expires_at: Option<DateTime<Utc>>,
) -> Result<Url> {
let now = Utc::now();
let result = sqlx::query!(
r#"
INSERT INTO urls (original_url, short_code, clicks, created_at, expires_at)
VALUES (?, ?, 0, ?, ?)
"#,
original_url,
short_code,
now,
expires_at
)
.execute(&self.pool)
.await?;
let url = Url {
id: result.last_insert_rowid(),
original_url: original_url.to_string(),
short_code: short_code.to_string(),
clicks: 0,
created_at: now,
expires_at,
};
Ok(url)
}
/// 根據短碼查找 URL
pub async fn find_by_short_code(&self, short_code: &str) -> Result<Option<Url>> {
let url = sqlx::query_as!(
Url,
"SELECT id, original_url, short_code, clicks, created_at, expires_at FROM urls WHERE short_code = ?",
short_code
)
.fetch_optional(&self.pool)
.await?;
Ok(url)
}
/// 檢查短碼是否已存在
pub async fn short_code_exists(&self, short_code: &str) -> Result<bool> {
let count = sqlx::query!(
"SELECT COUNT(*) as count FROM urls WHERE short_code = ?",
short_code
)
.fetch_one(&self.pool)
.await?;
Ok(count.count > 0)
}
/// 增加點擊次數
pub async fn increment_clicks(&self, short_code: &str) -> Result<()> {
sqlx::query!(
"UPDATE urls SET clicks = clicks + 1 WHERE short_code = ?",
short_code
)
.execute(&self.pool)
.await?;
Ok(())
}
/// 獲取 URL 統計信息
pub async fn get_url_stats(&self, short_code: &str) -> Result<Option<Url>> {
self.find_by_short_code(short_code).await
}
/// 清理過期的 URL
pub async fn cleanup_expired_urls(&self) -> Result<u64> {
let now = Utc::now();
let result = sqlx::query!(
"DELETE FROM urls WHERE expires_at IS NOT NULL AND expires_at < ?",
now
)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
/// 獲取所有 URL 列表(分頁)
pub async fn list_urls(&self, limit: i64, offset: i64) -> Result<Vec<Url>> {
let urls = sqlx::query_as!(
Url,
"SELECT id, original_url, short_code, clicks, created_at, expires_at FROM urls ORDER BY created_at DESC LIMIT ? OFFSET ?",
limit,
offset
)
.fetch_all(&self.pool)
.await?;
Ok(urls)
}
}
http handler 相關
src/handlers.rs
use axum::{
extract::{Path, Query, State},
http::{header, StatusCode},
response::{IntoResponse, Redirect},
Json,
};
use chrono::{Duration, Utc};
use serde::Deserialize;
use std::collections::HashMap;
use crate::{
database::Database,
models::{CreateUrlRequest, CreateUrlResponse, ErrorResponse, UrlStats},
utils::{generate_short_code, validate_custom_code, validate_url},
};
pub type AppState = std::sync::Arc<Database>;
/// 創建短 URL
pub async fn create_short_url(
State(db): State<AppState>,
Json(payload): Json<CreateUrlRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
// 驗證 URL 格式
if let Err(msg) = validate_url(&payload.url) {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse { error: msg }),
));
}
// 確定短碼
let short_code = if let Some(custom_code) = &payload.custom_code {
// 驗證自定義短碼
if let Err(msg) = validate_custom_code(custom_code) {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse { error: msg }),
));
}
// 檢查是否已存在
match db.short_code_exists(custom_code).await {
Ok(true) => {
return Err((
StatusCode::CONFLICT,
Json(ErrorResponse {
error: "短碼已被使用".to_string(),
}),
));
}
Ok(false) => custom_code.clone(),
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "資料庫錯誤".to_string(),
}),
));
}
}
} else {
// 生成隨機短碼,確保不重複
loop {
let code = generate_short_code(6);
match db.short_code_exists(&code).await {
Ok(false) => break code,
Ok(true) => continue,
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "資料庫錯誤".to_string(),
}),
));
}
}
}
};
// 計算過期時間
let expires_at = payload
.expires_in_days
.map(|days| Utc::now() + Duration::days(days));
// 創建 URL 記錄
match db.create_url(&payload.url, &short_code, expires_at).await {
Ok(_) => {
let response = CreateUrlResponse {
short_url: format!("http://localhost:3000/{}", short_code),
original_url: payload.url,
expires_at,
};
Ok((StatusCode::CREATED, Json(response)))
}
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "創建短 URL 失敗".to_string(),
}),
)),
}
}
/// redirect 到原始 URL
pub async fn redirect_url(
State(db): State<AppState>,
Path(short_code): Path<String>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
match db.find_by_short_code(&short_code).await {
Ok(Some(url_record)) => {
// 檢查是否過期
if let Some(expires_at) = url_record.expires_at {
if Utc::now() > expires_at {
return Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "短 URL 已過期".to_string(),
}),
));
}
}
// 增加點擊次數
if let Err(_) = db.increment_clicks(&short_code).await {
tracing::warn!("Failed to increment clicks for {}", short_code);
}
Ok(Redirect::permanent(&url_record.original_url))
}
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "短 URL 不存在".to_string(),
}),
)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "資料庫錯誤".to_string(),
}),
)),
}
}
/// 獲取 URL 統計
pub async fn get_url_stats(
State(db): State<AppState>,
Path(short_code): Path<String>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
match db.get_url_stats(&short_code).await {
Ok(Some(url_record)) => {
let stats = UrlStats {
original_url: url_record.original_url,
short_code: url_record.short_code,
clicks: url_record.clicks,
created_at: url_record.created_at,
expires_at: url_record.expires_at,
};
Ok(Json(stats))
}
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "短 URL 不存在".to_string(),
}),
)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "資料庫錯誤".to_string(),
}),
)),
}
}
#[derive(Deserialize)]
pub struct ListQuery {
pub limit: Option<i64>,
pub offset: Option<i64>,
}
/// 獲取 URL 列表
pub async fn list_urls(
State(db): State<AppState>,
Query(params): Query<ListQuery>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let limit = params.limit.unwrap_or(10).min(100); // 最多 100 條
let offset = params.offset.unwrap_or(0);
match db.list_urls(limit, offset).await {
Ok(urls) => Ok(Json(urls)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "獲取 URL 列表失敗".to_string(),
}),
)),
}
}
/// 健康檢查
pub async fn health_check() -> impl IntoResponse {
Json(serde_json::json!({
"status": "ok",
"timestamp": Utc::now()
}))
}
migrations/001_initial.sql
-- 創建 URLs 表
CREATE TABLE IF NOT EXISTS urls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_url TEXT NOT NULL,
short_code TEXT NOT NULL UNIQUE,
clicks INTEGER NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL,
expires_at DATETIME
);
-- 創建索引以提高查詢性能
CREATE INDEX IF NOT EXISTS idx_short_code ON urls(short_code);
CREATE INDEX IF NOT EXISTS idx_created_at ON urls(created_at);
CREATE INDEX IF NOT EXISTS idx_expires_at ON urls(expires_at);
最後 main.rs
組合起來
use axum::{
extract::MatchedPath,
http::Request,
middleware::{self, Next},
response::Response,
routing::{get, post},
Router,
};
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber;
mod database;
mod handlers;
mod models;
mod utils;
use database::Database;
use handlers::{
create_short_url, get_url_stats, health_check, list_urls, redirect_url, AppState,
};
async fn track_metrics<B>(req: Request<B>, next: Next<B>) -> Response {
let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
matched_path.as_str().to_owned()
} else {
req.uri().path().to_owned()
};
let start = std::time::Instant::now();
let response = next.run(req).await;
let latency = start.elapsed();
tracing::info!(
"path = {}, status = {}, latency = {:?}",
path,
response.status(),
latency
);
response
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// 初始化日誌
tracing_subscriber::init();
// 初始化資料庫
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "sqlite:./urls.db".to_string());
let db = Database::new(&database_url).await?;
let app_state = std::sync::Arc::new(db);
// 定期清理過期 URL 的背景任務
let cleanup_db = app_state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // 每小時清理一次
loop {
interval.tick().await;
match cleanup_db.cleanup_expired_urls().await {
Ok(deleted) => {
if deleted > 0 {
tracing::info!("Cleaned up {} expired URLs", deleted);
}
}
Err(e) => tracing::error!("Failed to cleanup expired URLs: {}", e),
}
}
});
// 構建路由
let app = Router::new()
.route("/", get(|| async { "URL Shortener Service" }))
.route("/health", get(health_check))
.route("/api/shorten", post(create_short_url))
.route("/api/urls", get(list_urls))
.route("/api/stats/:short_code", get(get_url_stats))
.route("/:short_code", get(redirect_url))
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
.layer(middleware::from_fn(track_metrics)),
)
.with_state(app_state);
// 啟動服務器
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
.await
.unwrap();
tracing::info!("Server running on http://0.0.0.0:3000");
axum::serve(listener, app).await.unwrap();
Ok(())
}
cargo run
curl -X POST http://localhost:3000/api/shorten \
-H "Content-Type: application/json" \
-d '{
"url": "https://www.example.com/very/long/url/that/needs/shortening",
"expires_in_days": 30
}'
自定義
curl -X POST http://localhost:3000/api/shorten \
-H "Content-Type: application/json" \
-d '{
"url": "https://www.example.com",
"custom_code": "example"
}'
這裡我們用簡易的 shell script 測試
#!/bin/bash
echo "=== URL Shortener 測試 ==="
# 1. 健康檢查
echo "1. 健康檢查"
curl -s http://localhost:3000/health | jq .
# 2. 創建短 URL
echo -e "\n2. 創建短 URL"
RESPONSE=$(curl -s -X POST http://localhost:3000/api/shorten \
-H "Content-Type: application/json" \
-d '{"url": "https://www.rust-lang.org", "expires_in_days": 7}')
echo $RESPONSE | jq .
# extract 短碼
SHORT_CODE=$(echo $RESPONSE | jq -r '.short_url' | sed 's|http://localhost:3000/||')
# 3. 獲取統計
echo -e "\n3. 獲取統計 ($SHORT_CODE)"
curl -s http://localhost:3000/api/stats/$SHORT_CODE | jq .
# 4. 訪問短 URL(會重定向,但我們只看 headers)
echo -e "\n4. 測試重定向 ($SHORT_CODE)"
curl -s -I http://localhost:3000/$SHORT_CODE
# 5. 再次獲取統計(點擊次數應該增加)
echo -e "\n5. 再次獲取統計 (點擊次數應該增加)"
curl -s http://localhost:3000/api/stats/$SHORT_CODE | jq .
打完收工!