iT邦幫忙

0

梯度下降 用在曲線擬合y=axe^bx

請問梯度下降可以用在非線性y=axe^bx 這一個嗎
現在有資料點,想辦法找出a和b
不知道各位大大有沒有範例 不能使用模組來算,網路上找不到範例
不知道有沒有大大可以指點迷津

2 個回答

0
淺水員
iT邦新手 1 級 ‧ 2019-12-20 02:57:17
最佳解答

https://ithelp.ithome.com.tw/upload/images/20191220/20112943SuIWn4buk2.png
公式部分偏微分上面計算出來有錯,於下面留言中有更正

之後選一組(a,b)作為起點,再配合 梯度下降 選合適的步長逼近最小值

看更多先前的回應...收起先前的回應...
canon760d iT邦新手 5 級 ‧ 2019-12-20 14:54:37 檢舉

我的想法和你一樣 概念是這樣不過code部分我很有障礙..
這部分可以和你討論嗎QQ

淺水員 iT邦新手 1 級 ‧ 2019-12-20 15:29:23 檢舉

我早上稍微寫了一下,覺得這個函式有可能不適合用梯度下降去做
殘餘的程式碼(暫時沒空閒繼續玩,就先放著)

# 更正於後面留言
canon760d iT邦新手 5 級 ‧ 2019-12-20 15:44:38 檢舉

我想請問 假如將這組數據轉為線性 就是ln(y/x) = ln(a) + bx
這樣的做法呢 是否就可以梯度下降 轉換後做出來是參數是2.268178503和-2.473308766 這樣表示以後收集到數據都要先做轉換再帶到這組ln(y/x) = ln(a) + bx 來去做預測呢?

因為非線性的關西所以這組函數的確可能會卡住不是合去做

https://ithelp.ithome.com.tw/upload/images/20191220/201173098fIlwZYHeu.jpg

EXCEL算出來的最佳值是這樣
想知道他們都怎麼fit出來的

謝謝您的回答 萬分感激!

淺水員 iT邦新手 1 級 ‧ 2019-12-20 18:02:22 檢舉

線性的話不用梯度下降,直接偏微分後=0 解聯立就可
轉換後的 2.268178503 跟 -2.473308766 的確算得出來
只是目前我還不知道怎麼找 9.897362 跟 -2.5318...

import math

# 讀取資料點的檔案,該檔案每行為一組 x,y ,以空白分隔
def readFile(fName):
	datas=[]
	with open(fName, 'r', encoding='utf-8') as f:
		for line in f:
			arr=line.strip().split()
			x=float(arr[0])
			y=float(arr[1])
			z=math.log(y/x)
			datas.append((x,y,z))
		return datas
	return None

#求解,利用偏微分=0去解聯立
def getSolution(dataPoints):
	sumX=0
	sumX2=0
	sumXY=0
	sumY=0
	sum1=len(dataPoints)
	for p in dataPoints:
		x=p[0]
		y=p[2]
		sumX+=x
		sumX2+=x*x
		sumXY+=x*y
		sumY+=y
	d= sumX2*sum1-sumX*sumX
	dx=sumXY*sum1-sumY*sumX
	dy=sumX2*sumY-sumX*sumXY
	return (dx/d, dy/d)

dataPoints=readFile('aaa.txt')
sol=getSolution(dataPoints)
print(sol)
print('a=%f, b=%f' % (math.exp(sol[1]), sol[0]))
淺水員 iT邦新手 1 級 ‧ 2019-12-20 18:06:31 檢舉

對了 excel 那個 9.897362 跟 -2.5318 是怎麼求出來的可以教一下嗎?函式名稱是?

canon760d iT邦新手 5 級 ‧ 2019-12-20 19:47:03 檢舉

用excel裡面的規劃求解他會幫你算出來
所以才有這個最佳解

canon760d iT邦新手 5 級 ‧ 2019-12-20 20:06:13 檢舉

希望還是可以用梯度下降來解這一題QQ
對了我發現9.897362是 2.268178503自然對數 也就是推回去而已

淺水員 iT邦新手 1 級 ‧ 2019-12-20 20:14:58 檢舉

2.268178503 的自然對數是 9.661786 應該還差一些
轉換後算出來的很接近正確答案是很合理的
只是還差一些

淺水員 iT邦新手 1 級 ‧ 2019-12-20 22:08:49 檢舉

我有做出來了,上面的公式對 b 的偏微分有計算錯
重新算偏微分的部分應該是:
https://ithelp.ithome.com.tw/upload/images/20191220/20112943oOEzXhqEoS.png

程式碼部分,你再看自己需求調整一下。

import math

# 讀取資料點的檔案,該檔案每行為一組 x,y ,以空白分隔
def readFile(fName):
	datas=[]
	with open(fName, 'r', encoding='utf-8') as f:
		for line in f:
			arr=line.strip().split()
			datas.append((float(arr[0]), float(arr[1])))
		return datas
	return None

#計算在(a,b)的誤差
def Err(a, b, dataPoints):
	sum=0
	for p in dataPoints:
		x=p[0]
		y=p[1]
		sum+=math.pow(y-a*x*math.exp(b*x), 2)
	return sum

