在網路應用開發中,檔案下載是一個常見需求。當下載大檔案時,網路中斷或程式異常終止都可能導致下載失敗,
需要重新開始。今天我們要實作一個支援斷點續傳的 HTTP 下載器,
能夠在下載中斷後從上次停止的位置繼續下載,大幅提升使用體驗。
這主題是我個人非常有興趣的其中之一主題
HTTP Range
請求允許用戶端請求檔案的特定部分,這是實現斷點續傳的核心技術
GET /large-file.zip HTTP/1.1
Host: example.com
Range: bytes=1024-2047
折時候會收到 response
HTTP/1.1 206 Partial Content
Content-Range: bytes 1024-2047/5000
Content-Length: 1024
以上原理和概述了解之後就開始實作
老樣子
cargo new http_downloader
cd http_downloader
[dependencies]
reqwest = { version = "0.11", features = ["stream"] }
tokio = { version = "1.0", features = ["full"] }
indicatif = "0.17"
clap = { version = "4.0", features = ["derive"] }
anyhow = "1.0"
src/downloader.rs
use anyhow::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use tokio::io::AsyncWriteExt;
use tokio_util::io::StreamReader;
pub struct Downloader {
client: Client,
}
impl Downloader {
pub fn new() -> Self {
let client = Client::new();
Self { client }
}
pub async fn download(
&self,
url: &str,
output_path: &Path,
resume: bool,
) -> Result<()> {
// 檢查檔案是否已存在
let existing_size = if resume && output_path.exists() {
std::fs::metadata(output_path)
.context("Failed to get file metadata")?
.len()
} else {
0
};
println!("Starting download from: {}", url);
if existing_size > 0 {
println!("Resuming from byte: {}", existing_size);
}
// 建立 HTTP 請求
let mut request = self.client.get(url);
if existing_size > 0 {
request = request.header("Range", format!("bytes={}-", existing_size));
}
let response = request
.send()
.await
.context("Failed to send request")?;
// 檢查回應狀態
if !response.status().is_success() && response.status().as_u16() != 206 {
anyhow::bail!("Server returned error: {}", response.status());
}
// 獲取檔案總大小
let content_length = if existing_size > 0 {
// 斷點續傳情況,從 Content-Range 頭部獲取總大小
response
.headers()
.get("content-range")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split('/').nth(1))
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(existing_size + response.content_length().unwrap_or(0))
} else {
response.content_length().unwrap_or(0)
};
// 設置進度條
let progress = ProgressBar::new(content_length);
progress.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"),
);
progress.set_position(existing_size);
// 開啟檔案用於寫入
let mut file = if existing_size > 0 {
OpenOptions::new()
.create(true)
.append(true)
.open(output_path)
.context("Failed to open file for appending")?
} else {
File::create(output_path)
.context("Failed to create output file")?
};
// 下載並寫入檔案
let mut stream = response.bytes_stream();
let mut downloaded = existing_size;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk")?;
file.write_all(&chunk)
.context("Failed to write chunk to file")?;
downloaded += chunk.len() as u64;
progress.set_position(downloaded);
}
progress.finish_with_message("Download completed!");
println!("File saved to: {:?}", output_path);
Ok(())
}
// 檢查伺服器是否支援 Range 請求
pub async fn check_resume_support(&self, url: &str) -> Result<bool> {
let response = self
.client
.head(url)
.send()
.await
.context("Failed to send HEAD request")?;
Ok(response
.headers()
.get("accept-ranges")
.map(|v| v.to_str().unwrap_or("").contains("bytes"))
.unwrap_or(false))
}
// 獲取檔案資訊
pub async fn get_file_info(&self, url: &str) -> Result<FileInfo> {
let response = self
.client
.head(url)
.send()
.await
.context("Failed to send HEAD request")?;
let size = response.content_length();
let filename = response
.headers()
.get("content-disposition")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split("filename=").nth(1))
.map(|s| s.trim_matches('"'))
.or_else(|| url.split('/').last())
.unwrap_or("download")
.to_string();
let supports_resume = response
.headers()
.get("accept-ranges")
.map(|v| v.to_str().unwrap_or("").contains("bytes"))
.unwrap_or(false);
Ok(FileInfo {
filename,
size,
supports_resume,
})
}
}
#[derive(Debug)]
pub struct FileInfo {
pub filename: String,
pub size: Option<u64>,
pub supports_resume: bool,
}
這裡我們做一個 multi-thread downloader
src/multithread_downloader.rs
use anyhow::{Context, Result};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use reqwest::Client;
use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct MultiThreadDownloader {
client: Client,
thread_count: usize,
}
impl MultiThreadDownloader {
pub fn new(thread_count: usize) -> Self {
let client = Client::new();
Self {
client,
thread_count,
}
}
pub async fn download(
&self,
url: &str,
output_path: &Path,
) -> Result<()> {
// 獲取檔案資訊
let file_info = self.get_file_info(url).await?;
if !file_info.supports_resume {
anyhow::bail!("Server doesn't support range requests");
}
let total_size = file_info.size
.context("Cannot determine file size")?;
println!("File size: {} bytes", total_size);
println!("Using {} threads", self.thread_count);
// 計算每個線程的下載範圍
let chunk_size = total_size / self.thread_count as u64;
let mut ranges = Vec::new();
for i in 0..self.thread_count {
let start = i as u64 * chunk_size;
let end = if i == self.thread_count - 1 {
total_size - 1
} else {
(i + 1) as u64 * chunk_size - 1
};
ranges.push((start, end));
}
// 建立臨時檔案
let temp_files: Vec<String> = (0..self.thread_count)
.map(|i| format!("{}.part{}", output_path.display(), i))
.collect();
// 建立進度條
let multi_progress = MultiProgress::new();
let main_progress = multi_progress.add(ProgressBar::new(total_size));
main_progress.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} Total: [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"),
);
// 為每個線程建立進度條
let thread_progress: Vec<_> = (0..self.thread_count)
.map(|i| {
let pb = multi_progress.add(ProgressBar::new(ranges[i].1 - ranges[i].0 + 1));
pb.set_style(
ProgressStyle::default_bar()
.template(&format!("Thread {}: [{{wide_bar:.yellow/red}}] {{bytes}}/{{total_bytes}}", i))?
.progress_chars("#>-"),
);
pb
})
.collect::<Result<Vec<_>, _>>()?;
// 並行下載
let handles: Vec<_> = ranges
.into_iter()
.enumerate()
.map(|(i, (start, end))| {
let client = self.client.clone();
let url = url.to_string();
let temp_file = temp_files[i].clone();
let progress = thread_progress[i].clone();
let main_progress = main_progress.clone();
tokio::spawn(async move {
Self::download_range(
client,
&url,
start,
end,
&temp_file,
progress,
main_progress,
).await
})
})
.collect();
// 等待所有下載完成
for handle in handles {
handle.await
.context("Thread panicked")?
.context("Download failed")?;
}
// 合併檔案
println!("Merging files...");
self.merge_files(&temp_files, output_path)?;
// 清理臨時檔案
for temp_file in &temp_files {
let _ = std::fs::remove_file(temp_file);
}
main_progress.finish_with_message("Download completed!");
println!("File saved to: {:?}", output_path);
Ok(())
}
async fn download_range(
client: Client,
url: &str,
start: u64,
end: u64,
temp_file: &str,
progress: ProgressBar,
main_progress: ProgressBar,
) -> Result<()> {
let response = client
.get(url)
.header("Range", format!("bytes={}-{}", start, end))
.send()
.await
.context("Failed to send range request")?;
if response.status().as_u16() != 206 {
anyhow::bail!("Server doesn't support partial content");
}
let mut file = File::create(temp_file)
.context("Failed to create temp file")?;
let mut stream = response.bytes_stream();
let mut downloaded = 0u64;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk")?;
file.write_all(&chunk)
.context("Failed to write chunk")?;
downloaded += chunk.len() as u64;
progress.inc(chunk.len() as u64);
main_progress.inc(chunk.len() as u64);
}
Ok(())
}
fn merge_files(&self, temp_files: &[String], output_path: &Path) -> Result<()> {
let mut output_file = File::create(output_path)
.context("Failed to create output file")?;
for temp_file in temp_files {
let mut input_file = File::open(temp_file)
.context("Failed to open temp file")?;
std::io::copy(&mut input_file, &mut output_file)
.context("Failed to copy temp file")?;
}
Ok(())
}
async fn get_file_info(&self, url: &str) -> Result<crate::downloader::FileInfo> {
let response = self
.client
.head(url)
.send()
.await
.context("Failed to send HEAD request")?;
let size = response.content_length();
let filename = response
.headers()
.get("content-disposition")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split("filename=").nth(1))
.map(|s| s.trim_matches('"'))
.or_else(|| url.split('/').last())
.unwrap_or("download")
.to_string();
let supports_resume = response
.headers()
.get("accept-ranges")
.map(|v| v.to_str().unwrap_or("").contains("bytes"))
.unwrap_or(false);
Ok(crate::downloader::FileInfo {
filename,
size,
supports_resume,
})
}
}
開始利用 command line tool 組合我們的 Downloader,multi-downloader
src/main.rs
mod downloader;
mod multithread_downloader;
use anyhow::Result;
use clap::{Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser)]
#[command(name = "http-downloader")]
#[command(about = "A HTTP downloader with resume support")]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Download a file with resume support
Download {
/// URL to download
url: String,
/// Output file path
#[arg(short, long)]
output: Option<PathBuf>,
/// Resume download if file exists
#[arg(short, long)]
resume: bool,
},
/// Download using multiple threads
MultiDownload {
/// URL to download
url: String,
/// Output file path
#[arg(short, long)]
output: Option<PathBuf>,
/// Number of threads to use
#[arg(short, long, default_value = "4")]
threads: usize,
},
/// Get file information
Info {
/// URL to check
url: String,
},
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Commands::Download { url, output, resume } => {
let downloader = downloader::Downloader::new();
// 決定輸出路徑
let output_path = if let Some(path) = output {
path
} else {
let file_info = downloader.get_file_info(&url).await?;
PathBuf::from(file_info.filename)
};
downloader.download(&url, &output_path, resume).await?;
}
Commands::MultiDownload { url, output, threads } => {
let downloader = multithread_downloader::MultiThreadDownloader::new(threads);
// 決定輸出路徑
let output_path = if let Some(path) = output {
path
} else {
let file_info = downloader.get_file_info(&url).await?;
PathBuf::from(file_info.filename)
};
downloader.download(&url, &output_path).await?;
}
Commands::Info { url } => {
let downloader = downloader::Downloader::new();
let info = downloader.get_file_info(&url).await?;
println!("File Information:");
println!(" Filename: {}", info.filename);
println!(" Size: {:?}", info.size.map(|s| format!("{} bytes", s)).unwrap_or_else(|| "Unknown".to_string()));
println!(" Supports Resume: {}", info.supports_resume);
}
}
Ok(())
}
# 下載檔案
cargo run -- download https://example.com/large-file.zip
# 指定輸出路徑
cargo run -- download https://example.com/file.zip -o ./downloads/file.zip
# 斷點續傳
cargo run -- download https://example.com/file.zip -o ./file.zip --resume
# 使用4個線程下載
cargo run -- multi-download https://example.com/large-file.zip -t 4
# 使用8個線程下載到指定位置
cargo run -- multi-download https://example.com/file.zip -t 8 -o ./downloads/file.zip
cargo run -- info https://example.com/file.zip
pub struct RetryConfig {
pub max_retries: usize,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
}
}
}
// 在下載器中加入重試邏輯
async fn download_with_retry(&self, url: &str, output_path: &Path, config: &RetryConfig) -> Result<()> {
let mut last_error = None;
for attempt in 0..=config.max_retries {
match self.download(url, output_path, true).await {
Ok(_) => return Ok(()),
Err(e) => {
last_error = Some(e);
if attempt < config.max_retries {
let delay = std::cmp::min(
config.base_delay * 2_u32.pow(attempt as u32),
config.max_delay,
);
println!("Download failed, retrying in {:?}...", delay);
tokio::time::sleep(delay).await;
}
}
}
}
Err(last_error.unwrap())
}
打完收工!