From e1c4bb5c038f5a1025e62b235488760d0b16f7ca Mon Sep 17 00:00:00 2001 From: luorijun Date: Tue, 8 Apr 2025 09:35:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=99=BD=E5=90=8D=E5=8D=95?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 9 +- go.mod | 2 +- go.sum | 9 +- scripts/prev/docker-compose.yaml | 1 + web/{ => auth}/auth.go | 47 +++++++- web/handlers/login.go | 14 ++- web/handlers/oauth.go | 5 +- web/handlers/whitelist.go | 187 +++++++++++++++++++++++++++++++ web/models/channel.gen.go | 1 + web/models/whitelist.gen.go | 1 + web/queries/channel.gen.go | 6 +- web/queries/whitelist.gen.go | 6 +- web/router.go | 16 ++- web/services/auth.go | 7 +- 14 files changed, 290 insertions(+), 21 deletions(-) rename web/{ => auth}/auth.go (74%) create mode 100644 web/handlers/whitelist.go diff --git a/README.md b/README.md index 8097328..04bb3a1 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ - [ ] Limiter - [ ] Compress +查端口需要通过外部接口实现,防止不同环境下的端口覆盖。提供一个额外的简便方法用来实现端口覆盖 + 业务代码和测试代码共用的控制变量可以优化为环境变量 channel 优化: @@ -45,10 +47,15 @@ oauth token 验证授权范围 短信发送日志 -## 环境变量 +## 环境变量和脚本 在 init/env 中有定义和默认值 +开发环境数据库迁移: +```powershell +pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai" +``` + ## 枚举字典 ### 产品 diff --git a/go.mod b/go.mod index 7e02d7c..7071590 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/stripe/pg-schema-diff v0.9.0 // indirect + github.com/stretchr/testify v1.8.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.59.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect diff --git a/go.sum b/go.sum index 9abc75c..e7f962d 100644 --- a/go.sum +++ b/go.sum @@ -81,13 +81,14 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stripe/pg-schema-diff v0.9.0 h1:qzm2VUdbZ2kYwqxoQqtEP3uLQI0B+ymS947zqFTZGBk= -github.com/stripe/pg-schema-diff v0.9.0/go.mod h1:cl2VC6te/cCTOewTRvv4pYsgQqAOhvRQmatCHfYwy8c= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI= diff --git a/scripts/prev/docker-compose.yaml b/scripts/prev/docker-compose.yaml index cfcba4a..d34bb28 100644 --- a/scripts/prev/docker-compose.yaml +++ b/scripts/prev/docker-compose.yaml @@ -19,6 +19,7 @@ services: restart: always ports: - "6379:6379" + command: redis-server --requirepass ${REDIS_PASS} volumes: - redis_data:/data diff --git a/web/auth.go b/web/auth/auth.go similarity index 74% rename from web/auth.go rename to web/auth/auth.go index 89d0115..ece5770 100644 --- a/web/auth.go +++ b/web/auth/auth.go @@ -1,4 +1,4 @@ -package web +package auth import ( "context" @@ -171,3 +171,48 @@ func authBasic(ctx context.Context, token string) (*services.AuthContext, error) Metadata: nil, }, nil } + +func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) { + // 获取令牌 + var header = c.Get("Authorization") + var split = strings.Split(header, " ") + if len(split) != 2 { + return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌") + } + + var token = split[1] + if token == "" { + return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌") + } + + var auth *services.AuthContext + var err error + switch split[0] { + case "Bearer": + auth, err = authBearer(c.Context(), token) + case "Basic": + if !slices.Contains(types, services.PayloadClientConfidential) { + return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + } + auth, err = authBasic(c.Context(), token) + default: + return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + } + if err != nil { + return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + } + + // 检查权限 + if !slices.Contains(types, auth.Payload.Type) { + return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问") + } + if len(permissions) > 0 && !auth.AnyPermission(permissions...) { + return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问") + } + + // 将认证信息存储在上下文中 + c.Locals("auth", auth) + c.Locals("access_token", token) // 存储原始令牌,便于后续操作 + + return auth, nil +} diff --git a/web/handlers/login.go b/web/handlers/login.go index d142df5..c85883a 100644 --- a/web/handlers/login.go +++ b/web/handlers/login.go @@ -18,9 +18,10 @@ type LoginReq struct { } type LoginResp struct { - Token string `json:"token"` - Expires int64 `json:"expires"` - Auth services.AuthContext `json:"auth"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Expires int64 `json:"expires"` + Auth services.AuthContext `json:"auth"` } func Login(c *fiber.Ctx) error { @@ -105,8 +106,9 @@ func loginByPhone(c *fiber.Ctx, req *LoginReq) error { } return c.JSON(LoginResp{ - Token: token.AccessToken, - Expires: token.AccessTokenExpires.Unix(), - Auth: auth, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + Expires: token.AccessTokenExpires.Unix(), + Auth: auth, }) } diff --git a/web/handlers/oauth.go b/web/handlers/oauth.go index 6207bca..e29bc09 100644 --- a/web/handlers/oauth.go +++ b/web/handlers/oauth.go @@ -126,7 +126,10 @@ func refreshToken(c *fiber.Ctx, req *TokenReq) error { scope := strings.Split(req.Scope, ",") token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) if err != nil { - return sendError(c, err.(services.AuthServiceOauthError)) + if errors.Is(err, services.ErrInvalidToken) { + return sendError(c, services.ErrOauthInvalidGrant) + } + return sendError(c, err) } return sendSuccess(c, token) diff --git a/web/handlers/whitelist.go b/web/handlers/whitelist.go new file mode 100644 index 0000000..c1967be --- /dev/null +++ b/web/handlers/whitelist.go @@ -0,0 +1,187 @@ +package handlers + +import ( + "platform/web/auth" + m "platform/web/models" + q "platform/web/queries" + "platform/web/services" + "time" + + "github.com/gofiber/fiber/v2" +) + +type ListWhitelistReq struct { + Page int `json:"page" validate:"required"` + Size int `json:"size" validate:"required"` +} + +type ListWhitelistResp struct { + Total int64 `json:"total"` + List []*m.Whitelist `json:"list"` + Page int `json:"page"` + Size int `json:"size"` +} + +func ListWhitelist(c *fiber.Ctx) error { + + // 检查权限 + authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + if err != nil { + return err + } + + // 解析请求参数 + req := new(ListWhitelistReq) + if err := c.BodyParser(req); err != nil { + return err + } + var page = req.Page + if page < 1 { + page = 1 + } + var size = req.Size + if size < 1 { + size = 10 + } + + // 获取用户信息 + list, err := q.Whitelist. + Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)). + Offset((page - 1) * size). + Limit(size). + Order(q.Whitelist.CreatedAt.Desc()). + Find() + if err != nil { + return err + } + count, err := q.Whitelist. + Where(q.Whitelist.UserID.Eq(authContext.Payload.Id)). + Count() + if err != nil { + return err + } + + // 返回结果 + return c.Status(fiber.StatusOK).JSON(ListWhitelistResp{ + Total: count, + List: list, + Page: page, + Size: size, + }) +} + +type CreateWhitelistReq struct { + Host string `json:"host" validate:"required"` + Remark string `json:"remark"` +} + +func CreateWhitelist(c *fiber.Ctx) error { + + // 检查权限 + authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + if err != nil { + return err + } + + // 解析请求参数 + req := new(CreateWhitelistReq) + if err := c.BodyParser(req); err != nil { + return err + } + if req.Host == "" { + return fiber.NewError(fiber.StatusBadRequest, "host is required") + } + + // 创建白名单 + whitelist := &m.Whitelist{ + UserID: authContext.Payload.Id, + Host: req.Host, + Remark: req.Remark, + } + + err = q.Whitelist.Create(whitelist) + return nil +} + +type UpdateWhitelistReq struct { + ID int32 `json:"id" validate:"required"` + Host string `json:"host"` + Remark string `json:"remark"` +} + +func UpdateWhitelist(c *fiber.Ctx) error { + // 检查权限 + authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + if err != nil { + return err + } + + // 解析请求参数 + req := new(UpdateWhitelistReq) + if err := c.BodyParser(req); err != nil { + return err + } + if req.ID == 0 { + return fiber.NewError(fiber.StatusBadRequest, "id is required") + } + + // 更新白名单 + _, err = q.Whitelist. + Where( + q.Whitelist.ID.Eq(req.ID), + q.Whitelist.UserID.Eq(authContext.Payload.Id), + ). + Updates(&m.Whitelist{ + ID: req.ID, + Host: req.Host, + Remark: req.Remark, + }) + if err != nil { + return err + } + return nil +} + +type RemoveWhitelistReq struct { + ID int32 `json:"id" validate:"required"` +} + +func RemoveWhitelist(c *fiber.Ctx) error { + + // 检查权限 + authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + if err != nil { + return err + } + + // 解析请求参数 + var req []RemoveWhitelistReq + if err := c.BodyParser(&req); err != nil { + return err + } + if len(req) == 0 { + return fiber.NewError(fiber.StatusBadRequest, "id is required") + } + + var ids = make([]int32, len(req)) + for i, item := range req { + if item.ID == 0 { + return fiber.NewError(fiber.StatusBadRequest, "id is required") + } + ids[i] = item.ID + } + + // 删除白名单 + _, err = q.Whitelist. + Where( + q.Whitelist.ID.In(ids...), + q.Whitelist.UserID.Eq(authContext.Payload.Id), + ). + Update( + q.Whitelist.DeletedAt, time.Now(), + ) + if err != nil { + return err + } + return nil +} diff --git a/web/models/channel.gen.go b/web/models/channel.gen.go index 7257b57..46528c5 100644 --- a/web/models/channel.gen.go +++ b/web/models/channel.gen.go @@ -30,6 +30,7 @@ type Channel struct { CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 + ProxyHost string `gorm:"column:proxy_host;not null" json:"proxy_host"` Node *Node `json:"node"` User *User `json:"user"` Proxy *Proxy `json:"proxy"` diff --git a/web/models/whitelist.gen.go b/web/models/whitelist.gen.go index 3e56ac2..7ae8d3b 100644 --- a/web/models/whitelist.gen.go +++ b/web/models/whitelist.gen.go @@ -20,6 +20,7 @@ type Whitelist struct { CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 + Remark string `gorm:"column:remark" json:"remark"` } // TableName Whitelist's table name diff --git a/web/queries/channel.gen.go b/web/queries/channel.gen.go index 5ceccad..341e28f 100644 --- a/web/queries/channel.gen.go +++ b/web/queries/channel.gen.go @@ -43,6 +43,7 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel { _channel.CreatedAt = field.NewTime(tableName, "created_at") _channel.UpdatedAt = field.NewTime(tableName, "updated_at") _channel.DeletedAt = field.NewField(tableName, "deleted_at") + _channel.ProxyHost = field.NewString(tableName, "proxy_host") _channel.fillFieldMap() @@ -69,6 +70,7 @@ type channel struct { CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 + ProxyHost field.String fieldMap map[string]field.Expr } @@ -101,6 +103,7 @@ func (c *channel) updateTableName(table string) *channel { c.CreatedAt = field.NewTime(table, "created_at") c.UpdatedAt = field.NewTime(table, "updated_at") c.DeletedAt = field.NewField(table, "deleted_at") + c.ProxyHost = field.NewString(table, "proxy_host") c.fillFieldMap() @@ -117,7 +120,7 @@ func (c *channel) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (c *channel) fillFieldMap() { - c.fieldMap = make(map[string]field.Expr, 16) + c.fieldMap = make(map[string]field.Expr, 17) c.fieldMap["id"] = c.ID c.fieldMap["user_id"] = c.UserID c.fieldMap["proxy_id"] = c.ProxyID @@ -134,6 +137,7 @@ func (c *channel) fillFieldMap() { c.fieldMap["created_at"] = c.CreatedAt c.fieldMap["updated_at"] = c.UpdatedAt c.fieldMap["deleted_at"] = c.DeletedAt + c.fieldMap["proxy_host"] = c.ProxyHost } func (c channel) clone(db *gorm.DB) channel { diff --git a/web/queries/whitelist.gen.go b/web/queries/whitelist.gen.go index 7f9e612..bfc7025 100644 --- a/web/queries/whitelist.gen.go +++ b/web/queries/whitelist.gen.go @@ -33,6 +33,7 @@ func newWhitelist(db *gorm.DB, opts ...gen.DOOption) whitelist { _whitelist.CreatedAt = field.NewTime(tableName, "created_at") _whitelist.UpdatedAt = field.NewTime(tableName, "updated_at") _whitelist.DeletedAt = field.NewField(tableName, "deleted_at") + _whitelist.Remark = field.NewString(tableName, "remark") _whitelist.fillFieldMap() @@ -49,6 +50,7 @@ type whitelist struct { CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 + Remark field.String fieldMap map[string]field.Expr } @@ -71,6 +73,7 @@ func (w *whitelist) updateTableName(table string) *whitelist { w.CreatedAt = field.NewTime(table, "created_at") w.UpdatedAt = field.NewTime(table, "updated_at") w.DeletedAt = field.NewField(table, "deleted_at") + w.Remark = field.NewString(table, "remark") w.fillFieldMap() @@ -87,13 +90,14 @@ func (w *whitelist) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (w *whitelist) fillFieldMap() { - w.fieldMap = make(map[string]field.Expr, 6) + w.fieldMap = make(map[string]field.Expr, 7) w.fieldMap["id"] = w.ID w.fieldMap["user_id"] = w.UserID w.fieldMap["host"] = w.Host w.fieldMap["created_at"] = w.CreatedAt w.fieldMap["updated_at"] = w.UpdatedAt w.fieldMap["deleted_at"] = w.DeletedAt + w.fieldMap["remark"] = w.Remark } func (w whitelist) clone(db *gorm.DB) whitelist { diff --git a/web/router.go b/web/router.go index 0875840..60cb28f 100644 --- a/web/router.go +++ b/web/router.go @@ -1,6 +1,7 @@ package web import ( + auth2 "platform/web/auth" "platform/web/handlers" "github.com/gofiber/fiber/v2" @@ -11,14 +12,21 @@ func ApplyRouters(app *fiber.App) { // 认证 auth := api.Group("/auth") - auth.Post("/verify/sms", PermitDevice(), handlers.SmsCode) - auth.Post("/login/sms", PermitDevice(), handlers.Login) + auth.Post("/verify/sms", auth2.PermitDevice(), handlers.SmsCode) + auth.Post("/login/sms", auth2.PermitDevice(), handlers.Login) auth.Post("/token", handlers.Token) // 通道 channel := api.Group("/channel") - channel.Post("/create", PermitAll(), handlers.CreateChannel) - channel.Post("/remove", PermitAll(), handlers.RemoveChannels) + channel.Post("/create", auth2.PermitAll(), handlers.CreateChannel) + channel.Post("/remove", auth2.PermitAll(), handlers.RemoveChannels) + + // 白名单 + whitelist := api.Group("/whitelist") + whitelist.Post("/list", handlers.ListWhitelist) + whitelist.Post("/create", handlers.CreateWhitelist) + whitelist.Post("/update", handlers.UpdateWhitelist) + whitelist.Post("/remove", handlers.RemoveWhitelist) // 临时 app.Get("/collect", handlers.CreateChannelGet) diff --git a/web/services/auth.go b/web/services/auth.go index 8f41190..3c8ae5e 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -77,7 +77,12 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models // OauthRefreshToken 验证刷新令牌 func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) { // TODO: 从数据库验证刷新令牌 - return nil, errors.New("TODO") + details, err := Session.Refresh(ctx, refreshToken) + if err != nil { + return nil, err + } + + return details, nil } type GrantType int