Files
platform/web/auth/endpoints.go

599 lines
15 KiB
Go

package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"platform/pkg/env"
"platform/pkg/u"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gorm.io/gorm"
)
// AuthorizeGet 授权端点
func AuthorizeGet(ctx *fiber.Ctx) error {
// 检查请求
req := new(AuthorizeGetReq)
if err := g.Validator.ParseQuery(ctx, req); err != nil {
return err
}
// 检查客户端
client, err := authClient(req.ClientID)
if err != nil {
return err
}
if client.RedirectURI == nil || *client.RedirectURI != req.RedirectURI {
return errors.New("客户端重定向URI错误")
}
// todo 检查 scope
// 授权确认页面
return nil
}
type AuthorizeGetReq struct {
ResponseType string `json:"response_type" validate:"eq=code"`
ClientID string `json:"client_id" validate:"required"`
RedirectURI string `json:"redirect_uri" validate:"required"`
Scope string `json:"scope"`
State string `json:"state"`
}
func AuthorizePost(ctx *fiber.Ctx) error {
// todo 解析用户授权的范围
return nil
}
type AuthorizePostReq struct {
Accept bool `json:"accept"`
Scope string `json:"scope"`
}
// Token 令牌端点
func Token(c *fiber.Ctx) error {
now := time.Now()
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
}
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
if req.Code == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
}
// 刷新令牌模式
case GrantRefreshToken:
if req.RefreshToken == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
}
// 密码模式
case GrantPassword:
if req.LoginType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
}
if req.Username == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
}
if req.Password == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
}
}
// 验证客户端身份
authCtx := GetAuthCtx(c)
if authCtx == nil {
authCtx = &AuthCtx{}
}
if authCtx.Client == nil {
client, err := authClient(req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
authCtx.Client = client
}
// 处理授权
var session *m.Session
var err error
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
session, err = authAuthorizationCode(c, authCtx, req, now)
// 客户端凭证模式
case GrantClientCredentials:
session, err = authClientCredential(c, authCtx, req, now)
// 刷新令牌模式
case GrantRefreshToken:
session, err = authRefreshToken(c, authCtx, req, now)
// 密码模式
case GrantPassword:
session, err = authPassword(c, authCtx, req, now)
default:
return sendError(c, ErrAuthorizeUnsupportedGrantType)
}
if err != nil {
return sendError(c, err)
}
// 返回响应
return c.JSON(&TokenResp{
TokenType: "Bearer",
AccessToken: session.AccessToken,
RefreshToken: u.Z(session.RefreshToken),
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
Scope: u.Z(session.Scopes),
})
}
type TokenReq struct {
GrantType GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
GrantCodeData
GrantClientData
GrantRefreshData
GrantPasswordData
}
type GrantCodeData struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type GrantClientData struct {
}
type GrantRefreshData struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type GrantPasswordData struct {
LoginType PwdLoginType `json:"login_type" form:"login_type"`
LoginPool PwdLoginPool `json:"login_pool" form:"login_pool"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type GrantType string
const (
GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式
GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式
GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式
GrantPassword = GrantType("password") // 密码模式(私有扩展)
)
type PwdLoginType string
const (
PwdLoginByPassword = PwdLoginType("password") // 账号密码
PwdLoginByPhone = PwdLoginType("phone_code") // 手机验证码
PwdLoginByEmail = PwdLoginType("email_code") // 邮箱验证码
)
type PwdLoginPool string
const (
PwdLoginAsUser = PwdLoginPool("user") // 用户池
PwdLoginAsAdmin = PwdLoginPool("admin") // 管理员池
)
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
func authAuthorizationCode(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
// 检查 code 获取用户授权信息
data, err := g.Redis.Get(context.Background(), req.Code).Result()
if err != nil {
return nil, err
}
var codeCtx CodeContext
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
return nil, err
}
// 检查 PKCE
if codeCtx.CodeChallengeMethod != "" {
if req.CodeVerifier == "" {
return nil, ErrAuthorizeInvalidPKCE
}
switch codeCtx.CodeChallengeMethod {
case "plain":
if req.CodeVerifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
case "S256":
hash := sha256.Sum256([]byte(req.CodeVerifier))
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
if verifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
default:
return nil, ErrAuthorizeInvalidPKCE
}
}
user, err := q.User.Where(
q.User.ID.Eq(codeCtx.UserID),
q.User.Status.Eq(int(m.UserStatusEnabled)),
).First()
if err != nil {
return nil, err
}
// todo 检查 scope
// 生成会话
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
ua := u.X(c.Get(fiber.HeaderUserAgent))
session := &m.Session{
IP: ip,
UA: ua,
UserID: &user.ID,
ClientID: &auth.Client.ID,
Scopes: u.P(strings.Join(codeCtx.Scopes, " ")),
AccessToken: uuid.NewString(),
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
}
if codeCtx.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
}
err = SaveSession(q.Q, session)
if err != nil {
return nil, err
}
return session, nil
}
func authClientCredential(c *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
// todo 检查 scope
scopes := strings.Join(auth.Scopes, " ")
// 生成会话
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
ua := u.X(c.Get(fiber.HeaderUserAgent))
session := &m.Session{
IP: ip,
UA: ua,
ClientID: &auth.Client.ID,
AccessToken: uuid.NewString(),
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
Scopes: &scopes,
}
// 保存会话
err := SaveSession(q.Q, session)
if err != nil {
return nil, err
}
return session, nil
}
func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
ip, _ := orm.ParseInet(c.IP()) // 可空字段,忽略异常
ua := u.X(c.Get(fiber.HeaderUserAgent))
// 分池认证
var err error
var user *m.User
var admin *m.Admin
var scopes []string
pool := req.LoginPool
if pool == "" {
pool = PwdLoginAsUser
}
switch pool {
case PwdLoginAsUser:
user, err = authUser(req.LoginType, req.Username, req.Password)
if err != nil {
if req.LoginType != PwdLoginByPhone || !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
// 手机号首次登录的自动创建用户
user = &m.User{
Phone: req.Username,
Status: m.UserStatusEnabled,
}
}
// 更新用户的登录时间
user.LastLogin = u.P(time.Now())
user.LastLoginIP = ip
user.LastLoginUA = ua
case PwdLoginAsAdmin:
admin, err = authAdmin(req.LoginType, req.Username, req.Password)
if err != nil {
return nil, err
}
scopes, err = adminScopes(admin)
if err != nil {
return nil, err
}
// 更新管理员登录时间
admin.LastLogin = u.P(time.Now())
admin.LastLoginIP = ip
admin.LastLoginUA = ua
default:
return nil, ErrAuthorizeInvalidRequest
}
// 生成会话
session := &m.Session{
IP: ip,
UA: ua,
ClientID: &auth.Client.ID,
Scopes: u.X(strings.Join(scopes, " ")),
AccessToken: uuid.NewString(),
AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second),
}
if req.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
}
// 保存用户更新和会话
err = q.Q.Transaction(func(tx *q.Query) error {
if user != nil {
if err := tx.User.Save(user); err != nil {
return err
}
session.UserID = &user.ID
}
if admin != nil {
if err := tx.Admin.Save(admin); err != nil {
return err
}
session.AdminID = &admin.ID
}
if err := SaveSession(tx, session); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return session, nil
}
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
session, err := FindSessionByRefresh(req.RefreshToken, now)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAuthorizeInvalidGrant
}
return nil, err
}
// todo 检查权限
// 生成令牌
session.AccessToken = uuid.NewString()
session.AccessTokenExpires = now.Add(time.Duration(env.SessionAccessExpire) * time.Second)
if session.RefreshToken != nil {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second))
}
// 保存令牌
err = SaveSession(q.Q, session)
if err != nil {
return nil, err
}
return session, nil
}
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr AuthErr
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, ErrAuthorizeInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, ErrAuthorizeInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// Revoke 令牌撤销端点
func Revoke(ctx *fiber.Ctx) error {
_, err := GetAuthCtx(ctx).PermitUser()
if err != nil {
// 用户未登录
return nil
}
// 解析请求参数
req := new(RevokeReq)
if err := ctx.BodyParser(req); err != nil {
return err
}
// 删除会话
err = RemoveSession(ctx.Context(), req.AccessToken, req.RefreshToken)
if err != nil {
return err
}
return nil
}
type RevokeReq struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// Introspect 令牌检查端点
func Introspect(ctx *fiber.Ctx) error {
authCtx := GetAuthCtx(ctx)
// 尝试验证用户权限
if _, err := authCtx.PermitUser(); err == nil {
return introspectUser(ctx, authCtx)
}
// 尝试验证管理员权限
if _, err := authCtx.PermitAdmin(); err == nil {
return introspectAdmin(ctx, authCtx)
}
return ErrAuthenticateForbidden
}
// introspectUser 获取并返回用户信息
func introspectUser(ctx *fiber.Ctx, authCtx *AuthCtx) error {
// 获取用户信息
profile, err := q.User.
Where(q.User.ID.Eq(authCtx.User.ID)).
Omit(q.User.DeletedAt).
Take()
if err != nil {
return err
}
// 检查用户是否设置了密码
hasPassword := false
if profile.Password != nil && *profile.Password != "" {
hasPassword = true
profile.Password = nil // 不返回密码
}
// 掩码敏感信息
if profile.Phone != "" {
profile.Phone = maskPhone(profile.Phone)
}
if profile.IDNo != nil && *profile.IDNo != "" {
profile.IDNo = u.P(maskIdNo(*profile.IDNo))
}
return ctx.JSON(struct {
m.User
HasPassword bool `json:"has_password"` // 是否设置了密码
}{*profile, hasPassword})
}
// introspectAdmin 获取并返回管理员信息
func introspectAdmin(ctx *fiber.Ctx, authCtx *AuthCtx) error {
// 获取管理员信息
profile, err := q.Admin.
Preload(q.Admin.Roles, q.Admin.Roles.Permissions).
Where(q.Admin.ID.Eq(authCtx.Admin.ID)).
Omit(q.Admin.DeletedAt, q.Admin.Password).
Take()
if err != nil {
return err
}
// 整理权限列表
scopes := make(map[string]struct{}, 0)
for _, role := range profile.Roles {
for _, permission := range role.Permissions {
scopes[permission.Name] = struct{}{}
}
}
list := make([]string, 0, len(scopes))
for scope := range scopes {
list = append(list, scope)
}
return ctx.JSON(struct {
*m.Admin
Scopes []string `json:"scopes"`
}{profile, list})
}
func maskPhone(phone string) string {
if len(phone) < 11 {
return phone
}
return phone[:3] + "****" + phone[7:]
}
func maskIdNo(idNo string) string {
if len(idNo) < 18 {
return idNo
}
return idNo[:3] + "*********" + idNo[14:]
}
type CodeContext struct {
UserID int32 `json:"user_id"`
ClientID int32 `json:"client_id"`
Scopes []string `json:"scopes"`
Remember bool `json:"remember"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}