#計算在(a,b)的梯度,回傳 (da, db, len) 其中 len 為 (da,db)向量的長度
def Grad(a, b, dataPoints):
	sumA=0
	sumB=0
	for p in dataPoints:
		x=p[0]
		y=p[1]
		t=x*math.exp(b*x)
		t=2*t*(a*t-y)
		sumA+=t
		sumB+=a*x*t
	return (sumA, sumB, math.sqrt(sumA*sumA+sumB*sumB))

# 利用轉換後的線性方法求近似解作為起始值
def getInitialVal(dataPoints):
	sumX=0
	sumX2=0
	sumXY=0
	sumY=0
	sum1=len(dataPoints)
	for p in dataPoints:
		x=p[0]
		y=math.log(p[1]/p[0])
		sumX+=x
		sumX2+=x*x
		sumXY+=x*y
		sumY+=y
	d= sumX2*sum1-sumX*sumX
	dx=sumXY*sum1-sumY*sumX
	dy=sumX2*sumY-sumX*sumXY
	return (math.exp(dy/d), dx/d)

#求解
def getSolution(a, b, r, dataPoints):
	err=Err(a, b, dataPoints)
	count=0
	while(count<2000 and (count==0 or (grad[2]>1e-30 and d>=1e-10))):
		count=count+1
		print('(a, b, err)=(%f, %f, %f)' % (a , b , err))
		grad=Grad(a, b, dataPoints)
		print('(da, db, len)=(%f, %f, %e)' % grad)
		print(count)
		if(grad[2]>1e-10):
			d=r
			while(d>=1e-10):
				newA=a-d*grad[0]/grad[2]
				newB=b-d*grad[1]/grad[2]
				newErr=Err(newA, newB, dataPoints)
				if(newErr <= err):
					err=newErr
					a=newA
					b=newB
					break
				else:
					d=d/2
		else:
			break
	return (a,b,err)

dataPoints=readFile('aaa.txt')
initVal=getInitialVal(dataPoints)
sol=getSolution(initVal[0], initVal[1], 1, dataPoints)
print(sol)
canon760d iT邦新手 5 級 ‧ 2019-12-21 01:13:56 檢舉

大大 你太神了
https://ithelp.ithome.com.tw/upload/images/20191221/20117309Szd6mWAzWL.jpg
想請問這個也是梯度下降對吧
如果是的話這些方法差別在哪邊呢
真的非常感謝大大

淺水員 iT邦新手 1 級 ‧ 2019-12-21 10:31:39 檢舉

他那個不是擬合線性函數 ax+b 嗎?

淺水員 iT邦新手 1 級 ‧ 2019-12-21 10:31:39 檢舉

刪除 (手機發文連按到)

淺水員 iT邦新手 1 級 ‧ 2019-12-21 10:31:39 檢舉

刪除 (手機發文連按到)

淺水員 iT邦新手 1 級 ‧ 2019-12-21 10:31:39 檢舉

刪除 (手機發文連按到)

canon760d iT邦新手 5 級 ‧ 2019-12-25 01:32:38 檢舉

對 因為想說資料轉線性 然後用這個試試看
對這方面的觀念不太好 有錯誤還請多指教QQ
也謝謝大大幫助這麼多 再度感謝阿!!

1
舜~
iT邦好手 1 級 ‧ 2019-12-19 08:06:24

需要解聯立方程式,看起來會用到ln...然後我log/ln忘得差不多了@@

跑去複習指數律、對數律 ...

您說現在有資料點,我就當作您有多組正確的 x,y

  1. y1/y2 = ( ax1 e^(bx2) ) / ( ax2 e^(bx2) )
  2. y1/y2 = ( x1 e^(bx2) ) / ( x2 e^(bx2) )
  3. x2y1/x1y2 = ( e^(bx2) ) / ( e^(bx2) ) = e^( b(x1-x2) )
  4. b = (ln(x2y1/x1y2)) / (x1-x2)
    b出來了回頭帶入原式獲得a
  5. y = axe^(bx)
  6. a = y / (xe^(bx))
    a也出來了

也不用真的算出值,交給計算機XD

看更多先前的回應...收起先前的回應...

數學 /images/emoticon/emoticon46.gif

canon760d iT邦新手 5 級 ‧ 2019-12-19 08:48:27 檢舉

取LN轉線性 不過是想再不轉換成線性情況下做這有辦法做到嗎QQ

舜~ iT邦好手 1 級 ‧ 2019-12-19 10:16:57 檢舉

a,b的式子已經幫您列出來了,剩下的給計算機算吧~~

舜~ iT邦好手 1 級 ‧ 2019-12-19 10:53:47 檢舉

這題目讓我想到以前的數學作業XD

淺水員 iT邦新手 1 級 ‧ 2019-12-20 01:41:03 檢舉

我不知道跟我想的是不是一樣
自己對問題的理解是:

  1. 資料點是一堆 (x,y) 數對
  2. 資料點未必會在擬合的曲線上面
  3. 必須找出最佳的曲線,使得 (Δy)^2 的總和為最小

有點像是 維基百科的圖

我要發表回答

立即登入回答