package auth import ( "encoding/base64" "log/slog" "net/http" "os" "proxy-server/server/pkg/resp" "slices" "strings" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) func middleware(c *gin.Context) { auth := check(c) if auth { secret, err := getSecret(c) if err != nil { slog.Error("认证失败", err) fail400(c, err) return } err = authenticate(c, secret) if err != nil { slog.Error("认证失败", err) fail401(c, err) return } } c.Next() } var ( securedPaths = []string{ "/connect", } ) func check(c *gin.Context) bool { path := c.Request.URL.Path if slices.Contains(securedPaths, path) { return true } return false } func getSecret(c *gin.Context) (string, error) { // 获取认证信息 header := strings.Split(c.GetHeader("Authorization"), " ") if len(header) != 2 { return "", errors.New("无认证信息") } // 检查认证类型 schema := header[0] if schema != "Secret" { return "", errors.New("不支持的认证类型 " + schema) } // 解码密钥 parameters := header[1] result, err := base64.URLEncoding.DecodeString(parameters) if err != nil { return "", errors.Wrap(err, "密钥解析失败") } return string(result), nil } func authenticate(_ *gin.Context, secret string) error { if secret != os.Getenv("SECRET") { return errors.New("认证失败") } return nil } func fail400(c *gin.Context, err error) { _ = c.Error(err) c.Abort() c.JSON( http.StatusBadRequest, resp.Fail(err.Error()), ) } func fail401(c *gin.Context, err error) { _ = c.Error(err) c.Abort() c.JSON( http.StatusUnauthorized, resp.Fail(err.Error()), ) }