完成白名单接口

This commit is contained in:
2025-04-08 09:35:19 +08:00
parent eca58c7032
commit e1c4bb5c03
14 changed files with 290 additions and 21 deletions

View File

@@ -20,6 +20,8 @@
- [ ] Limiter
- [ ] Compress
查端口需要通过外部接口实现,防止不同环境下的端口覆盖。提供一个额外的简便方法用来实现端口覆盖
业务代码和测试代码共用的控制变量可以优化为环境变量
channel 优化:
@@ -45,10 +47,15 @@ oauth token 验证授权范围
短信发送日志
## 环境变量
## 环境变量和脚本
在 init/env 中有定义和默认值
开发环境数据库迁移:
```powershell
pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai"
```
## 枚举字典
### 产品

2
go.mod
View File

@@ -40,7 +40,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/stripe/pg-schema-diff v0.9.0 // indirect
github.com/stretchr/testify v1.8.2 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.59.0 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect

9
go.sum
View File

@@ -81,13 +81,14 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stripe/pg-schema-diff v0.9.0 h1:qzm2VUdbZ2kYwqxoQqtEP3uLQI0B+ymS947zqFTZGBk=
github.com/stripe/pg-schema-diff v0.9.0/go.mod h1:cl2VC6te/cCTOewTRvv4pYsgQqAOhvRQmatCHfYwy8c=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI=

View File

@@ -19,6 +19,7 @@ services:
restart: always
ports:
- "6379:6379"
command: redis-server --requirepass ${REDIS_PASS}
volumes:
- redis_data:/data

View File

