介紹完CNN過後,接下來想介紹給各位另外一種神經網路—RNN,這種神經網路和CNN不同之處在於,RNN的資料處理是具有先後順序的,前面所得到的結果會影響後續運算的結果。RNN主要模型分為三種,分別為Simple RNN、LSTM、GRU,差別在於LSTM補強RNN長期記憶的短版,而GRU為LSTM的簡化。接下來我們就要以股票預測的實例來介紹RNN的模型運作。
首先第一步我們需要下載安裝python的twstock套件,這是臺灣股市專用套件,裡面有各種台灣股票的資訊。我們執行以下程式碼進行安裝。
! pip install twstock
接下來我們要利用twstock套件獲得台塑股票的資訊,這裡以2021年做範例。首先第一次執行時建立csv檔,接著我們每次取三個月的股票資訊做儲存,避免我們的IP被鎖住而無法執行程式。每次執行前修改for迴圈的範圍即可,上下標為月份的範圍。
import csv
import twstock
import os
# if file doesn't exist, create file
filepath = "stock2021.csv"
if not os.path.isfile(filepath):
# create columns
title = ["日期", "成交股數", "成交金額", "開盤價", "最高價", "最低價", "收盤價", "漲跌價差", "成交筆數"]
data = []
for i in range(1, 4):
# get the stock through the stock's code
stock = twstock.Stock("1301")
stocklist = stock.fetch(2021, i)
for stock in stocklist:
# convert datetime object into string
strdate = stock.date.strftime("%Y-%m-%d")
li = [strdate, stock.capacity, stock.turnover, stock.open, stock.high, stock.low, stock.close, stock.change, stock.transaction]
data.append(li)
# create the csv file
outputfile = open(filepath, "w", newline = "", encoding = "big5")
outputwriter = csv.writer(outputfile)
# output file
outputwriter.writerow(title)
for dataline in (data):
outputwriter.writerow(dataline)
outputfile.close()
第二次開始執行前修改for迴圈的上下標以及修改檔案輸出模式,從原先的w改為a,也就是從寫入改成增添,這樣就不會覆蓋先前下載過的資訊,而是直接接在後方了。
import csv
import twstock
import os
# if file doesn't exist, create file
filepath = "stock2021.csv"
if os.path.isfile(filepath):
# create columns
data = []
for i in range(10, 13): # change the upper limit and lower limit every time
# get the stock through the stock's code
stock = twstock.Stock("1301")
stocklist = stock.fetch(2021, i)
for stock in stocklist:
# convert datetime object into string
strdate = stock.date.strftime("%Y-%m-%d")
li = [strdate, stock.capacity, stock.turnover, stock.open, stock.high, stock.low, stock.close, stock.change, stock.transaction]
data.append(li)
# create the csv file
outputfile = open(filepath, "a", newline = "", encoding = "big5")
outputwriter = csv.writer(outputfile)
# add title when first add information
for dataline in (data):
outputwriter.writerow(dataline)
outputfile.close()
最後反覆執行就獲得到2021年1月到12月的台塑股票資訊了!