添加全局验证器,优化白名单创建请求的参数验证
This commit is contained in:
55
web/globals/validator.go
Normal file
55
web/globals/validator.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
zhtrans "github.com/go-playground/validator/v10/translations/zh"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var Validator *ValidatorHolder
|
||||
|
||||
type ValidatorHolder struct {
|
||||
validator *validator.Validate
|
||||
translator ut.Translator
|
||||
}
|
||||
|
||||
func (v *ValidatorHolder) Validate(c *fiber.Ctx, data any) error {
|
||||
|
||||
if err := c.BodyParser(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if errs := v.validator.Struct(data); errs != nil {
|
||||
var sb = strings.Builder{}
|
||||
var typeErrs validator.ValidationErrors
|
||||
errors.As(errs, &typeErrs)
|
||||
for i, err := range typeErrs {
|
||||
sb.WriteString(err.Translate(v.translator))
|
||||
if i < len(typeErrs)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return fiber.NewError(fiber.StatusBadRequest, sb.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InitValidator() {
|
||||
var validate = validator.New(validator.WithRequiredStructEnabled())
|
||||
|
||||
var translator = ut.New(zh.New()).GetFallback()
|
||||
err := zhtrans.RegisterDefaultTranslations(validate, translator)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
Validator = &ValidatorHolder{
|
||||
validator: validate,
|
||||
translator: translator,
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net"
|
||||
"platform/web/auth"
|
||||
"platform/web/common"
|
||||
g "platform/web/globals"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/services"
|
||||
@@ -69,8 +71,8 @@ func ListWhitelist(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
type CreateWhitelistReq struct {
|
||||
Host string `json:"host" validate:"required"`
|
||||
Remark string `json:"remark"`
|
||||
Host string `json:"host" validate:"required,ip"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
}
|
||||
|
||||
func CreateWhitelist(c *fiber.Ctx) error {
|
||||
@@ -83,21 +85,22 @@ func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 解析请求参数
|
||||
req := new(CreateWhitelistReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
err = g.Validator.Validate(c, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Host == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "host is required")
|
||||
|
||||
err = secureAddr(req.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建白名单
|
||||
whitelist := &m.Whitelist{
|
||||
err = q.Whitelist.Create(&m.Whitelist{
|
||||
UserID: authContext.Payload.Id,
|
||||
Host: req.Host,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
|
||||
err = q.Whitelist.Create(whitelist)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -183,3 +186,15 @@ func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func secureAddr(str string) error {
|
||||
var addr = net.ParseIP(str)
|
||||
if addr == nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "IP 解析失败")
|
||||
}
|
||||
|
||||
if addr.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
return fiber.NewError(fiber.StatusBadRequest, "IP 地址不可用")
|
||||
}
|
||||
|
||||
127
web/handlers/whitelist_test.go
Normal file
127
web/handlers/whitelist_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package handlers
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_secureAddr(t *testing.T) {
|
||||
type args struct {
|
||||
str string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// 有效的公网 IP 地址
|
||||
{
|
||||
name: "有效公网IPv4地址",
|
||||
args: args{str: "203.0.113.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效公网IPv6地址",
|
||||
args: args{str: "2001:db8::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// 私有地址
|
||||
{
|
||||
name: "IPv4私有地址(10.x.x.x)",
|
||||
args: args{str: "10.0.0.1"},
|
||||
wantErr: false, // 取决于需求,通常私有地址是被允许的全局单播地址
|
||||
},
|
||||
{
|
||||
name: "IPv4私有地址(172.16.x.x)",
|
||||
args: args{str: "172.16.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv4私有地址(192.168.x.x)",
|
||||
args: args{str: "192.168.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6私有地址(ULA)",
|
||||
args: args{str: "fd00::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// 广播地址
|
||||
{
|
||||
name: "IPv4本地广播地址",
|
||||
args: args{str: "255.255.255.255"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 未指定地址
|
||||
{
|
||||
name: "IPv4未指定地址",
|
||||
args: args{str: "0.0.0.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6未指定地址",
|
||||
args: args{str: "::"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 回环地址
|
||||
{
|
||||
name: "IPv4回环地址",
|
||||
args: args{str: "127.0.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6回环地址",
|
||||
args: args{str: "::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 组播地址
|
||||
{
|
||||
name: "IPv4组播地址",
|
||||
args: args{str: "224.0.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6组播地址",
|
||||
args: args{str: "ff00::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 链路本地地址
|
||||
{
|
||||
name: "IPv4链路本地地址",
|
||||
args: args{str: "169.254.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6链路本地地址",
|
||||
args: args{str: "fe80::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 格式错误的地址
|
||||
{
|
||||
name: "格式错误的IP地址",
|
||||
args: args{str: "not-an-ip"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "不完整的IP地址",
|
||||
args: args{str: "192.168.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超出范围的IP地址",
|
||||
args: args{str: "256.256.256.256"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := secureAddr(tt.args.str); (err != nil) != tt.wantErr {
|
||||
t.Errorf("secureAddr() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -43,10 +43,11 @@ func (s *Server) Run() error {
|
||||
g.InitAlipay()
|
||||
// g.InitWechatPay()
|
||||
g.InitAliyun()
|
||||
g.InitValidator()
|
||||
|
||||
// config
|
||||
s.fiber = fiber.New(fiber.Config{
|
||||
ProxyHeader: fiber.HeaderXForwardedFor,
|
||||
ProxyHeader: fiber.HeaderXForwardedFor,
|
||||
ErrorHandler: ErrorHandler,
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user