TDD(測試驅動開發 Test Driven Develop)是一種軟體開發方法,它要求在編寫任何程式碼之前,先撰寫測試案例。這樣可以確保程式碼的品質和可靠性,並促進重構和重用。TDD的基本流程是:
相對於 Flink,Airflow 比較容易達成 TDD,那就讓我們來試著寫一個吧。
就跟最一開始的例子一樣,我們來抓個股票的日成交資訊,並整理出前50大股票,寫入某個 db 吧。
資料來源拿 open data 盤後資訊 > 個股日成交資訊 | 政府資料開放平臺 (data.gov.tw) 的這個網址
https://www.twse.com.tw/exchangeReport/STOCK_DAY_ALL?response=open_data
最後的載入 db,就假裝我們有個 postgres 吧
讓我們想想看這個需求的邏輯,來源有可能不是 open data,而是去官網爬蟲,或是其他地方,所以我們先把 source 拆出去。
寫入 db 的部份,我們可以測試是不是筆數剛好 50 筆,並且是前50大。但一樣我們可能是寫到 GCP, AWS, 或是自架 DB,所以一樣拆出來。
def test_dag_flow_filter_max_50_stocks(dump_stock_list: list):
source_dao = MagicMock()
source_dao.return_value = dump_stock_list # 塞個 100 筆假資料進去
sink_dao = MagicMock()
# 主要程式入口
task_main_flow(source_dao, sink_dao)
sorted_list = sorted(dump_stock_list, key=lambda x: x[1], reverse=True)
max_50 = sorted_list[:49]
sink_dao.assert_called_once_with(max_50)
ok,我們寫了第一個測試案例,會 mock source 跟 sink,並塞給主要的程式入口點。
我們確保了主要邏輯是 source_dao 會提供一些資料,過濾出最大50筆後會塞給 sink_dao。
測試 ⇒ 理所當然失敗,畢竟我們主程式都沒寫。
這時候當然來寫個簡單的主程式
with DAG() as dag:
@task
def main_flow():
source_dao = StockDAO.open_data_daily_stock # 先放一個空 method
sink_dao= StockDAO.insert_data_to_postgres # 先放一個空 method
task_main_flow(source_dao, sink_dao)
def task_main_flow(source_dao, sink_dao):
stock_list = source_dao()
sorted_list = sorted(stock_list , key=lambda x: x[1], reverse=True)
max_50 = sorted_list[:49]
sink_dao(max_50)
好了,測試 → 應該會通過,因為兩邊邏輯差不多一模一樣。
接下來我們就能重構了,首先中間的 sort 邏輯,我想再拆出來成另外一個 method,以防我之後想改變需求變抓前100筆,或是看漲幅最大的 10 筆。
with DAG() as dag:
@task
def main_flow():
source_dao = StockDAO.open_data_daily_stock # 先放一個空 method
sink_dao= StockDAO.insert_data_to_postgres # 先放一個空 method
task_main_flow(source_dao, sink_dao)
def task_main_flow(source_dao, sink_dao):
stock_list = source_dao() # E
stocks = filter_stocks(stock_list) # T
sink_dao(stocks) # L
def filter_stocks(stock_list):
sorted_list = sorted(stock_list, key=lambda x: x[1], reverse=True)
max_50 = sorted_list[:49]
return max_50
ok 我現在把中間拆了一個 filter_stocks
method 出來,但我們不針對它寫測試。因為他的主邏輯已被一開始的 dag_flow 測掉了。
現在再跑一次測試,理論上你還是通過的,證明我們的重構並沒有改壞任何東西。
我們在主程式內的 StockDAO.open_data_daily_stock
目前還是空的,讓我們先來想想他的測試吧。
通常我們會 mock 掉 request ,因為我們不需要真的去抓資料,而且我們相信這個 package 會正常運作。
但是外部連線用多了總會遇到鬼,所以我們需要驗證當 response status code 非 200 時,會拋出一個 exception 來警告使用者
import unittest
from unittest.mock import patch
from requests.exceptions import RequestException
from my_module import open_data_daily_stock
class TestOpenDataDailyStock(unittest.TestCase):
@patch('my_module.requests.get')
def test_open_data_daily_stock_request_failure(self, mock_get):
# 模擬一個請求失敗的情況,HTTP狀態碼為404
mock_response = mock_get.return_value
mock_response.status_code = 404
# 測試是否拋出 RequestException
with self.assertRaises(RequestException):
open_data_daily_stock()
if __name__ == "__main__":
unittest.main()
一樣,現在跑測試一定不會過,我們先回頭補上主程式吧。
import requests
import csv
from io import StringIO
def open_data_daily_stock():
# 發送 GET 請求並取得網頁內容
url = "https://www.twse.com.tw/exchangeReport/STOCK_DAY_ALL?response=open_data"
response = requests.get(url)
# 確保請求成功
if response.status_code == 200:
# 將回應的內容解析為 CSV 格式
csv_text = response.text
csv_data = list(csv.reader(StringIO(csv_text), delimiter=','))
# 提取資料行(不含標頭)並將其轉換為元組
data_tuples = [(row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9]) for row in csv_data[1:]]
# 返回轉換後的資料
return data_tuples
else:
# 請求失敗,引發 RequestException
raise requests.exceptions.RequestException(f"請求失敗,狀態碼:{response.status_code}")
現在執行測試,應該會通過了。
當你要測試的對象層級越高,換句話說越抽象,你要測的項目就要隨著調整。同時,應該適當地隔離存取資料的那層,以便之後可以更容易地抽換方法。首先完成滿足業務需求的基本測試,然後再逐步增加實作 method 的細節測試,這樣可以確保你的測試不容易被重構破壞掉。