api/transfer.go
package api
import (
"database/sql"
"fmt"
"net/http"
db "github.com/Kcih4518/simpleBank_2023/db/sqlc"
"github.com/gin-gonic/gin"
)
type transferRequest struct {
FromAccountID int64 `json:"from_account_id" binding:"required,min=1"`
ToAccountID int64 `json:"to_account_id" binding:"required,min=1"`
Amount int64 `json:"amount" binding:"required,gt=0"`
Currency string `json:"currency" binding:"required,currency"`
}
func (server *Server) createTransfer(ctx *gin.Context) {
var req transferRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusBadRequest, errorResponse(err))
}
if !server.validAccount(ctx, req.FromAccountID, req.Currency) {
return
}
if !server.validAccount(ctx, req.ToAccountID, req.Currency) {
return
}
arg := db.TransferTxParams{
FromAccountID: req.FromAccountID,
ToAccountID: req.ToAccountID,
Amount: req.Amount,
}
result, err := server.store.TransferTx(ctx, arg)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
ctx.JSON(http.StatusOK, result)
}
func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) bool {
account, err := server.store.GetAccount(ctx, accountID)
if err != nil {
if err == sql.ErrNoRows {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return false
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return false
}
if account.Currency != currency {
err := fmt.Errorf("account [%d] currency mismatch: %s vs %s", account.ID, account.Currency, currency)
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return false
}
return true
}
創建 transferRequest
結構:
api/transfer.go
type transferRequest struct {
FromAccountID int64 `json:"from_account_id" binding:"required,min=1"`
ToAccountID int64 `json:"to_account_id" binding:"required,min=1"`
Amount int64 `json:"amount" binding:"required,gt=0"`
Currency string `json:"currency" binding:"required,oneof=USD EUR CAD"`
}
FromAccountID
(類型:int64
):
ToAccountID
(類型:int64
):
Amount
(類型:int64
):
Currency
(類型:string
):
createTransfer
handler
func (server *Server) createTransfer(ctx *gin.Context) {
var req transferRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return
}
arg := db.TransferTxParams{
FromAccountID: req.FromAccountID,
ToAccountID: req.ToAccountID,
Amount: req.Amount,
}
result, err := server.store.TransferTx(ctx, arg)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
ctx.JSON(http.StatusOK, result)
}
ctx.ShouldBindJSON(&req)
方法來將輸入參數綁定到req
**對象。http.StatusBadRequest
**到客戶端。db.TransferTxParams
**對象,其中:
FromAccountID
是req.FromAccountID
ToAccountID
是req.ToAccountID
Amount
是req.Amount
server.store.TransferTx(ctx, arg)
**以執行金錢轉帳交易。http.StatusInternalServerError
**。TransferTxResult
對象作為回應返回給客戶端,HTTP 狀態碼為http.StatusOK
**。在執行Transfer時要驗證Account的Currency是否與Request的相同,避免無效交易,因此在這實現validAccount
來進行驗証Account的合法性:
func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) bool {
account, err := server.store.GetAccount(ctx, accountID)
if err != nil {
if err == sql.ErrNoRows {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return false
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return false
}
if account.Currency != currency {
err := fmt.Errorf("account [%d] currency mismatch: %s vs %s", account.ID, account.Currency, currency)
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return false
}
return true
}
validAccount
函數是 Server
結構的一部分,接受三個參數:一個 gin.Context
、一個賬戶ID(accountID)和一個貨幣字符串(currency),並返回一個boolean
。server.store.GetAccount()
方法來從數據庫中查詢賬戶信息。這個函數將返回一個賬戶對象或一個錯誤。sql.ErrNoRows
),則向客戶端發送 http.StatusNotFound
狀態碼,並返回 false。http.StatusInternalServerError
狀態碼,並返回 false。Currency
驗証:
http.StatusBadRequest
** 狀態碼通過 ctx.JSON()
方法將錯誤回應發送給客戶端,然後返回 falsetrue
func NewServer(store db.Store) *Server {
server := &Server{store: store}
router := gin.Default()
router.POST("/accounts", server.createAccount)
router.GET("/accounts/:id", server.getAccount)
router.GET("/accounts", server.listAccounts)
router.POST("/transfers", server.createTransfer)
server.router = router
return server
}
api/transfer.go
type transferRequest struct {
FromAccountID int64 `json:"from_account_id" binding:"required,min=1"`
ToAccountID int64 `json:"to_account_id" binding:"required,min=1"`
Amount int64 `json:"amount" binding:"required,gt=0"`
Currency string `json:"currency" binding:"required,oneof=USD EUR CAD"`
}
api/account.go
type createAccountRequest struct {
Owner string `json:"owner" binding:"required"`
Currency string `json:"currency" binding:"required,oneof=USD EUR CAD"`
}
當前的 transferRequest
和 createAccountRequest
結構中,Currency的綁定條件是Hard-code的,只允許 USD, EUR 和 CAD 這三種貨幣。如果將來需要支援100種不同的Currency,將這100種貨幣值都放在 oneof
標籤中將非常難以閱讀和容易出錯。
為了避免這種情況,可以透過一個自定義validator來驗證Currency的類別,所以接下來在 api
文件夾中創建一個新的文件 validator.go
,並在其中宣告一個新的變量 validCurrency
,其類型為 validator.Func
。
api/validator.go
package api
import (
"github.com/go-playground/validator/v10"
"github.com/techschool/simplebank/util"
)
var validCurrency validator.Func = func(fieldLevel validator.FieldLevel) bool {
if currency, ok := fieldLevel.Field().Interface().(string); ok {
return util.IsSupportedCurrency(currency)
}
return false
}
util/currency.go
package util
// Constants for all supported currencies
const (
USD = "USD"
EUR = "EUR"
CAD = "CAD"
)
func IsSupportedCurrency(currency string) bool {
switch currency {
case USD, EUR, CAD:
return true
}
return false
}
var validCurrency validator.Func = func(fieldLevel validator.FieldLevel) bool
這一行定義了一個名為 validCurrency
的變量,並將其指定為一個函數,該函數接受一個 validator.FieldLevel
類型的參數,並返回一個布爾值。
var validCurrency validator.Func = func(fieldLevel validator.FieldLevel) bool
是一個匿名函數(或稱為 lambda 函數)
if currency, ok := fieldLevel.Field().Interface().(string); ok {
這一行首先調用 fieldLevel.Field()
來獲得正在驗證的字段的反射值,然後調用 Interface()
方法來獲得其底層的值,並嘗試將其斷言為字符串。如果成功(即 ok
為真),則進入 if 語句的主體。
return util.IsSupportedCurrency(currency)
如果上面的類型斷言成功,則調用 util.IsSupportedCurrency(currency)
函數來檢查 currency
是否是一個受支援的貨幣類型,並返回結果。
return false
如果類型斷言失敗(即字段的值不是一個字符串),則函數返回 false,表示驗證失敗。
Visual Studio Code 已經自動為我們導入了 validator 套件。然而,我們需要在導入路徑的末尾添加 /v10
,因為我們想要使用這個套件的v10
"github.com/go-playground/validator/v10"
package api
import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
func NewServer(store db.Store) *Server {
server := &Server{store: store}
router := gin.Default()
// register various handler functions
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("currency", validCurrency); err != nil {
// Handle the error, for example, log it or return an error
log.Printf("Failed to register 'currency' custom validation: %v", err)
// You can take appropriate error-handling measures here, such as returning an error message
}
}
router.POST("/accounts", server.createAccount)
router.GET("/accounts/:id", server.getAccount)
router.GET("/accounts", server.listAccounts)
router.POST("/transfers", server.createTransfer)
server.router = router
return server
}
在此,我們在創建 Gin 路由器後,調用 binding.Validator.Engine()
來獲得 Gin 當前正在使用的驗證引擎(binding 是 Gin 的一個子套件)。
注意,這個函數會返回一個通用的介面類型,默認情況下,它是指向 go-playground/validator/v10 套件的 validator 對象的指針。
因此,我們需要將輸出轉換成一個 validator.Validate
物件指針。如果轉換成功(ok 為真),我們可以調用 v.RegisterValidation()
來註冊我們之前實現的自定義驗證函數 validCurrency
。
這個函數的第一個參數是驗證標籤的名稱:currency
,而第二個參數應該是我們之前實現的 validCurrency
函數。
需要加上驗證v.RegisterValidation
是否註冊成功,來避免golng-lint
出現錯誤:
golangci-lint run -v ./...
api/server.go:21:23: Error return value of `v.RegisterValidation` is not checked (errcheck)
v.RegisterValidation("currency", validCurrency)
api/transfer.go
type transferRequest struct {
FromAccountID int64 `json:"from_account_id" binding:"required,min=1"`
ToAccountID int64 `json:"to_account_id" binding:"required,min=1"`
Amount int64 `json:"amount" binding:"required,gt=0"`
Currency string `json:"currency" binding:"required,currency"`
}
api/account.go
type createAccountRequest struct {
Owner string `json:"owner" binding:"required"`
Currency string `json:"currency" binding:"required,currency"`
}
什麼是 reflection
?
在程式設計中,當我們想在運行時更深入地了解一個變量(例如它的類型或它包含的值),或者想動態地修改它的值時,我們會使用一種叫做“reflection
”的技術。
if currency, ok := fieldLevel.Field().Interface().(string); ok {
return util.IsSupportedCurrency(currency)
}
fieldLevel.Field()
- 這一部分是在獲得當前正在檢查的字段的一些“特殊信息”或“背景信息”,我們可以把它想像成一本包含了該字段所有資料的手冊。Interface()
- 這個方法就像是打開那本手冊,讓我們可以查看或提取該字段的實際值(在我們的案例中是貨幣的字符串值)。.Interface().(string)
- 這一部分是我們確定或“聲明”我們期望的值是一個字符串。如果真的是字符串,那麼 ok
會是 true
,然後我們可以繼續檢查該字符串是否是一個受支援的貨幣類型。reflection
的使用情境?
package api
import (
"bytes"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
mockdb "github.com/Kcih4518/simpleBank_2023/db/mock"
db "github.com/Kcih4518/simpleBank_2023/db/sqlc"
"github.com/Kcih4518/simpleBank_2023/util"
"github.com/gin-gonic/gin"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestCreateTransfer(t *testing.T) {
account1 := randomAccount()
account2 := randomAccount()
account3 := randomAccount()
account1.Currency = util.USD
account2.Currency = util.USD
account3.Currency = util.EUR
testCases := []struct {
name string
body gin.H
buildStubs func(store *mockdb.MockStore)
checkResponse func(t *testing.T, recoder *httptest.ResponseRecorder)
}{
{
name: "OK",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(account2, nil)
arg := db.TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: 10,
}
store.EXPECT().TransferTx(gomock.Any(), gomock.Eq(arg)).Times(1)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
},
},
{
name: "FromAccountNotFound",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(0)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
},
},
{
name: "ToAccountNotFound",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
},
},
{
name: "FromAccountCurrencyMismatch",
body: gin.H{
"from_account_id": account3.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account3.ID)).Times(1).Return(account3, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(0)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
},
{
name: "ToAccountCurrencyMismatch",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account3.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account3.ID)).Times(1).Return(account3, nil)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
},
{
name: "NegativeAmount",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": -10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
},
{
name: "InvalidCurrency",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": "TWD",
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
},
{
name: "GetAccountError",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(1).Return(db.Account{}, sql.ErrConnDone)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusInternalServerError, recorder.Code)
},
},
{
name: "TransferTxError",
body: gin.H{
"from_account_id": account1.ID,
"to_account_id": account2.ID,
"amount": 10,
"currency": util.USD,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(account2, nil)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(1).Return(db.TransferTxResult{}, sql.ErrTxDone)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusInternalServerError, recorder.Code)
},
},
}
for i := range testCases {
tc := testCases[i]
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
store := mockdb.NewMockStore(ctrl)
tc.buildStubs(store)
server := NewServer(store)
recorder := httptest.NewRecorder()
// Marshal body data to JSON
data, err := json.Marshal(tc.body)
require.NoError(t, err)
url := "/transfers"
request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
require.NoError(t, err)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}