diff --git a/README.md b/README.md index e2d313c..2d2222b 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ - [ ] Limiter - [ ] Compress +废弃 password 授权模式,迁移到 authorization code 授权模式 + 使用 fiber 自带 validator 进行参数验证 增加 domain 层,缓解同包字段过长的问题 diff --git a/web/auth/auth.go b/web/auth/auth.go index e45a96b..7770d5a 100644 --- a/web/auth/auth.go +++ b/web/auth/auth.go @@ -100,12 +100,12 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( var header = c.Get("Authorization") var split = strings.Split(header, " ") if len(split) != 2 { - return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var token = split[1] if token == "" { - return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var auth *services.AuthContext @@ -115,7 +115,7 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( auth, err = authBearer(c.Context(), token) case "Basic": if !slices.Contains(types, services.PayloadClientConfidential) { - return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } auth, err = authBasic(c.Context(), token) default: @@ -127,10 +127,10 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( // 检查权限 if !slices.Contains(types, auth.Payload.Type) { - return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问") + return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { - return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问") + return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } // 将认证信息存储在上下文中 diff --git a/web/handlers/auth.go b/web/handlers/auth.go new file mode 100644 index 0000000..b7e23be --- /dev/null +++ b/web/handlers/auth.go @@ -0,0 +1,292 @@ +package handlers + +import ( + "encoding/base64" + "errors" + "log/slog" + "platform/web/auth" + m "platform/web/models" + q "platform/web/queries" + s "platform/web/services" + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +// region /token + +type TokenReq struct { + GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"` + ClientID string `json:"client_id" form:"client_id"` + ClientSecret string `json:"client_secret" form:"client_secret"` + Scope string `json:"scope" form:"scope"` + s.GrantCodeData + s.GrantClientData + s.GrantRefreshData + s.GrantPasswordData +} + +type TokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` +} + +type TokenErrResp struct { + Error string `json:"error"` + Description string `json:"error_description,omitempty"` +} + +// Token 处理 OAuth2.0 授权请求 +func Token(c *fiber.Ctx) error { + + // 验证请求参数 + req := new(TokenReq) + if err := c.BodyParser(req); err != nil { + return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数") + } + if req.GrantType == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type") + } + + slog.Debug("oauth token", slog.String("grant_type", + string(req.GrantType)), + slog.String("client_id", req.ClientID), + ) + + // 基于授权类型处理请求 + switch req.GrantType { + + // 授权码模式 + case s.OauthGrantTypeAuthorizationCode: + if req.Code == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code") + } + + client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier) + if err != nil { + return sendError(c, err.(s.AuthServiceError)) + } + + return sendSuccess(c, token) + + // 客户端凭证模式 + case s.OauthGrantTypeClientCredentials: + client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + scope := strings.Split(req.Scope, ",") + token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...) + if err != nil { + return sendError(c, err.(s.AuthServiceError)) + } + + return sendSuccess(c, token) + + // 刷新令牌模式 + case s.OauthGrantTypeRefreshToken: + if req.RefreshToken == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") + } + + client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + scope := strings.Split(req.Scope, ",") + token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) + if err != nil { + if errors.Is(err, s.ErrInvalidToken) { + return sendError(c, s.ErrOauthInvalidGrant) + } + return sendError(c, err) + } + + return sendSuccess(c, token) + + // 密码模式 + case s.OauthGrantTypePassword: + if req.LoginType == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type") + } + if req.Username == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username") + } + if req.Password == "" { + return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password") + } + + client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent")) + if err != nil { + return err + } + + return sendSuccess(c, token) + + default: + return sendError(c, s.ErrOauthUnsupportedGrantType) + } +} + +// 检查客户端凭证 +func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) { + header := c.Get("Authorization") + if header != "" { + basic := strings.TrimPrefix(header, "Basic ") + if basic != "" { + base, err := base64.URLEncoding.DecodeString(basic) + if err != nil { + return nil, err + } + parts := strings.SplitN(string(base), ":", 2) + if len(parts) == 2 { + clientId = parts[0] + clientSecret = parts[1] + } + } + } + + // 查找客户端 + if clientId == "" { + return nil, s.ErrOauthInvalidRequest + } + client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, s.ErrOauthInvalidClient + } + return nil, err + } + + // 验证客户端状态 + if client.Status != 1 { + return nil, s.ErrOauthUnauthorizedClient + } + + // 验证授权类型 + switch grant { + case s.OauthGrantTypeAuthorizationCode: + if !client.GrantCode { + return nil, s.ErrOauthUnauthorizedClient + } + case s.OauthGrantTypeClientCredentials: + if !client.GrantClient || client.Spec != 0 { + return nil, s.ErrOauthUnauthorizedClient + } + case s.OauthGrantTypeRefreshToken: + if !client.GrantRefresh { + return nil, s.ErrOauthUnauthorizedClient + } + case s.OauthGrantTypePassword: + if !client.GrantPassword { + return nil, s.ErrOauthUnauthorizedClient + } + } + + // 如果客户端是 confidential,验证 client_secret,失败返回错误 + if client.Spec == 0 { + if clientSecret == "" { + return nil, s.ErrOauthInvalidRequest + } + if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { + return nil, s.ErrOauthInvalidClient + } + } + + return client, nil +} + +// 发送成功响应 +func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error { + return c.JSON(TokenResp{ + AccessToken: details.AccessToken, + TokenType: "Bearer", + ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()), + RefreshToken: details.RefreshToken, + }) +} + +// 发送错误响应 +func sendError(c *fiber.Ctx, err error, description ...string) error { + var sErr s.AuthServiceError + if errors.As(err, &sErr) { + status := fiber.StatusBadRequest + var desc string + switch { + case errors.Is(sErr, s.ErrOauthInvalidRequest): + desc = "无效的请求" + case errors.Is(sErr, s.ErrOauthInvalidClient): + status = fiber.StatusUnauthorized + desc = "无效的客户端凭证" + case errors.Is(sErr, s.ErrOauthInvalidGrant): + desc = "无效的授权凭证" + case errors.Is(sErr, s.ErrOauthInvalidScope): + desc = "无效的授权范围" + case errors.Is(sErr, s.ErrOauthUnauthorizedClient): + desc = "未授权的客户端" + case errors.Is(sErr, s.ErrOauthUnsupportedGrantType): + desc = "不支持的授权类型" + } + if len(description) > 0 { + desc = description[0] + } + + return c.Status(status).JSON(TokenErrResp{ + Error: string(sErr), + Description: desc, + }) + } + + return err +} + +// endregion + +// region /revoke + +type RevokeReq struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +func Revoke(c *fiber.Ctx) error { + _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + if err != nil { + // 用户未登录 + return nil + } + + // 解析请求参数 + req := new(RevokeReq) + if err := c.BodyParser(req); err != nil { + return err + } + + // 删除会话 + err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken) + if err != nil { + return err + } + + return nil +} + +// endregion diff --git a/web/handlers/login.go b/web/handlers/login.go deleted file mode 100644 index 26e747c..0000000 --- a/web/handlers/login.go +++ /dev/null @@ -1,145 +0,0 @@ -package handlers - -import ( - "errors" - "platform/web/auth" - m "platform/web/models" - q "platform/web/queries" - s "platform/web/services" - "time" - - "github.com/gofiber/fiber/v2" - "gorm.io/gorm" -) - -type LoginReq struct { - Username string `json:"username"` - Password string `json:"password"` - Remember bool `json:"remember"` -} - -type LoginResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - Expires int64 `json:"expires"` - Auth s.AuthContext `json:"auth"` - Profile *m.User `json:"profile"` -} - -func Login(c *fiber.Ctx) error { - - // 验证请求参数 - req := new(LoginReq) - if err := c.BodyParser(req); err != nil { - return err - } - if req.Username == "" { - return fiber.NewError(fiber.StatusBadRequest, "手机号不能为空") - } - if req.Password == "" { - return fiber.NewError(fiber.StatusBadRequest, "验证码不能为空") - } - - return loginByPhone(c, req) -} - -func loginByPhone(c *fiber.Ctx, req *LoginReq) error { - - // 验证验证码 - err := s.Verifier.VerifySms(c.Context(), req.Username, req.Password) - if err != nil { - if errors.Is(err, s.ErrVerifierServiceInvalid) { - return fiber.NewError(fiber.StatusBadRequest, "验证码错误") - } - return err - } - - // 查找用户 todo 获取权限信息 - var user *m.User - err = q.Q.Transaction(func(tx *q.Query) error { - user, err = tx.User. - Where(tx.User.Phone.Eq(req.Username)). - Take() - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - - // 如果用户不存在,初始化用户 todo 保存默认权限信息 - if user == nil { - user = &m.User{ - Phone: req.Username, - Username: req.Username, - } - } - - // 更新用户的登录时间 - user.LastLogin = time.Now() - user.LastLoginHost = c.IP() - user.LastLoginAgent = c.Get("User-Agent") - if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil { - return err - } - - return nil - }) - if err != nil { - return err - } - - // 保存到会话 - auth := s.AuthContext{ - Permissions: map[string]struct{}{ - "user": {}, - }, - Payload: s.Payload{ - Id: user.ID, - Type: s.PayloadUser, - Name: user.Name, - Avatar: user.Avatar, - }, - } - duration := time.Hour * 24 - if req.Remember { - duration *= 7 - } - token, err := s.Session.Create(c.Context(), auth) - if err != nil { - return err - } - - user.Password = "" - return c.JSON(LoginResp{ - AccessToken: token.AccessToken, - RefreshToken: token.RefreshToken, - Expires: token.AccessTokenExpires.Unix(), - Auth: auth, - Profile: user, - }) -} - -type LogoutReq struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` -} - -func Logout(c *fiber.Ctx) error { - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) - if err != nil { - // 用户未登录 - return nil - } - - // 解析请求参数 - req := new(LogoutReq) - if err := c.BodyParser(req); err != nil { - return err - } - - // 删除会话 - err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken) - if err != nil { - return err - } - - return nil -} diff --git a/web/handlers/oauth.go b/web/handlers/oauth.go deleted file mode 100644 index 0549d76..0000000 --- a/web/handlers/oauth.go +++ /dev/null @@ -1,371 +0,0 @@ -package handlers - -import ( - "encoding/base64" - "errors" - "log/slog" - m "platform/web/models" - q "platform/web/queries" - s "platform/web/services" - "strings" - "time" - - "github.com/gofiber/fiber/v2" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" -) - -// region Token - -type TokenReq struct { - GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"` - ClientID string `json:"client_id" form:"client_id"` - ClientSecret string `json:"client_secret" form:"client_secret"` - Scope string `json:"scope" form:"scope"` - TokenReqCode - TokenReqClient - TokenReqRefresh - TokenReqPassword -} - -type TokenReqCode struct { - Code string `json:"code" form:"code"` - RedirectURI string `json:"redirect_uri" form:"redirect_uri"` - CodeVerifier string `json:"code_verifier" form:"code_verifier"` -} - -type TokenReqClient struct { -} - -type TokenReqRefresh struct { - RefreshToken string `json:"refresh_token" form:"refresh_token"` -} - -type TokenReqPassword struct { - LoginType s.OauthGrantLoginType `json:"login_type" form:"login_type"` - Username string `json:"username" form:"username"` - Password string `json:"password" form:"password"` - Remember bool `json:"remember" form:"remember"` -} - -type TokenResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresIn int `json:"expires_in"` - TokenType string `json:"token_type"` - Scope string `json:"scope,omitempty"` -} - -type TokenErrResp struct { - Error string `json:"error"` - Description string `json:"error_description,omitempty"` -} - -// Token 处理 OAuth2.0 授权请求 -func Token(c *fiber.Ctx) error { - - // 验证请求参数 - req := new(TokenReq) - if err := c.BodyParser(req); err != nil { - return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数") - } - if req.GrantType == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type") - } - - slog.Debug("oauth token", slog.String("grant_type", - string(req.GrantType)), - slog.String("client_id", req.ClientID), - ) - - // 基于授权类型处理请求 - switch req.GrantType { - - case s.OauthGrantTypeAuthorizationCode: - return authorizationCode(c, req) - - case s.OauthGrantTypeClientCredentials: - return clientCredentials(c, req) - - case s.OauthGrantTypeRefreshToken: - return refreshToken(c, req) - - case s.OauthGrantTypePassword: - return password(c, req) - - default: - return sendError(c, s.ErrOauthUnsupportedGrantType) - } -} - -// 授权码 -func authorizationCode(c *fiber.Ctx, req *TokenReq) error { - if req.Code == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code") - } - - client, err := protect(c, s.OauthGrantTypeAuthorizationCode, req.ClientID, req.ClientSecret) - if err != nil { - return sendError(c, err) - } - - token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier) - if err != nil { - return sendError(c, err.(s.AuthServiceOauthError)) - } - - return sendSuccess(c, token) -} - -// 客户端凭证 -func clientCredentials(c *fiber.Ctx, req *TokenReq) error { - client, err := protect(c, s.OauthGrantTypeClientCredentials, req.ClientID, req.ClientSecret) - if err != nil { - return sendError(c, err) - } - - scope := strings.Split(req.Scope, ",") - token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...) - if err != nil { - return sendError(c, err.(s.AuthServiceOauthError)) - } - - return sendSuccess(c, token) -} - -// 刷新令牌 -func refreshToken(c *fiber.Ctx, req *TokenReq) error { - if req.RefreshToken == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") - } - - client, err := protect(c, s.OauthGrantTypeRefreshToken, req.ClientID, req.ClientSecret) - if err != nil { - return sendError(c, err) - } - - scope := strings.Split(req.Scope, ",") - token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) - if err != nil { - if errors.Is(err, s.ErrInvalidToken) { - return sendError(c, s.ErrOauthInvalidGrant) - } - return sendError(c, err) - } - - return sendSuccess(c, token) -} - -func password(c *fiber.Ctx, req *TokenReq) error { - if req.LoginType == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type") - } - if req.Username == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username") - } - if req.Password == "" { - return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password") - } - - // 验证客户端凭证 - _, err := protect(c, s.OauthGrantTypePassword, req.ClientID, req.ClientSecret) - if err != nil { - return sendError(c, err) - } - - // 验证验证码 - err = s.Verifier.VerifySms(c.Context(), req.Username, req.Password) - if err != nil { - if errors.Is(err, s.ErrVerifierServiceInvalid) { - return fiber.NewError(fiber.StatusBadRequest, "验证码错误") - } - return err - } - - // 查找用户 - var user *m.User - err = q.Q.Transaction(func(tx *q.Query) error { - - switch req.LoginType { - case s.OauthGrantPasswordTypePhoneCode: - user, err = tx.User.Where(tx.User.Phone.Eq(req.Username)).Take() - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - case s.OauthGrantPasswordTypeEmailCode: - user, err = tx.User.Where(tx.User.Email.Eq(req.Username)).Take() - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - case s.OauthGrantPasswordTypePassword: - user, err = tx.User. - Where(tx.User.Or( - tx.User.Phone.Eq(req.Username), - tx.User.Email.Eq(req.Username), - tx.User.Username.Eq(req.Username), - )). - Take() - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - default: - return sendError(c, s.ErrOauthInvalidRequest, "无效的登录类型") - } - - // 如果用户不存在,初始化用户 todo 初始化默认权限信息 - if user == nil { - user = &m.User{ - Phone: req.Username, - Username: req.Username, - } - } - - // 更新用户的登录时间 - user.LastLogin = time.Now() - user.LastLoginHost = c.IP() - user.LastLoginAgent = c.Get("User-Agent") - if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil { - return err - } - - return nil - }) - if err != nil { - return err - } - - // 保存到会话 - auth := s.AuthContext{ - Payload: s.Payload{ - Id: user.ID, - Type: s.PayloadUser, - Name: user.Name, - Avatar: user.Avatar, - }, - } - - duration := s.DefaultSessionConfig - if !req.Remember { - duration.RefreshTokenDuration = 0 - } - token, err := s.Session.Create(c.Context(), auth) - if err != nil { - return err - } - - return sendSuccess(c, token) -} - -// 检查客户端凭证 -func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) { - header := c.Get("Authorization") - if header != "" { - basic := strings.TrimPrefix(header, "Basic ") - if basic != "" { - base, err := base64.URLEncoding.DecodeString(basic) - if err != nil { - return nil, err - } - parts := strings.SplitN(string(base), ":", 2) - if len(parts) == 2 { - clientId = parts[0] - clientSecret = parts[1] - } - } - } - - // 查找客户端 - if clientId == "" { - return nil, s.ErrOauthInvalidRequest - } - client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take() - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, s.ErrOauthInvalidClient - } - return nil, err - } - - // 验证客户端状态 - if client.Status != 1 { - return nil, s.ErrOauthUnauthorizedClient - } - - // 验证授权类型 - switch grant { - case s.OauthGrantTypeAuthorizationCode: - if !client.GrantCode { - return nil, s.ErrOauthUnauthorizedClient - } - case s.OauthGrantTypeClientCredentials: - if !client.GrantClient || client.Spec != 0 { - return nil, s.ErrOauthUnauthorizedClient - } - case s.OauthGrantTypeRefreshToken: - if !client.GrantRefresh { - return nil, s.ErrOauthUnauthorizedClient - } - case s.OauthGrantTypePassword: - if !client.GrantPassword { - return nil, s.ErrOauthUnauthorizedClient - } - } - - // 如果客户端是 confidential,验证 client_secret,失败返回错误 - if client.Spec == 0 { - if clientSecret == "" { - return nil, s.ErrOauthInvalidRequest - } - if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { - return nil, s.ErrOauthInvalidClient - } - } - - return client, nil -} - -// 发送成功响应 -func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error { - return c.JSON(TokenResp{ - AccessToken: details.AccessToken, - TokenType: "Bearer", - ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()), - RefreshToken: details.RefreshToken, - }) -} - -// 发送错误响应 -func sendError(c *fiber.Ctx, err error, description ...string) error { - var sErr s.AuthServiceOauthError - if errors.As(err, &sErr) { - status := fiber.StatusBadRequest - var desc string - switch { - case errors.Is(sErr, s.ErrOauthInvalidRequest): - desc = "无效的请求" - case errors.Is(sErr, s.ErrOauthInvalidClient): - status = fiber.StatusUnauthorized - desc = "无效的客户端凭证" - case errors.Is(sErr, s.ErrOauthInvalidGrant): - desc = "无效的授权凭证" - case errors.Is(sErr, s.ErrOauthInvalidScope): - desc = "无效的授权范围" - case errors.Is(sErr, s.ErrOauthUnauthorizedClient): - desc = "未授权的客户端" - case errors.Is(sErr, s.ErrOauthUnsupportedGrantType): - desc = "不支持的授权类型" - } - if len(description) > 0 { - desc = description[0] - } - - return c.Status(status).JSON(TokenErrResp{ - Error: string(sErr), - Description: desc, - }) - } - - return err -} - -// endregion diff --git a/web/router.go b/web/router.go index 42580e8..07ff5c2 100644 --- a/web/router.go +++ b/web/router.go @@ -12,10 +12,9 @@ func ApplyRouters(app *fiber.App) { // 认证 auth := api.Group("/auth") - auth.Post("/verify/sms", handlers.SmsCode) - auth.Post("/login/sms", auth2.PermitDevice(), handlers.Login) - auth.Post("/logout", handlers.Logout) auth.Post("/token", handlers.Token) + auth.Post("/revoke", handlers.Revoke) + auth.Post("/verify/sms", handlers.SmsCode) // 通道 channel := api.Group("/channel") diff --git a/web/services/auth.go b/web/services/auth.go index 7c77554..6ffa463 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -3,42 +3,25 @@ package services import ( "context" "errors" - "platform/web/models" + m "platform/web/models" + q "platform/web/queries" + "time" + + "gorm.io/gorm" ) var Auth = &authService{} -type AuthServiceError string - -func (e AuthServiceError) Error() string { - return string(e) -} - -type AuthServiceOauthError string - -func (e AuthServiceOauthError) Error() string { - return string(e) -} - -var ( - ErrOauthInvalidRequest = AuthServiceOauthError("invalid_request") - ErrOauthInvalidClient = AuthServiceOauthError("invalid_client") - ErrOauthInvalidGrant = AuthServiceOauthError("invalid_grant") - ErrOauthInvalidScope = AuthServiceOauthError("invalid_scope") - ErrOauthUnauthorizedClient = AuthServiceOauthError("unauthorized_client") - ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type") -) - type authService struct{} // OauthAuthorizationCode 验证授权码 -func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) { +func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) { // TODO: 从数据库验证授权码 return nil, errors.New("TODO") } // OauthClientCredentials 验证客户端凭证 -func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...string) (*TokenDetails, error) { +func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { var clientType PayloadType switch client.Spec { @@ -75,7 +58,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models } // OauthRefreshToken 验证刷新令牌 -func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) { +func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) { // TODO: 从数据库验证刷新令牌 details, err := Session.Refresh(ctx, refreshToken) if err != nil { @@ -85,6 +68,114 @@ func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Clie return details, nil } +// OauthPassword 验证密码 +func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*TokenDetails, error) { + var user *m.User + err := q.Q.Transaction(func(tx *q.Query) error { + + switch data.LoginType { + case OauthGrantPasswordTypePhoneCode: + // 验证验证码 + err := Verifier.VerifySms(ctx, data.Username, data.Password) + if err != nil { + if errors.Is(err, ErrVerifierServiceInvalid) { + return ErrOauthInvalidRequest + } + return err + } + + // 查找用户 + user, err = + tx.User.Where(tx.User.Phone.Eq(data.Username)).Take() + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + case OauthGrantPasswordTypeEmailCode: + var err error + user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take() + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + case OauthGrantPasswordTypePassword: + var err error + user, err = tx.User. + Where(tx.User.Or( + tx.User.Phone.Eq(data.Username), + tx.User.Email.Eq(data.Username), + tx.User.Username.Eq(data.Username), + )). + Take() + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + default: + return ErrOauthInvalidRequest + } + + // 如果用户不存在,初始化用户 todo 初始化默认权限信息 + if user == nil { + user = &m.User{ + Phone: data.Username, + Username: data.Username, + } + } + + // 更新用户的登录时间 + user.LastLogin = time.Now() + user.LastLoginHost = ip + user.LastLoginAgent = agent + if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + // 保存到会话 + auth := AuthContext{ + Payload: Payload{ + Id: user.ID, + Type: PayloadUser, + Name: user.Name, + Avatar: user.Avatar, + }, + } + + duration := DefaultSessionConfig + if !data.Remember { + duration.RefreshTokenDuration = 0 + } + token, err := Session.Create(ctx, auth) + if err != nil { + return nil, err + } + + return token, nil +} + +type GrantCodeData struct { + Code string `json:"code" form:"code"` + RedirectURI string `json:"redirect_uri" form:"redirect_uri"` + CodeVerifier string `json:"code_verifier" form:"code_verifier"` +} + +type GrantClientData struct { +} + +type GrantRefreshData struct { + RefreshToken string `json:"refresh_token" form:"refresh_token"` +} + +type GrantPasswordData struct { + LoginType OauthGrantLoginType `json:"login_type" form:"login_type"` + Username string `json:"username" form:"username"` + Password string `json:"password" form:"password"` + Remember bool `json:"remember" form:"remember"` +} + type OauthGrantType string const ( @@ -101,3 +192,18 @@ const ( OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code") OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code") ) + +type AuthServiceError string + +func (e AuthServiceError) Error() string { + return string(e) +} + +var ( + ErrOauthInvalidRequest = AuthServiceError("invalid_request") + ErrOauthInvalidClient = AuthServiceError("invalid_client") + ErrOauthInvalidGrant = AuthServiceError("invalid_grant") + ErrOauthInvalidScope = AuthServiceError("invalid_scope") + ErrOauthUnauthorizedClient = AuthServiceError("unauthorized_client") + ErrOauthUnsupportedGrantType = AuthServiceError("unsupported_grant_type") +)