今天我們將使用 Rust 實現一個功能完整的 WebSocket 聊天室,體驗 Rust 在處理併發連接時的強大能力,
聊天室也是相當常見的應用,過往也嘗試用 go 以及 nodejs 完成聊天室功能並且實作在各種專案中,
今天是我第一次嘗試用 rust 去實現相關功能,也是一個我比較注重的學習目標之一,讓我們開始吧!
私聊
使用者列表我們將使用以下 crate:
tokio: 非同步運行時
warp: 輕量級 Web 框架,內建 WebSocket 支援
tokio-stream: 處理非同步串流
serde: 序列化/反序列化訊息
uuid: 生成唯一使用者 ID
cargo new websocket-chat
cd websocket-chat
依賴
cargo.toml
[package]
name = "websocket-chat"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.40", features = ["full"] }
warp = "0.3"
tokio-stream = "0.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
uuid = { version = "1.10", features = ["v4"] }
futures-util = "0.3"
src/message.rs
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ClientMessage {
Join { username: String },
Message { content: String },
PrivateMessage { to: String, content: String },
Leave,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ServerMessage {
Joined {
user_id: String,
username: String,
},
UserList {
users: Vec<User>,
},
Message {
user_id: String,
username: String,
content: String,
timestamp: i64,
},
PrivateMessage {
from_id: String,
from_name: String,
content: String,
timestamp: i64,
},
UserJoined {
user_id: String,
username: String,
},
UserLeft {
user_id: String,
username: String,
},
Error {
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
}
src/chat_room.rs
use crate::message::{ServerMessage, User};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use warp::ws::Message;
pub type Tx = mpsc::UnboundedSender<Result<Message, warp::Error>>;
pub type Users = Arc<RwLock<HashMap<String, UserConnection>>>;
#[derive(Clone)]
pub struct UserConnection {
pub username: String,
pub tx: Tx,
}
pub struct ChatRoom {
users: Users,
}
impl ChatRoom {
pub fn new() -> Self {
Self {
users: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn users(&self) -> Users {
Arc::clone(&self.users)
}
pub async fn add_user(&self, user_id: String, username: String, tx: Tx) {
let user_conn = UserConnection {
username: username.clone(),
tx,
};
self.users.write().await.insert(user_id.clone(), user_conn);
// 廣播新使用者加入
let message = ServerMessage::UserJoined {
user_id: user_id.clone(),
username: username.clone(),
};
self.broadcast(&user_id, message).await;
}
pub async fn remove_user(&self, user_id: &str) -> Option<String> {
let username = self
.users
.write()
.await
.remove(user_id)
.map(|conn| conn.username);
if let Some(ref name) = username {
let message = ServerMessage::UserLeft {
user_id: user_id.to_string(),
username: name.clone(),
};
self.broadcast(user_id, message).await;
}
username
}
pub async fn broadcast(&self, exclude_id: &str, message: ServerMessage) {
let users = self.users.read().await;
let msg_json = serde_json::to_string(&message).unwrap();
let ws_msg = Message::text(msg_json);
for (user_id, conn) in users.iter() {
if user_id != exclude_id {
let _ = conn.tx.send(Ok(ws_msg.clone()));
}
}
}
pub async fn broadcast_all(&self, message: ServerMessage) {
let users = self.users.read().await;
let msg_json = serde_json::to_string(&message).unwrap();
let ws_msg = Message::text(msg_json);
for conn in users.values() {
let _ = conn.tx.send(Ok(ws_msg.clone()));
}
}
pub async fn send_to_user(&self, user_id: &str, message: ServerMessage) {
let users = self.users.read().await;
if let Some(conn) = users.get(user_id) {
let msg_json = serde_json::to_string(&message).unwrap();
let _ = conn.tx.send(Ok(Message::text(msg_json)));
}
}
pub async fn get_user_list(&self) -> Vec<User> {
self.users
.read()
.await
.iter()
.map(|(id, conn)| User {
id: id.clone(),
username: conn.username.clone(),
})
.collect()
}
pub async fn find_user_id_by_name(&self, username: &str) -> Option<String> {
self.users
.read()
.await
.iter()
.find(|(_, conn)| conn.username == username)
.map(|(id, _)| id.clone())
}
}
src/handler.rs
use crate::chat_room::{ChatRoom, Tx};
use crate::message::{ClientMessage, ServerMessage};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
use warp::ws::{Message, WebSocket};
pub async fn handle_connection(ws: WebSocket, chat_room: Arc<ChatRoom>) {
let (mut ws_tx, mut ws_rx) = ws.split();
let (tx, rx) = mpsc::unbounded_channel();
let mut rx = UnboundedReceiverStream::new(rx);
// 生成唯一使用者 ID
let user_id = Uuid::new_v4().to_string();
let user_id_clone = user_id.clone();
let chat_room_clone = Arc::clone(&chat_room);
// 發送訊息任務
tokio::task::spawn(async move {
while let Some(message) = rx.next().await {
if let Ok(msg) = message {
if ws_tx.send(msg).await.is_err() {
break;
}
}
}
});
let mut username: Option<String> = None;
let mut joined = false;
// 接收訊息任務
while let Some(result) = ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(e) => {
eprintln!("WebSocket error: {}", e);
break;
}
};
if msg.is_close() {
break;
}
if let Ok(text) = msg.to_str() {
if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
match client_msg {
ClientMessage::Join { username: name } => {
if !joined {
username = Some(name.clone());
joined = true;
// 加入聊天室
chat_room_clone
.add_user(user_id.clone(), name.clone(), tx.clone())
.await;
// 發送加入成功訊息
send_message(
&tx,
ServerMessage::Joined {
user_id: user_id.clone(),
username: name.clone(),
},
);
// 發送使用者列表
let users = chat_room_clone.get_user_list().await;
send_message(&tx, ServerMessage::UserList { users });
}
}
ClientMessage::Message { content } => {
if joined {
let timestamp = chrono::Utc::now().timestamp();
let message = ServerMessage::Message {
user_id: user_id.clone(),
username: username.clone().unwrap_or_default(),
content,
timestamp,
};
chat_room_clone.broadcast_all(message).await;
}
}
ClientMessage::PrivateMessage { to, content } => {
if joined {
if let Some(target_id) = chat_room_clone.find_user_id_by_name(&to).await
{
let timestamp = chrono::Utc::now().timestamp();
let message = ServerMessage::PrivateMessage {
from_id: user_id.clone(),
from_name: username.clone().unwrap_or_default(),
content,
timestamp,
};
chat_room_clone.send_to_user(&target_id, message).await;
} else {
send_message(
&tx,
ServerMessage::Error {
message: format!("User '{}' not found", to),
},
);
}
}
}
ClientMessage::Leave => {
break;
}
}
}
}
}
// 使用者離線
if joined {
chat_room_clone.remove_user(&user_id_clone).await;
}
}
fn send_message(tx: &Tx, message: ServerMessage) {
if let Ok(json) = serde_json::to_string(&message) {
let _ = tx.send(Ok(Message::text(json)));
}
}
mod chat_room;
mod handler;
mod message;
use chat_room::ChatRoom;
use std::sync::Arc;
use warp::Filter;
#[tokio::main]
async fn main() {
let chat_room = Arc::new(ChatRoom::new());
// WebSocket 路由
let chat_route = warp::path("ws")
.and(warp::ws())
.and(with_chat_room(chat_room.clone()))
.map(|ws: warp::ws::Ws, chat_room| {
ws.on_upgrade(move |socket| handler::handle_connection(socket, chat_room))
});
// 靜態文件路由
let static_route = warp::path::end().map(|| {
warp::reply::html(include_str!("../static/index.html"))
});
let routes = static_route.or(chat_route);
println!("🚀 WebSocket Chat Server started at http://127.0.0.1:3030");
println!("📝 Open http://127.0.0.1:3030 in your browser");
warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
}
fn with_chat_room(
chat_room: Arc<ChatRoom>,
) -> impl Filter<Extract = (Arc<ChatRoom>,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || chat_room.clone())
}
(因為目前主軸為 rust 所以我 html 用 ai 快速產生)
如果有需要可以用 react
或是 vue
進行前後端分離也是沒問題
static/index.html
<!DOCTYPE html>
<html lang="zh-TW">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 聊天室</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
width: 90%;
max-width: 800px;
height: 600px;
background: white;
border-radius: 15px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
display: flex;
flex-direction: column;
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
text-align: center;
}
.header h1 {
font-size: 24px;
margin-bottom: 5px;
}
.status {
display: flex;
justify-content: space-between;
padding: 10px 20px;
background: #f8f9fa;
border-bottom: 1px solid #dee2e6;
font-size: 14px;
}
.status-indicator {
display: flex;
align-items: center;
gap: 8px;
}
.status-dot {
width: 10px;
height: 10px;
border-radius: 50%;
background: #dc3545;
}
.status-dot.connected {
background: #28a745;
}
.chat-area {
flex: 1;
padding: 20px;
overflow-y: auto;
background: #f8f9fa;
}
.message {
margin-bottom: 15px;
animation: slideIn 0.3s ease;
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message-header {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 5px;
}
.username {
font-weight: bold;
color: #667eea;
}
.timestamp {
font-size: 12px;
color: #6c757d;
}
.message-content {
background: white;
padding: 10px 15px;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.message.system {
text-align: center;
color: #6c757d;
font-style: italic;
}
.message.system .message-content {
background: #e9ecef;
display: inline-block;
}
.input-area {
padding: 20px;
background: white;
border-top: 1px solid #dee2e6;
}
.input-group {
display: flex;
gap: 10px;
}
input[type="text"] {
flex: 1;
padding: 12px 15px;
border: 2px solid #dee2e6;
border-radius: 25px;
font-size: 14px;
outline: none;
transition: border-color 0.3s;
}
input[type="text"]:focus {
border-color: #667eea;
}
button {
padding: 12px 30px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 25px;
cursor: pointer;
font-size: 14px;
font-weight: bold;
transition: transform 0.2s, box-shadow 0.2s;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
button:disabled {
background: #6c757d;
cursor: not-allowed;
transform: none;
}
.login-screen {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100%;
gap: 20px;
}
.login-screen h2 {
color: #667eea;
}
.login-input {
width: 300px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🚀 WebSocket 聊天室</h1>
<p>即時通訊體驗</p>
</div>
<div class="status">
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<span id="statusText">未連接</span>
</div>
<div id="userCount">線上: 0 人</div>
</div>
<div class="chat-area" id="chatArea" style="display: none;">
</div>
<div class="login-screen" id="loginScreen">
<h2>請輸入您的暱稱</h2>
<input type="text" id="usernameInput" class="login-input" placeholder="輸入暱稱..." maxlength="20">
<button onclick="joinChat()">加入聊天室</button>
</div>
<div class="input-area" id="inputArea" style="display: none;">
<div class="input-group">
<input type="text" id="messageInput" placeholder="輸入訊息..." onkeypress="handleKeyPress(event)">
<button onclick="sendMessage()">發送</button>
</div>
</div>
</div>
<script>
let ws;
let username = '';
let userId = '';
function joinChat() {
username = document.getElementById('usernameInput').value.trim();
if (!username) {
alert('請輸入暱稱');
return;
}
connectWebSocket();
}
function connectWebSocket() {
ws = new WebSocket('ws://127.0.0.1:3030/ws');
ws.onopen = () => {
updateStatus(true);
ws.send(JSON.stringify({
type: 'join',
username: username
}));
};
ws.onmessage = (event) => {
const message = JSON.parse(event.data);
handleServerMessage(message);
};
ws.onclose = () => {
updateStatus(false);
addSystemMessage('已斷開連接');
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
addSystemMessage('連接錯誤');
};
}
function handleServerMessage(message) {
switch (message.type) {
case 'joined':
userId = message.user_id;
document.getElementById('loginScreen').style.display = 'none';
document.getElementById('chatArea').style.display = 'block';
document.getElementById('inputArea').style.display = 'block';
addSystemMessage(`歡迎 ${message.username} 加入聊天室!`);
break;
case 'userlist':
document.getElementById('userCount').textContent = `線上: ${message.users.length} 人`;
break;
case 'message':
addMessage(message.username, message.content, message.timestamp);
break;
case 'privatemessage':
addMessage(`${message.from_name} (私訊)`, message.content, message.timestamp);
break;
case 'userjoined':
addSystemMessage(`${message.username} 加入了聊天室`);
break;
case 'userleft':
addSystemMessage(`${message.username} 離開了聊天室`);
break;
case 'error':
addSystemMessage(`錯誤: ${message.message}`);
break;
}
}
function sendMessage() {
const input = document.getElementById('messageInput');
const content = input.value.trim();
if (!content) return;
ws.send(JSON.stringify({
type: 'message',
content: content
}));
input.value = '';
}
function addMessage(username, content, timestamp) {
const chatArea = document.getElementById('chatArea');
const messageDiv = document.createElement('div');
messageDiv.className = 'message';
const time = new Date(timestamp * 1000).toLocaleTimeString('zh-TW', {
hour: '2-digit',
minute: '2-digit'
});
messageDiv.innerHTML = `
<div class="message-header">
<span class="username">${username}</span>
<span class="timestamp">${time}</span>
</div>
<div class="message-content">${escapeHtml(content)}</div>
`;
chatArea.appendChild(messageDiv);
chatArea.scrollTop = chatArea.scrollHeight;
}
function addSystemMessage(content) {
const chatArea = document.getElementById('chatArea');
const messageDiv = document.createElement('div');
messageDiv.className = 'message system';
messageDiv.innerHTML = `<div class="message-content">${content}</div>`;
chatArea.appendChild(messageDiv);
chatArea.scrollTop = chatArea.scrollHeight;
}
function updateStatus(connected) {
const statusDot = document.getElementById('statusDot');
const statusText = document.getElementById('statusText');
if (connected) {
statusDot.classList.add('connected');
statusText.textContent = '已連接';
} else {
statusDot.classList.remove('connected');
statusText.textContent = '未連接';
}
}
function handleKeyPress(event) {
if (event.key === 'Enter') {
sendMessage();
}
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
</script>
</body>
</html>
finish!