@@ -1,4 +1,4 @@
package web
package auth
import (
"context"
@@ -171,3 +171,48 @@ func authBasic(ctx context.Context, token string) (*services.AuthContext, error)
Metadata: nil,
}, nil
}
func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) {
// 获取令牌
var header = c.Get("Authorization")
var split = strings.Split(header, " ")
if len(split) != 2 {
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
}
var token = split[1]
if token == "" {
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
}
var auth *services.AuthContext
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
auth, err = authBasic(c.Context(), token)
default:
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
if err != nil {
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
return auth, nil
}

View File

@@ -18,9 +18,10 @@ type LoginReq struct {
}
type LoginResp struct {
Token string `json:"token"`
Expires int64 `json:"expires"`
Auth services.AuthContext `json:"auth"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expires int64 `json:"expires"`
Auth services.AuthContext `json:"auth"`
}
func Login(c *fiber.Ctx) error {
@@ -105,8 +106,9 @@ func loginByPhone(c *fiber.Ctx, req *LoginReq) error {
}
return c.JSON(LoginResp{
Token: token.AccessToken,
Expires: token.AccessTokenExpires.Unix(),
Auth: auth,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expires: token.AccessTokenExpires.Unix(),
Auth: auth,
})
}

View File

@@ -126,7 +126,10 @@ func refreshToken(c *fiber.Ctx, req *TokenReq) error {
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
if errors.Is(err, services.ErrInvalidToken) {
return sendError(c, services.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)

187
web/handlers/whitelist.go Normal file
View File

@@ -0,0 +1,187 @@
package handlers
import (
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
"platform/web/services"
"time"
"github.com/gofiber/fiber/v2"
)
type ListWhitelistReq struct {
Page int `json:"page" validate:"required"`
Size int `json:"size" validate:"required"`
}
type ListWhitelistResp struct {
Total int64 `json:"total"`
List []*m.Whitelist `json:"list"`
Page int `json:"page"`
Size int `json:"size"`
}
func ListWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(ListWhitelistReq)
if err := c.BodyParser(req); err != nil {
return err
}
var page = req.Page
if page < 1 {
page = 1
}
var size = req.Size
if size < 1 {
size = 10
}
// 获取用户信息
list, err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)).
Offset((page - 1) * size).
Limit(size).
Order(q.Whitelist.CreatedAt.Desc()).
Find()
if err != nil {
return err
}
count, err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)).
Count()
if err != nil {
return err
}
// 返回结果
return c.Status(fiber.StatusOK).JSON(ListWhitelistResp{
Total: count,
List: list,
Page: page,
Size: size,
})
}
type CreateWhitelistReq struct {
Host string `json:"host" validate:"required"`
Remark string `json:"remark"`
}
func CreateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(CreateWhitelistReq)
if err := c.BodyParser(req); err != nil {
return err
}
if req.Host == "" {
return fiber.NewError(fiber.StatusBadRequest, "host is required")
}
// 创建白名单
whitelist := &m.Whitelist{
UserID: authContext.Payload.Id,
Host: req.Host,
Remark: req.Remark,
}
err = q.Whitelist.Create(whitelist)
return nil
}
type UpdateWhitelistReq struct {
ID int32 `json:"id" validate:"required"`
Host string `json:"host"`
Remark string `json:"remark"`
}
func UpdateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(UpdateWhitelistReq)
if err := c.BodyParser(req); err != nil {
return err
}
if req.ID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "id is required")
}
// 更新白名单
_, err = q.Whitelist.
Where(
q.Whitelist.ID.Eq(req.ID),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
).
Updates(&m.Whitelist{
ID: req.ID,
Host: req.Host,
Remark: req.Remark,
})
if err != nil {
return err
}
return nil
}
type RemoveWhitelistReq struct {
ID int32 `json:"id" validate:"required"`
}
func RemoveWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
var req []RemoveWhitelistReq
if err := c.BodyParser(&req); err != nil {
return err
}
if len(req) == 0 {
return fiber.NewError(fiber.StatusBadRequest, "id is required")
}
var ids = make([]int32, len(req))
for i, item := range req {
if item.ID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "id is required")
}
ids[i] = item.ID
}
// 删除白名单
_, err = q.Whitelist.
Where(
q.Whitelist.ID.In(ids...),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
).
Update(
q.Whitelist.DeletedAt, time.Now(),
)
if err != nil {
return err
}
return nil
}

View File

@@ -30,6 +30,7 @@ type Channel struct {
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
ProxyHost string `gorm:"column:proxy_host;not null" json:"proxy_host"`
Node *Node `json:"node"`
User *User `json:"user"`
Proxy *Proxy `json:"proxy"`

View File

@@ -20,6 +20,7 @@ type Whitelist struct {
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
Remark string `gorm:"column:remark" json:"remark"`
}
// TableName Whitelist's table name

View File

@@ -43,6 +43,7 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
_channel.CreatedAt = field.NewTime(tableName, "created_at")
_channel.UpdatedAt = field.NewTime(tableName, "updated_at")
_channel.DeletedAt = field.NewField(tableName, "deleted_at")
_channel.ProxyHost = field.NewString(tableName, "proxy_host")
_channel.fillFieldMap()
@@ -69,6 +70,7 @@ type channel struct {
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
ProxyHost field.String
fieldMap map[string]field.Expr
}
@@ -101,6 +103,7 @@ func (c *channel) updateTableName(table string) *channel {
c.CreatedAt = field.NewTime(table, "created_at")
c.UpdatedAt = field.NewTime(table, "updated_at")
c.DeletedAt = field.NewField(table, "deleted_at")
c.ProxyHost = field.NewString(table, "proxy_host")
c.fillFieldMap()
@@ -117,7 +120,7 @@ func (c *channel) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (c *channel) fillFieldMap() {
c.fieldMap = make(map[string]field.Expr, 16)
c.fieldMap = make(map[string]field.Expr, 17)
c.fieldMap["id"] = c.ID
c.fieldMap["user_id"] = c.UserID
c.fieldMap["proxy_id"] = c.ProxyID
@@ -134,6 +137,7 @@ func (c *channel) fillFieldMap() {
c.fieldMap["created_at"] = c.CreatedAt
c.fieldMap["updated_at"] = c.UpdatedAt
c.fieldMap["deleted_at"] = c.DeletedAt
c.fieldMap["proxy_host"] = c.ProxyHost
}
func (c channel) clone(db *gorm.DB) channel {

View File

@@ -33,6 +33,7 @@ func newWhitelist(db *gorm.DB, opts ...gen.DOOption) whitelist {
_whitelist.CreatedAt = field.NewTime(tableName, "created_at")
_whitelist.UpdatedAt = field.NewTime(tableName, "updated_at")
_whitelist.DeletedAt = field.NewField(tableName, "deleted_at")
_whitelist.Remark = field.NewString(tableName, "remark")
_whitelist.fillFieldMap()
@@ -49,6 +50,7 @@ type whitelist struct {
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
Remark field.String
fieldMap map[string]field.Expr
}
@@ -71,6 +73,7 @@ func (w *whitelist) updateTableName(table string) *whitelist {
w.CreatedAt = field.NewTime(table, "created_at")
w.UpdatedAt = field.NewTime(table, "updated_at")
w.DeletedAt = field.NewField(table, "deleted_at")
w.Remark = field.NewString(table, "remark")
w.fillFieldMap()
@@ -87,13 +90,14 @@ func (w *whitelist) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (w *whitelist) fillFieldMap() {
w.fieldMap = make(map[string]field.Expr, 6)
w.fieldMap = make(map[string]field.Expr, 7)
w.fieldMap["id"] = w.ID
w.fieldMap["user_id"] = w.UserID
w.fieldMap["host"] = w.Host
w.fieldMap["created_at"] = w.CreatedAt
w.fieldMap["updated_at"] = w.UpdatedAt
w.fieldMap["deleted_at"] = w.DeletedAt
w.fieldMap["remark"] = w.Remark
}
func (w whitelist) clone(db *gorm.DB) whitelist {

View File

@@ -1,6 +1,7 @@
package web
import (
auth2 "platform/web/auth"
"platform/web/handlers"
"github.com/gofiber/fiber/v2"
@@ -11,14 +12,21 @@ func ApplyRouters(app *fiber.App) {
// 认证
auth := api.Group("/auth")
auth.Post("/verify/sms", PermitDevice(), handlers.SmsCode)
auth.Post("/login/sms", PermitDevice(), handlers.Login)
auth.Post("/verify/sms", auth2.PermitDevice(), handlers.SmsCode)
auth.Post("/login/sms", auth2.PermitDevice(), handlers.Login)
auth.Post("/token", handlers.Token)
// 通道
channel := api.Group("/channel")
channel.Post("/create", PermitAll(), handlers.CreateChannel)
channel.Post("/remove", PermitAll(), handlers.RemoveChannels)
channel.Post("/create", auth2.PermitAll(), handlers.CreateChannel)
channel.Post("/remove", auth2.PermitAll(), handlers.RemoveChannels)
// 白名单
whitelist := api.Group("/whitelist")
whitelist.Post("/list", handlers.ListWhitelist)
whitelist.Post("/create", handlers.CreateWhitelist)
whitelist.Post("/update", handlers.UpdateWhitelist)
whitelist.Post("/remove", handlers.RemoveWhitelist)
// 临时
app.Get("/collect", handlers.CreateChannelGet)

View File

@@ -77,7 +77,12 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models
// OauthRefreshToken 验证刷新令牌
func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
// TODO: 从数据库验证刷新令牌
return nil, errors.New("TODO")
details, err := Session.Refresh(ctx, refreshToken)
if err != nil {
return nil, err
}
return details, nil
}
type GrantType int