diff --git a/web/auth/endpoints.go b/web/auth/endpoints.go index 3cb12e5..9b11f4e 100644 --- a/web/auth/endpoints.go +++ b/web/auth/endpoints.go @@ -353,6 +353,9 @@ func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m admin.LastLogin = u.P(time.Now()) admin.LastLoginIP = ip admin.LastLoginUA = ua + + default: + return nil, ErrAuthorizeInvalidRequest } // 生成会话 @@ -364,12 +367,7 @@ func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } - if user != nil { - session.UserID = &user.ID - } - if admin != nil { - session.AdminID = &admin.ID - } + if req.Remember { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) @@ -377,18 +375,20 @@ func authPassword(c *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m // 保存用户更新和会话 err = q.Q.Transaction(func(tx *q.Query) error { - if err := SaveSession(tx, session); err != nil { - return err - } if user != nil { if err := tx.User.Save(user); err != nil { return err } + session.UserID = &user.ID } if admin != nil { if err := tx.Admin.Save(admin); err != nil { return err } + session.AdminID = &admin.ID + } + if err := SaveSession(tx, session); err != nil { + return err } return nil })