552 lines
14 KiB
Go
552 lines
14 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"platform/pkg/env"
|
|
"platform/pkg/u"
|
|
g "platform/web/globals"
|
|
"platform/web/globals/orm"
|
|
m "platform/web/models"
|
|
q "platform/web/queries"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// AuthorizeGet 授权端点
|
|
func AuthorizeGet(ctx *fiber.Ctx) error {
|
|
|
|
// 检查请求
|
|
req := new(AuthorizeGetReq)
|
|
if err := g.Validator.ParseQuery(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查客户端
|
|
client, err := authClient(req.ClientID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if client.RedirectURI == nil || *client.RedirectURI != req.RedirectURI {
|
|
return errors.New("客户端重定向URI错误")
|
|
}
|
|
|
|
// todo 检查 scope
|
|
|
|
// 授权确认页面
|
|
return nil
|
|
}
|
|
|
|
type AuthorizeGetReq struct {
|
|
ResponseType string `json:"response_type" validate:"eq=code"`
|
|
ClientID string `json:"client_id" validate:"required"`
|
|
RedirectURI string `json:"redirect_uri" validate:"required"`
|
|
Scope string `json:"scope"`
|
|
State string `json:"state"`
|
|
}
|
|
|
|
func AuthorizePost(ctx *fiber.Ctx) error {
|
|
|
|
// todo 解析用户授权的范围
|
|
|
|
return nil
|
|
}
|
|
|
|
type AuthorizePostReq struct {
|
|
Accept bool `json:"accept"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
// Token 令牌端点
|
|
func Token(c *fiber.Ctx) error {
|
|
now := time.Now()
|
|
|
|
// 验证请求参数
|
|
req := new(TokenReq)
|
|
if err := c.BodyParser(req); err != nil {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
|
|
}
|
|
if req.GrantType == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
|
|
}
|
|
|
|
switch req.GrantType {
|
|
// 授权码模式
|
|
case GrantAuthorizationCode:
|
|
if req.Code == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
|
|
}
|
|
// 刷新令牌模式
|
|
case GrantRefreshToken:
|
|
if req.RefreshToken == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
|
|
}
|
|
// 密码模式
|
|
case GrantPassword:
|
|
if req.LoginType == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
|
|
}
|
|
if req.Username == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
|
|
}
|
|
if req.Password == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
|
|
}
|
|
}
|
|
|
|
// 验证客户端身份
|
|
authCtx := GetAuthCtx(c)
|
|
if authCtx == nil {
|
|
authCtx = &AuthCtx{}
|
|
}
|
|
if authCtx.Client == nil {
|
|
client, err := authClient(req.ClientID, req.ClientSecret)
|
|
if err != nil {
|
|
return sendError(c, err)
|
|
}
|
|
authCtx.Client = client
|
|
}
|
|
|
|
// 处理授权
|
|
var session *m.Session
|
|
var err error
|
|
switch req.GrantType {
|
|
// 授权码模式
|
|
case GrantAuthorizationCode:
|
|
session, err = authAuthorizationCode(c, authCtx, req, now)
|
|
// 客户端凭证模式
|
|
case GrantClientCredentials:
|
|
session, err = authClientCredential(c, authCtx, req, now)
|
|
// 刷新令牌模式
|
|
case GrantRefreshToken:
|
|
session, err = authRefreshToken(c, authCtx, req, now)
|
|
// 密码模式
|
|
case GrantPassword:
|
|
session, err = authPassword(c, authCtx, req, now)
|
|
default:
|
|
return sendError(c, ErrAuthorizeUnsupportedGrantType)
|
|
}
|
|
if err != nil {
|
|
return sendError(c, err)
|
|
}
|
|
|
|
// 返回响应
|
|
return c.JSON(&TokenResp{
|
|
TokenType: "Bearer",
|
|
AccessToken: session.AccessToken,
|
|
RefreshToken: u.Z(session.RefreshToken),
|
|
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
|
|
Scope: u.Z(session.Scopes),
|
|
})
|
|
}
|
|
|
|
type TokenReq struct {
|
|
GrantType GrantType `json:"grant_type" form:"grant_type"`
|
|
ClientID string `json:"client_id" form:"client_id"`
|
|
ClientSecret string `json:"client_secret" form:"client_secret"`
|
|
Scope string `json:"scope" form:"scope"`
|
|
GrantCodeData
|
|
GrantClientData
|
|
GrantRefreshData
|
|
GrantPasswordData
|
|
}
|
|
|
|
type GrantCodeData struct {
|
|
Code string `json:"code" form:"code"`
|
|
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
|
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
|
}
|
|
|
|
type GrantClientData struct {
|
|
}
|
|
|
|
type GrantRefreshData struct {
|
|
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
|
}
|
|
|
|
type GrantPasswordData struct {
|
|
LoginType PwdLoginType `json:"login_type" form:"login_type"`
|
|
LoginPool PwdLoginPool `json:"login_pool" form:"login_pool"`
|
|
Username string `json:"username" form:"username"`
|
|
Password string `json:"password" form:"password"`
|
|
Remember bool `json:"remember" form:"remember"`
|
|
}
|
|
|
|
type GrantType string
|
|
|
|
const (
|
|
GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式
|
|
GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式
|
|
GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式
|
|
GrantPassword = GrantType("password") // 密码模式(私有扩展)
|
|
)
|
|
|
|
type PwdLoginType string
|
|
|
|
const (
|
|
PwdLoginByPassword = PwdLoginType("password") // 账号密码
|
|
PwdLoginByPhone = PwdLoginType("phone_code") // 手机验证码
|
|
PwdLoginByEmail = PwdLoginType("email_code") // 邮箱验证码
|
|
)
|
|
|
|
type PwdLoginPool string
|
|
|
|
const (
|
|
PwdLoginAsUser = PwdLoginPool("user") // 用户池
|
|
PwdLoginAsAdmin = PwdLoginPool("admin") // 管理员池
|
|
)
|
|
|
|
type TokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
type TokenErrResp struct {
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description,omitempty"`
|
|
}
|
|
|
|
func authAuthorizationCode(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
|
|
// 检查 code 获取用户授权信息
|
|
data, err := g.Redis.Get(context.Background(), req.Code).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var codeCtx CodeContext
|
|
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 检查 PKCE
|
|
if codeCtx.CodeChallengeMethod != "" {
|
|
if req.CodeVerifier == "" {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
|
|
switch codeCtx.CodeChallengeMethod {
|
|
case "plain":
|
|
if req.CodeVerifier != codeCtx.CodeChallenge {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
case "S256":
|
|
hash := sha256.Sum256([]byte(req.CodeVerifier))
|
|
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
if verifier != codeCtx.CodeChallenge {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
default:
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
}
|
|
|
|
user, err := q.User.Where(
|
|
q.User.ID.Eq(codeCtx.UserID),
|
|
q.User.Status.Eq(int(m.UserStatusEnabled)),
|
|
).First()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// todo 检查 scope
|
|
|
|
// 生成会话
|
|
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
|
|
ua := u.X(c.Get(fiber.HeaderUserAgent))
|
|
session := &m.Session{
|
|
IP: ip,
|
|
UA: ua,
|
|
UserID: &user.ID,
|
|
ClientID: &auth.Client.ID,
|
|
Scopes: u.P(strings.Join(codeCtx.Scopes, " ")),
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
|
|
}
|
|
if codeCtx.Remember {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
|
|
}
|
|
|
|
err = SaveSession(q.Q, session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authClientCredential(c *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
|
|
// todo 检查 scope
|
|
|
|
// 生成会话
|
|
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
|
|
ua := u.X(c.Get(fiber.HeaderUserAgent))
|
|
session := &m.Session{
|
|
IP: ip,
|
|
UA: ua,
|
|
ClientID: &auth.Client.ID,
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
|
|
}
|
|
|
|
// 保存会话
|
|
err := SaveSession(q.Q, session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
|
|
ua := u.X(c.Get(fiber.HeaderUserAgent))
|
|
|
|
// 分池认证
|
|
var err error
|
|
var user *m.User
|
|
var admin *m.Admin
|
|
|
|
pool := req.LoginPool
|
|
if pool == "" {
|
|
pool = PwdLoginAsUser
|
|
}
|
|
switch pool {
|
|
case PwdLoginAsUser:
|
|
user, err = authUser(req.LoginType, req.Username, req.Password)
|
|
if err != nil {
|
|
if req.LoginType != PwdLoginByPhone || !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
// 手机号首次登录的自动创建用户
|
|
user = &m.User{
|
|
Phone: req.Username,
|
|
Username: u.P(req.Username),
|
|
Status: m.UserStatusEnabled,
|
|
}
|
|
}
|
|
|
|
// 更新用户的登录时间
|
|
user.LastLogin = u.P(time.Now())
|
|
user.LastLoginIP = ip
|
|
user.LastLoginUA = ua
|
|
|
|
case PwdLoginAsAdmin:
|
|
admin, err = authAdmin(req.LoginType, req.Username, req.Password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 更新管理员登录时间
|
|
admin.LastLogin = u.P(time.Now())
|
|
admin.LastLoginIP = ip
|
|
admin.LastLoginUA = ua
|
|
}
|
|
|
|
// 生成会话
|
|
session := &m.Session{
|
|
IP: ip,
|
|
UA: ua,
|
|
ClientID: &auth.Client.ID,
|
|
Scopes: u.X(req.Scope),
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
|
|
}
|
|
if user != nil {
|
|
session.UserID = &user.ID
|
|
}
|
|
if admin != nil {
|
|
session.AdminID = &admin.ID
|
|
}
|
|
if req.Remember {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
|
|
}
|
|
|
|
// 保存用户更新和会话
|
|
err = q.Q.Transaction(func(tx *q.Query) error {
|
|
if err := SaveSession(tx, session); err != nil {
|
|
return err
|
|
}
|
|
if user != nil {
|
|
if err := tx.User.Save(user); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if admin != nil {
|
|
if err := tx.Admin.Save(admin); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
session, err := FindSessionByRefresh(req.RefreshToken, now)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrAuthorizeInvalidGrant
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// todo 检查权限
|
|
|
|
// 生成令牌
|
|
session.AccessToken = uuid.NewString()
|
|
session.AccessTokenExpires = now.Add(time.Duration(env.SessionAccessExpire) * time.Second)
|
|
if session.RefreshToken != nil {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
|
|
}
|
|
|
|
// 保存令牌
|
|
err = SaveSession(q.Q, session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
|
var sErr AuthErr
|
|
if errors.As(err, &sErr) {
|
|
status := fiber.StatusBadRequest
|
|
var desc string
|
|
switch {
|
|
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
|
|
desc = "无效的请求"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidClient):
|
|
status = fiber.StatusUnauthorized
|
|
desc = "无效的客户端凭证"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
|
|
desc = "无效的授权凭证"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidScope):
|
|
desc = "无效的授权范围"
|
|
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
|
|
desc = "未授权的客户端"
|
|
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
|
|
desc = "不支持的授权类型"
|
|
}
|
|
if len(description) > 0 {
|
|
desc = description[0]
|
|
}
|
|
|
|
return c.Status(status).JSON(TokenErrResp{
|
|
Error: string(sErr),
|
|
Description: desc,
|
|
})
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Revoke 令牌撤销端点
|
|
func Revoke(ctx *fiber.Ctx) error {
|
|
_, err := GetAuthCtx(ctx).PermitUser()
|
|
if err != nil {
|
|
// 用户未登录
|
|
return nil
|
|
}
|
|
|
|
// 解析请求参数
|
|
req := new(RevokeReq)
|
|
if err := ctx.BodyParser(req); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 删除会话
|
|
err = RemoveSession(ctx.Context(), req.AccessToken, req.RefreshToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type RevokeReq struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
}
|
|
|
|
// Introspect 令牌检查端点
|
|
func Introspect(ctx *fiber.Ctx) error {
|
|
// 验证权限
|
|
authCtx, err := GetAuthCtx(ctx).PermitUser()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 获取用户信息
|
|
profile, err := q.User.
|
|
Where(q.User.ID.Eq(authCtx.User.ID)).
|
|
Omit(q.User.DeletedAt).
|
|
Take()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查用户是否设置了密码
|
|
hasPassword := false
|
|
if profile.Password != nil && *profile.Password != "" {
|
|
hasPassword = true
|
|
profile.Password = nil // 不返回密码
|
|
}
|
|
|
|
// 掩码敏感信息
|
|
if profile.Phone != "" {
|
|
profile.Phone = maskPhone(profile.Phone)
|
|
}
|
|
if profile.IDNo != nil && *profile.IDNo != "" {
|
|
profile.IDNo = u.P(maskIdNo(*profile.IDNo))
|
|
}
|
|
return ctx.JSON(IntrospectResp{*profile, hasPassword})
|
|
}
|
|
|
|
type IntrospectResp struct {
|
|
m.User
|
|
HasPassword bool `json:"has_password"` // 是否设置了密码
|
|
}
|
|
|
|
func maskPhone(phone string) string {
|
|
if len(phone) < 11 {
|
|
return phone
|
|
}
|
|
return phone[:3] + "****" + phone[7:]
|
|
}
|
|
|
|
func maskIdNo(idNo string) string {
|
|
if len(idNo) < 18 {
|
|
return idNo
|
|
}
|
|
return idNo[:3] + "*********" + idNo[14:]
|
|
}
|
|
|
|
type CodeContext struct {
|
|
UserID int32 `json:"user_id"`
|
|
ClientID int32 `json:"client_id"`
|
|
Scopes []string `json:"scopes"`
|
|
Remember bool `json:"remember"`
|
|
CodeChallenge string `json:"code_challenge"`
|
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
|
}
|