package auth import ( "context" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "platform/pkg/env" "platform/pkg/u" g "platform/web/globals" "platform/web/globals/orm" m "platform/web/models" q "platform/web/queries" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "gorm.io/gorm" ) // AuthorizeGet 授权端点 func AuthorizeGet(ctx *fiber.Ctx) error { // 检查请求 req := new(AuthorizeGetReq) if err := g.Validator.ParseQuery(ctx, req); err != nil { return err } // 检查客户端 client, err := authClient(req.ClientID) if err != nil { return err } if client.RedirectURI == nil || *client.RedirectURI != req.RedirectURI { return errors.New("客户端重定向URI错误") } // todo 检查 scope // 授权确认页面 return nil } type AuthorizeGetReq struct { ResponseType string `json:"response_type" validate:"eq=code"` ClientID string `json:"client_id" validate:"required"` RedirectURI string `json:"redirect_uri" validate:"required"` Scope string `json:"scope"` State string `json:"state"` } func AuthorizePost(ctx *fiber.Ctx) error { // todo 解析用户授权的范围 return nil } type AuthorizePostReq struct { Accept bool `json:"accept"` Scope string `json:"scope"` } // Token 令牌端点 func Token(c *fiber.Ctx) error { now := time.Now() // 验证请求参数 req := new(TokenReq) if err := c.BodyParser(req); err != nil { return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数") } if req.GrantType == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type") } switch req.GrantType { // 授权码模式 case GrantAuthorizationCode: if req.Code == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code") } // 刷新令牌模式 case GrantRefreshToken: if req.RefreshToken == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token") } // 密码模式 case GrantPassword: if req.LoginType == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type") } if req.Username == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username") } if req.Password == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password") } } // 验证客户端身份 authCtx := GetAuthCtx(c) if authCtx == nil { authCtx = &AuthCtx{} } if authCtx.Client == nil { client, err := authClient(req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } authCtx.Client = client } // 处理授权 var session *m.Session var err error switch req.GrantType { // 授权码模式 case GrantAuthorizationCode: session, err = authAuthorizationCode(c, authCtx, req, now) // 客户端凭证模式 case GrantClientCredentials: session, err = authClientCredential(c, authCtx, req, now) // 刷新令牌模式 case GrantRefreshToken: session, err = authRefreshToken(c, authCtx, req, now) // 密码模式 case GrantPassword: session, err = authPassword(c, authCtx, req, now) default: return sendError(c, ErrAuthorizeUnsupportedGrantType) } if err != nil { return sendError(c, err) } // 返回响应 return c.JSON(&TokenResp{ TokenType: "Bearer", AccessToken: session.AccessToken, RefreshToken: u.Z(session.RefreshToken), ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()), Scope: u.Z(session.Scopes), }) } type TokenReq struct { GrantType GrantType `json:"grant_type" form:"grant_type"` ClientID string `json:"client_id" form:"client_id"` ClientSecret string `json:"client_secret" form:"client_secret"` Scope string `json:"scope" form:"scope"` GrantCodeData GrantClientData GrantRefreshData GrantPasswordData } type GrantCodeData struct { Code string `json:"code" form:"code"` RedirectURI string `json:"redirect_uri" form:"redirect_uri"` CodeVerifier string `json:"code_verifier" form:"code_verifier"` } type GrantClientData struct { } type GrantRefreshData struct { RefreshToken string `json:"refresh_token" form:"refresh_token"` } type GrantPasswordData struct { LoginType PwdLoginType `json:"login_type" form:"login_type"` LoginPool PwdLoginPool `json:"login_pool" form:"login_pool"` Username string `json:"username" form:"username"` Password string `json:"password" form:"password"` Remember bool `json:"remember" form:"remember"` } type GrantType string const ( GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式 GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式 GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式 GrantPassword = GrantType("password") // 密码模式(私有扩展) ) type PwdLoginType string const ( PwdLoginByPassword = PwdLoginType("password") // 账号密码 PwdLoginByPhone = PwdLoginType("phone_code") // 手机验证码 PwdLoginByEmail = PwdLoginType("email_code") // 邮箱验证码 ) type PwdLoginPool string const ( PwdLoginAsUser = PwdLoginPool("user") // 用户池 PwdLoginAsAdmin = PwdLoginPool("admin") // 管理员池 ) type TokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int `json:"expires_in"` TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` } type TokenErrResp struct { Error string `json:"error"` Description string `json:"error_description,omitempty"` } func authAuthorizationCode(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { // 检查 code 获取用户授权信息 data, err := g.Redis.Get(context.Background(), req.Code).Result() if err != nil { return nil, err } var codeCtx CodeContext if err := json.Unmarshal([]byte(data), &codeCtx); err != nil { return nil, err } // 检查 PKCE if codeCtx.CodeChallengeMethod != "" { if req.CodeVerifier == "" { return nil, ErrAuthorizeInvalidPKCE } switch codeCtx.CodeChallengeMethod { case "plain": if req.CodeVerifier != codeCtx.CodeChallenge { return nil, ErrAuthorizeInvalidPKCE } case "S256": hash := sha256.Sum256([]byte(req.CodeVerifier)) verifier := base64.RawURLEncoding.EncodeToString(hash[:]) if verifier != codeCtx.CodeChallenge { return nil, ErrAuthorizeInvalidPKCE } default: return nil, ErrAuthorizeInvalidPKCE } } user, err := q.User.Where( q.User.ID.Eq(codeCtx.UserID), q.User.Status.Eq(int(m.UserStatusEnabled)), ).First() if err != nil { return nil, err } // todo 检查 scope // 生成会话 ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常 ua := u.X(c.Get(fiber.HeaderUserAgent)) session := &m.Session{ IP: ip, UA: ua, UserID: &user.ID, ClientID: &auth.Client.ID, Scopes: u.P(strings.Join(codeCtx.Scopes, " ")), AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } if codeCtx.Remember { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } err = SaveSession(q.Q, session) if err != nil { return nil, err } return session, nil } func authClientCredential(c *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) { // todo 检查 scope scopes := strings.Join(auth.Scopes, " ") // 生成会话 ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常 ua := u.X(c.Get(fiber.HeaderUserAgent)) session := &m.Session{ IP: ip, UA: ua, ClientID: &auth.Client.ID, AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), Scopes: &scopes, } // 保存会话 err := SaveSession(q.Q, session) if err != nil { return nil, err } return session, nil } func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常 ua := u.X(c.Get(fiber.HeaderUserAgent)) // 分池认证 var err error var user *m.User var admin *m.Admin var scopes []string pool := req.LoginPool if pool == "" { pool = PwdLoginAsUser } switch pool { case PwdLoginAsUser: user, err = authUser(req.LoginType, req.Username, req.Password) if err != nil { if req.LoginType != PwdLoginByPhone || !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } // 手机号首次登录的自动创建用户 user = &m.User{ Phone: req.Username, Status: m.UserStatusEnabled, } } // 更新用户的登录时间 user.LastLogin = u.P(time.Now()) user.LastLoginIP = ip user.LastLoginUA = ua case PwdLoginAsAdmin: admin, err = authAdmin(req.LoginType, req.Username, req.Password) if err != nil { return nil, err } scopes, err = adminScopes(admin) if err != nil { return nil, err } // 更新管理员登录时间 admin.LastLogin = u.P(time.Now()) admin.LastLoginIP = ip admin.LastLoginUA = ua default: return nil, ErrAuthorizeInvalidRequest } // 生成会话 session := &m.Session{ IP: ip, UA: ua, ClientID: &auth.Client.ID, Scopes: u.X(strings.Join(scopes, " ")), AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } if req.Remember { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } // 保存用户更新和会话 err = q.Q.Transaction(func(tx *q.Query) error { if user != nil { if err := tx.User.Save(user); err != nil { return err } session.UserID = &user.ID } if admin != nil { if err := tx.Admin.Save(admin); err != nil { return err } session.AdminID = &admin.ID } if err := SaveSession(tx, session); err != nil { return err } return nil }) if err != nil { return nil, err } return session, nil } func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { session, err := FindSessionByRefresh(req.RefreshToken, now) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrAuthorizeInvalidGrant } return nil, err } // todo 检查权限 // 生成令牌 session.AccessToken = uuid.NewString() session.AccessTokenExpires = now.Add(time.Duration(env.SessionAccessExpire) * time.Second) if session.RefreshToken != nil { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } // 保存令牌 err = SaveSession(q.Q, session) if err != nil { return nil, err } return session, nil } func sendError(c *fiber.Ctx, err error, description ...string) error { var sErr AuthErr if errors.As(err, &sErr) { status := fiber.StatusBadRequest var desc string switch { case errors.Is(sErr, ErrAuthorizeInvalidRequest): desc = "无效的请求" case errors.Is(sErr, ErrAuthorizeInvalidClient): status = fiber.StatusUnauthorized desc = "无效的客户端凭证" case errors.Is(sErr, ErrAuthorizeInvalidGrant): desc = "无效的授权凭证" case errors.Is(sErr, ErrAuthorizeInvalidScope): desc = "无效的授权范围" case errors.Is(sErr, ErrAuthorizeUnauthorizedClient): desc = "未授权的客户端" case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType): desc = "不支持的授权类型" } if len(description) > 0 { desc = description[0] } return c.Status(status).JSON(TokenErrResp{ Error: string(sErr), Description: desc, }) } return err } // Revoke 令牌撤销端点 func Revoke(ctx *fiber.Ctx) error { _, err := GetAuthCtx(ctx).PermitUser() if err != nil { // 用户未登录 return nil } // 解析请求参数 req := new(RevokeReq) if err := ctx.BodyParser(req); err != nil { return err } // 删除会话 err = RemoveSession(ctx.Context(), req.AccessToken, req.RefreshToken) if err != nil { return err } return nil } type RevokeReq struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } // Introspect 令牌检查端点 func Introspect(ctx *fiber.Ctx) error { authCtx := GetAuthCtx(ctx) // 尝试验证用户权限 if _, err := authCtx.PermitUser(); err == nil { return introspectUser(ctx, authCtx) } // 尝试验证管理员权限 if _, err := authCtx.PermitAdmin(); err == nil { return introspectAdmin(ctx, authCtx) } return ErrAuthenticateForbidden } // introspectUser 获取并返回用户信息 func introspectUser(ctx *fiber.Ctx, authCtx *AuthCtx) error { // 获取用户信息 profile, err := q.User. Where(q.User.ID.Eq(authCtx.User.ID)). Omit(q.User.DeletedAt). Take() if err != nil { return err } // 检查用户是否设置了密码 hasPassword := false if profile.Password != nil && *profile.Password != "" { hasPassword = true profile.Password = nil // 不返回密码 } // 掩码敏感信息 if profile.Phone != "" { profile.Phone = maskPhone(profile.Phone) } if profile.IDNo != nil && *profile.IDNo != "" { profile.IDNo = u.P(maskIdNo(*profile.IDNo)) } return ctx.JSON(struct { m.User HasPassword bool `json:"has_password"` // 是否设置了密码 }{*profile, hasPassword}) } // introspectAdmin 获取并返回管理员信息 func introspectAdmin(ctx *fiber.Ctx, authCtx *AuthCtx) error { // 获取管理员信息 profile, err := q.Admin. Preload(q.Admin.Roles, q.Admin.Roles.Permissions). Where(q.Admin.ID.Eq(authCtx.Admin.ID)). Omit(q.Admin.DeletedAt, q.Admin.Password). Take() if err != nil { return err } // 整理权限列表 scopes := make(map[string]struct{}, 0) for _, role := range profile.Roles { for _, permission := range role.Permissions { scopes[permission.Name] = struct{}{} } } list := make([]string, 0, len(scopes)) for scope := range scopes { list = append(list, scope) } return ctx.JSON(struct { *m.Admin Scopes []string `json:"scopes"` }{profile, list}) } func maskPhone(phone string) string { if len(phone) < 11 { return phone } return phone[:3] + "****" + phone[7:] } func maskIdNo(idNo string) string { if len(idNo) < 18 { return idNo } return idNo[:3] + "*********" + idNo[14:] } type CodeContext struct { UserID int32 `json:"user_id"` ClientID int32 `json:"client_id"` Scopes []string `json:"scopes"` Remember bool `json:"remember"` CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` }