iT邦幫忙

0

Day9視覺化

d9
  • 分享至 

  • xImage
  •  

K 線 + 技術指標圖(K 線 + MA + Bollinger Bands + MACD/RSI)
預測結果 vs 真實結果(在 K 線上標出模型預測的買/賣點)

# -------------------------
# Step 4: 訓練 ML 模型
# -------------------------
def train_ml_model_walkforward(X, y, n_splits=5):
    """
    使用 Walk Forward Validation 訓練模型
    :param X: 特徵資料
    :param y: 標籤
    :param n_splits: 分割次數 (k-fold)
    """
    tscv = TimeSeriesSplit(n_splits=n_splits)
    all_scores = []
    split_num = 1

    for train_index, test_index in tscv.split(X):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]

        # 建立模型
        model = RandomForestClassifier(n_estimators=100, random_state=42)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        # 評估
        acc = accuracy_score(y_test, y_pred)
        all_scores.append(acc)

        print(f"=== Split {split_num} ===")
        print(f"Train size: {len(train_index)}, Test size: {len(test_index)}")
        print(f"Accuracy: {acc:.4f}")
        print("Confusion Matrix:")
        print(confusion_matrix(y_test, y_pred))
        print("Classification Report:")
        print(classification_report(y_test, y_pred))
        print("\n")

         # 可視化預測結果
        if visualize and df is not None:
            plot_predictions(df, y_test, y_pred, test_index, 
                             title=f"Predictions vs Actual (Split {split_num})")
        
        split_num += 1

    print("平均準確率:", sum(all_scores) / len(all_scores))
    return all_scores

# -------------------------
# Step 5:畫 K 線 + 技術指標
# -------------------------
def plot_technical_indicators(df, title="BTC/USDT Technical Indicators"):
    fig, axs = plt.subplots(3, 1, figsize=(14,10), sharex=True)

    # --- 子圖1: K 線 + MA + BB ---
    axs[0].plot(df['timestamp'], df['close'], label='Close Price', color='black')
    axs[0].plot(df['timestamp'], df['MA20'], label='MA20', color='blue')
    axs[0].plot(df['timestamp'], df['MA50'], label='MA50', color='orange')
    axs[0].plot(df['timestamp'], df['BB_Upper'], linestyle='--', color='gray', alpha=0.7, label='BB Upper')
    axs[0].plot(df['timestamp'], df['BB_Lower'], linestyle='--', color='gray', alpha=0.7, label='BB Lower')
    axs[0].set_title(title)
    axs[0].legend()

    # --- 子圖2: RSI ---
    axs[1].plot(df['timestamp'], df['RSI'], label='RSI', color='purple')
    axs[1].axhline(70, color='red', linestyle='--', alpha=0.6)
    axs[1].axhline(30, color='green', linestyle='--', alpha=0.6)
    axs[1].set_ylabel("RSI")
    axs[1].legend()

    # --- 子圖3: MACD ---
    axs[2].plot(df['timestamp'], df['MACD'], label='MACD', color='blue')
    axs[2].plot(df['timestamp'], df['Signal'], label='Signal', color='orange')
    axs[2].axhline(0, color='black', linestyle='--', alpha=0.7)
    axs[2].set_ylabel("MACD")
    axs[2].legend()

    # X 軸時間格式
    axs[2].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# -------------------------
# Step 6:預測結果 vs 真實結果
# -------------------------
def plot_predictions(df, y_true, y_pred, test_index, title="Model Predictions vs Actual"):
    """
    在 K 線圖上標記模型預測 (漲/跌)
    :param df: 原始數據 DataFrame
    :param y_true: 測試集真實標籤
    :param y_pred: 模型預測結果
    :param test_index: 測試集索引 (TimeSeriesSplit 給的)
    """
    df_test = df.iloc[test_index].copy()
    df_test['True'] = y_true
    df_test['Pred'] = y_pred

    plt.figure(figsize=(14,6))
    plt.plot(df_test['timestamp'], df_test['close'], label="Close Price", color="black")

    # 預測漲的地方標記綠色 ↑
    plt.scatter(
        df_test.loc[df_test['Pred']==1, 'timestamp'],
        df_test.loc[df_test['Pred']==1, 'close'],
        marker='^', color='green', label='Predicted Up', alpha=0.8
    )

    # 預測跌的地方標記紅色 ↓
    plt.scatter(
        df_test.loc[df_test['Pred']==0, 'timestamp'],
        df_test.loc[df_test['Pred']==0, 'close'],
        marker='v', color='red', label='Predicted Down', alpha=0.8
    )

    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Price")
    plt.legend()
    plt.xticks(rotation=45)
    plt.show()

# -------------------------
# 主程式
# -------------------------
if __name__ == "__main__":
    df = fetch_crypto_data("BTC/USDT", "1h", 500)
    df = add_indicators(df)

    X, y = prepare_ml_data(df)
    # 技術指標圖
    plot_technical_indicators(df)

    # Walk Forward + 預測視覺化
    scores = train_ml_model_walkforward(X, y, df=df, n_splits=5, visualize


圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言