0

## 梯度下降 用在曲線擬合y=axe^bx

### 2 個回答

0

iT邦研究生 2 級 ‧ 2019-12-20 02:57:17

canon760d iT邦新手 5 級 ‧ 2019-12-20 14:54:37 檢舉

``````# 更正於後面留言
``````
canon760d iT邦新手 5 級 ‧ 2019-12-20 15:44:38 檢舉

EXCEL算出來的最佳值是這樣

``````import math

# 讀取資料點的檔案，該檔案每行為一組 x,y ，以空白分隔
datas=[]
with open(fName, 'r', encoding='utf-8') as f:
for line in f:
arr=line.strip().split()
x=float(arr[0])
y=float(arr[1])
z=math.log(y/x)
datas.append((x,y,z))
return datas
return None

#求解，利用偏微分=0去解聯立
def getSolution(dataPoints):
sumX=0
sumX2=0
sumXY=0
sumY=0
sum1=len(dataPoints)
for p in dataPoints:
x=p[0]
y=p[2]
sumX+=x
sumX2+=x*x
sumXY+=x*y
sumY+=y
d= sumX2*sum1-sumX*sumX
dx=sumXY*sum1-sumY*sumX
dy=sumX2*sumY-sumX*sumXY
return (dx/d, dy/d)

sol=getSolution(dataPoints)
print(sol)
print('a=%f, b=%f' % (math.exp(sol[1]), sol[0]))
``````

canon760d iT邦新手 5 級 ‧ 2019-12-20 19:47:03 檢舉

canon760d iT邦新手 5 級 ‧ 2019-12-20 20:06:13 檢舉

2.268178503 的自然對數是 9.661786 應該還差一些

``````import math

# 讀取資料點的檔案，該檔案每行為一組 x,y ，以空白分隔
datas=[]
with open(fName, 'r', encoding='utf-8') as f:
for line in f:
arr=line.strip().split()
datas.append((float(arr[0]), float(arr[1])))
return datas
return None

#計算在(a,b)的誤差
def Err(a, b, dataPoints):
sum=0
for p in dataPoints:
x=p[0]
y=p[1]
sum+=math.pow(y-a*x*math.exp(b*x), 2)
return sum

#計算在(a,b)的梯度，回傳 (da, db, len) 其中 len 為 (da,db)向量的長度
sumA=0
sumB=0
for p in dataPoints:
x=p[0]
y=p[1]
t=x*math.exp(b*x)
t=2*t*(a*t-y)
sumA+=t
sumB+=a*x*t
return (sumA, sumB, math.sqrt(sumA*sumA+sumB*sumB))

# 利用轉換後的線性方法求近似解作為起始值
def getInitialVal(dataPoints):
sumX=0
sumX2=0
sumXY=0
sumY=0
sum1=len(dataPoints)
for p in dataPoints:
x=p[0]
y=math.log(p[1]/p[0])
sumX+=x
sumX2+=x*x
sumXY+=x*y
sumY+=y
d= sumX2*sum1-sumX*sumX
dx=sumXY*sum1-sumY*sumX
dy=sumX2*sumY-sumX*sumXY
return (math.exp(dy/d), dx/d)

#求解
def getSolution(a, b, r, dataPoints):
err=Err(a, b, dataPoints)
count=0
while(count<2000 and (count==0 or (grad[2]>1e-30 and d>=1e-10))):
count=count+1
print('(a, b, err)=(%f, %f, %f)' % (a , b , err))
print('(da, db, len)=(%f, %f, %e)' % grad)
print(count)
d=r
while(d>=1e-10):
a=newA
b=newB
break
else:
d=d/2
else:
break
return (a,b,err)

initVal=getInitialVal(dataPoints)
sol=getSolution(initVal[0], initVal[1], 1, dataPoints)
print(sol)
``````
canon760d iT邦新手 5 級 ‧ 2019-12-21 01:13:56 檢舉

canon760d iT邦新手 5 級 ‧ 2019-12-25 01:32:38 檢舉

1

iT邦高手 1 級 ‧ 2019-12-19 08:06:24

1. y1/y2 = ( ax1 e^(bx2) ) / ( ax2 e^(bx2) )
2. y1/y2 = ( x1 e^(bx2) ) / ( x2 e^(bx2) )
3. x2y1/x1y2 = ( e^(bx2) ) / ( e^(bx2) ) = e^( b(x1-x2) )
4. b = (ln(x2y1/x1y2)) / (x1-x2)
b出來了回頭帶入原式獲得a
5. y = axe^(bx)
6. a = y / (xe^(bx))
a也出來了

canon760d iT邦新手 5 級 ‧ 2019-12-19 08:48:27 檢舉

a,b的式子已經幫您列出來了，剩下的給計算機算吧~~

1. 資料點是一堆 (x,y) 數對
2. 資料點未必會在擬合的曲線上面
3. 必須找出最佳的曲線，使得 (Δy)^2 的總和為最小