完成白名单接口
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package web
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -171,3 +171,48 @@ func authBasic(ctx context.Context, token string) (*services.AuthContext, error)
|
||||
Metadata: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) {
|
||||
// 获取令牌
|
||||
var header = c.Get("Authorization")
|
||||
var split = strings.Split(header, " ")
|
||||
if len(split) != 2 {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
|
||||
}
|
||||
|
||||
var token = split[1]
|
||||
if token == "" {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
|
||||
}
|
||||
|
||||
var auth *services.AuthContext
|
||||
var err error
|
||||
switch split[0] {
|
||||
case "Bearer":
|
||||
auth, err = authBearer(c.Context(), token)
|
||||
case "Basic":
|
||||
if !slices.Contains(types, services.PayloadClientConfidential) {
|
||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
|
||||
}
|
||||
auth, err = authBasic(c.Context(), token)
|
||||
default:
|
||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if !slices.Contains(types, auth.Payload.Type) {
|
||||
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
|
||||
}
|
||||
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
|
||||
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
|
||||
}
|
||||
|
||||
// 将认证信息存储在上下文中
|
||||
c.Locals("auth", auth)
|
||||
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
@@ -18,9 +18,10 @@ type LoginReq struct {
|
||||
}
|
||||
|
||||
type LoginResp struct {
|
||||
Token string `json:"token"`
|
||||
Expires int64 `json:"expires"`
|
||||
Auth services.AuthContext `json:"auth"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Expires int64 `json:"expires"`
|
||||
Auth services.AuthContext `json:"auth"`
|
||||
}
|
||||
|
||||
func Login(c *fiber.Ctx) error {
|
||||
@@ -105,8 +106,9 @@ func loginByPhone(c *fiber.Ctx, req *LoginReq) error {
|
||||
}
|
||||
|
||||
return c.JSON(LoginResp{
|
||||
Token: token.AccessToken,
|
||||
Expires: token.AccessTokenExpires.Unix(),
|
||||
Auth: auth,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expires: token.AccessTokenExpires.Unix(),
|
||||
Auth: auth,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -126,7 +126,10 @@ func refreshToken(c *fiber.Ctx, req *TokenReq) error {
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
if err != nil {
|
||||
return sendError(c, err.(services.AuthServiceOauthError))
|
||||
if errors.Is(err, services.ErrInvalidToken) {
|
||||
return sendError(c, services.ErrOauthInvalidGrant)
|
||||
}
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
187
web/handlers/whitelist.go
Normal file
187
web/handlers/whitelist.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"platform/web/auth"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/services"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type ListWhitelistReq struct {
|
||||
Page int `json:"page" validate:"required"`
|
||||
Size int `json:"size" validate:"required"`
|
||||
}
|
||||
|
||||
type ListWhitelistResp struct {
|
||||
Total int64 `json:"total"`
|
||||
List []*m.Whitelist `json:"list"`
|
||||
Page int `json:"page"`
|
||||
Size int `json:"size"`
|
||||
}
|
||||
|
||||
func ListWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(ListWhitelistReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
var page = req.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
var size = req.Size
|
||||
if size < 1 {
|
||||
size = 10
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
list, err := q.Whitelist.
|
||||
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)).
|
||||
Offset((page - 1) * size).
|
||||
Limit(size).
|
||||
Order(q.Whitelist.CreatedAt.Desc()).
|
||||
Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := q.Whitelist.
|
||||
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)).
|
||||
Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 返回结果
|
||||
return c.Status(fiber.StatusOK).JSON(ListWhitelistResp{
|
||||
Total: count,
|
||||
List: list,
|
||||
Page: page,
|
||||
Size: size,
|
||||
})
|
||||
}
|
||||
|
||||
type CreateWhitelistReq struct {
|
||||
Host string `json:"host" validate:"required"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(CreateWhitelistReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Host == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "host is required")
|
||||
}
|
||||
|
||||
// 创建白名单
|
||||
whitelist := &m.Whitelist{
|
||||
UserID: authContext.Payload.Id,
|
||||
Host: req.Host,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
|
||||
err = q.Whitelist.Create(whitelist)
|
||||
return nil
|
||||
}
|
||||
|
||||
type UpdateWhitelistReq struct {
|
||||
ID int32 `json:"id" validate:"required"`
|
||||
Host string `json:"host"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func UpdateWhitelist(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(UpdateWhitelistReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.ID == 0 {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
// 更新白名单
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.Eq(req.ID),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
).
|
||||
Updates(&m.Whitelist{
|
||||
ID: req.ID,
|
||||
Host: req.Host,
|
||||
Remark: req.Remark,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RemoveWhitelistReq struct {
|
||||
ID int32 `json:"id" validate:"required"`
|
||||
}
|
||||
|
||||
func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
var req []RemoveWhitelistReq
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(req) == 0 {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
var ids = make([]int32, len(req))
|
||||
for i, item := range req {
|
||||
if item.ID == 0 {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
ids[i] = item.ID
|
||||
}
|
||||
|
||||
// 删除白名单
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.In(ids...),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
).
|
||||
Update(
|
||||
q.Whitelist.DeletedAt, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -30,6 +30,7 @@ type Channel struct {
|
||||
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
ProxyHost string `gorm:"column:proxy_host;not null" json:"proxy_host"`
|
||||
Node *Node `json:"node"`
|
||||
User *User `json:"user"`
|
||||
Proxy *Proxy `json:"proxy"`
|
||||
|
||||
@@ -20,6 +20,7 @@ type Whitelist struct {
|
||||
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
Remark string `gorm:"column:remark" json:"remark"`
|
||||
}
|
||||
|
||||
// TableName Whitelist's table name
|
||||
|
||||
@@ -43,6 +43,7 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
|
||||
_channel.CreatedAt = field.NewTime(tableName, "created_at")
|
||||
_channel.UpdatedAt = field.NewTime(tableName, "updated_at")
|
||||
_channel.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_channel.ProxyHost = field.NewString(tableName, "proxy_host")
|
||||
|
||||
_channel.fillFieldMap()
|
||||
|
||||
@@ -69,6 +70,7 @@ type channel struct {
|
||||
CreatedAt field.Time // 创建时间
|
||||
UpdatedAt field.Time // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
ProxyHost field.String
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -101,6 +103,7 @@ func (c *channel) updateTableName(table string) *channel {
|
||||
c.CreatedAt = field.NewTime(table, "created_at")
|
||||
c.UpdatedAt = field.NewTime(table, "updated_at")
|
||||
c.DeletedAt = field.NewField(table, "deleted_at")
|
||||
c.ProxyHost = field.NewString(table, "proxy_host")
|
||||
|
||||
c.fillFieldMap()
|
||||
|
||||
@@ -117,7 +120,7 @@ func (c *channel) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
}
|
||||
|
||||
func (c *channel) fillFieldMap() {
|
||||
c.fieldMap = make(map[string]field.Expr, 16)
|
||||
c.fieldMap = make(map[string]field.Expr, 17)
|
||||
c.fieldMap["id"] = c.ID
|
||||
c.fieldMap["user_id"] = c.UserID
|
||||
c.fieldMap["proxy_id"] = c.ProxyID
|
||||
@@ -134,6 +137,7 @@ func (c *channel) fillFieldMap() {
|
||||
c.fieldMap["created_at"] = c.CreatedAt
|
||||
c.fieldMap["updated_at"] = c.UpdatedAt
|
||||
c.fieldMap["deleted_at"] = c.DeletedAt
|
||||
c.fieldMap["proxy_host"] = c.ProxyHost
|
||||
}
|
||||
|
||||
func (c channel) clone(db *gorm.DB) channel {
|
||||
|
||||
@@ -33,6 +33,7 @@ func newWhitelist(db *gorm.DB, opts ...gen.DOOption) whitelist {
|
||||
_whitelist.CreatedAt = field.NewTime(tableName, "created_at")
|
||||
_whitelist.UpdatedAt = field.NewTime(tableName, "updated_at")
|
||||
_whitelist.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_whitelist.Remark = field.NewString(tableName, "remark")
|
||||
|
||||
_whitelist.fillFieldMap()
|
||||
|
||||
@@ -49,6 +50,7 @@ type whitelist struct {
|
||||
CreatedAt field.Time // 创建时间
|
||||
UpdatedAt field.Time // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
Remark field.String
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -71,6 +73,7 @@ func (w *whitelist) updateTableName(table string) *whitelist {
|
||||
w.CreatedAt = field.NewTime(table, "created_at")
|
||||
w.UpdatedAt = field.NewTime(table, "updated_at")
|
||||
w.DeletedAt = field.NewField(table, "deleted_at")
|
||||
w.Remark = field.NewString(table, "remark")
|
||||
|
||||
w.fillFieldMap()
|
||||
|
||||
@@ -87,13 +90,14 @@ func (w *whitelist) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
}
|
||||
|
||||
func (w *whitelist) fillFieldMap() {
|
||||
w.fieldMap = make(map[string]field.Expr, 6)
|
||||
w.fieldMap = make(map[string]field.Expr, 7)
|
||||
w.fieldMap["id"] = w.ID
|
||||
w.fieldMap["user_id"] = w.UserID
|
||||
w.fieldMap["host"] = w.Host
|
||||
w.fieldMap["created_at"] = w.CreatedAt
|
||||
w.fieldMap["updated_at"] = w.UpdatedAt
|
||||
w.fieldMap["deleted_at"] = w.DeletedAt
|
||||
w.fieldMap["remark"] = w.Remark
|
||||
}
|
||||
|
||||
func (w whitelist) clone(db *gorm.DB) whitelist {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
auth2 "platform/web/auth"
|
||||
"platform/web/handlers"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -11,14 +12,21 @@ func ApplyRouters(app *fiber.App) {
|
||||
|
||||
// 认证
|
||||
auth := api.Group("/auth")
|
||||
auth.Post("/verify/sms", PermitDevice(), handlers.SmsCode)
|
||||
auth.Post("/login/sms", PermitDevice(), handlers.Login)
|
||||
auth.Post("/verify/sms", auth2.PermitDevice(), handlers.SmsCode)
|
||||
auth.Post("/login/sms", auth2.PermitDevice(), handlers.Login)
|
||||
auth.Post("/token", handlers.Token)
|
||||
|
||||
// 通道
|
||||
channel := api.Group("/channel")
|
||||
channel.Post("/create", PermitAll(), handlers.CreateChannel)
|
||||
channel.Post("/remove", PermitAll(), handlers.RemoveChannels)
|
||||
channel.Post("/create", auth2.PermitAll(), handlers.CreateChannel)
|
||||
channel.Post("/remove", auth2.PermitAll(), handlers.RemoveChannels)
|
||||
|
||||
// 白名单
|
||||
whitelist := api.Group("/whitelist")
|
||||
whitelist.Post("/list", handlers.ListWhitelist)
|
||||
whitelist.Post("/create", handlers.CreateWhitelist)
|
||||
whitelist.Post("/update", handlers.UpdateWhitelist)
|
||||
whitelist.Post("/remove", handlers.RemoveWhitelist)
|
||||
|
||||
// 临时
|
||||
app.Get("/collect", handlers.CreateChannelGet)
|
||||
|
||||
@@ -77,7 +77,12 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models
|
||||
// OauthRefreshToken 验证刷新令牌
|
||||
func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
|
||||
// TODO: 从数据库验证刷新令牌
|
||||
return nil, errors.New("TODO")
|
||||
details, err := Session.Refresh(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
type GrantType int
|
||||
|
||||
Reference in New Issue
Block a user