diff --git a/.env.example b/.env.example index ad5cb43..38647ab 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,9 @@ # 应用配置 RUN_MODE=development DEBUG_HTTP_DUMP=false +UPLOAD_DIR=./data/uploads +UPLOAD_PUBLIC_BASE_URL= +ARTICLE_UPLOAD_MAX_BYTES=5242880 # 数据库配置 DB_HOST=127.0.0.1 diff --git a/docker-compose.yaml b/docker-compose.yaml index 5b73742..537f53a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -40,6 +40,13 @@ services: depends_on: - redis + gost: + image: gogost/gost + network_mode: host + command: + - -api + - :8900 + volumes: postgres_data: redis_data: diff --git a/docs/api.yaml b/docs/api.yaml index 19fef9a..fb0076e 100644 --- a/docs/api.yaml +++ b/docs/api.yaml @@ -1902,6 +1902,27 @@ components: required: - id + ArticleUploadResponse: + type: object + properties: + url: + type: string + path: + type: string + original_name: + type: string + size: + type: integer + format: int64 + mime_type: + type: string + required: + - url + - path + - original_name + - size + - mime_type + PageArticleGroupRequest: allOf: - $ref: "#/components/schemas/PageRequest" @@ -5776,6 +5797,34 @@ paths: default: $ref: "#/components/responses/PlainTextError" + /api/admin/article/upload: + post: + tags: [admin/article] + summary: 上传文章图片 + security: + - bearerAuth: [] + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + properties: + file: + type: string + format: binary + required: + - file + responses: + "200": + description: 上传结果 + content: + application/json: + schema: + $ref: "#/components/schemas/ArticleUploadResponse" + default: + $ref: "#/components/responses/PlainTextError" + /api/admin/article-group/page: post: tags: [admin/article-group] diff --git a/pkg/env/env.go b/pkg/env/env.go index 81b51eb..0e94616 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -18,12 +18,15 @@ const ( ) var ( - RunMode = RunModeProd - LogLevel = slog.LevelDebug - TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 900 秒(15 分钟) - SessionAccessExpire = 60 * 60 * 2 // 访问令牌过期时间,单位秒。默认 2 小时 - SessionRefreshExpire = 60 * 60 * 24 * 7 // 刷新令牌过期时间,单位秒。默认 7 天 - DebugHttpDump = false // 是否打印请求和响应的原始数据 + RunMode = RunModeProd + LogLevel = slog.LevelDebug + TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 900 秒(15 分钟) + SessionAccessExpire = 60 * 60 * 2 // 访问令牌过期时间,单位秒。默认 2 小时 + SessionRefreshExpire = 60 * 60 * 24 * 7 // 刷新令牌过期时间,单位秒。默认 7 天 + DebugHttpDump = false // 是否打印请求和响应的原始数据 + UploadDir = "./data/uploads" + UploadPublicBaseURL = "" + ArticleUploadMaxBytes = 5 * 1024 * 1024 DbHost = "localhost" DbPort = "5432" @@ -106,6 +109,9 @@ func Init() { errs = append(errs, parse(&SessionAccessExpire, "SESSION_ACCESS_EXPIRE", true, nil)) errs = append(errs, parse(&SessionRefreshExpire, "SESSION_REFRESH_EXPIRE", true, nil)) errs = append(errs, parse(&DebugHttpDump, "DEBUG_HTTP_DUMP", true, nil)) + errs = append(errs, parse(&UploadDir, "UPLOAD_DIR", true, nil)) + errs = append(errs, parse(&UploadPublicBaseURL, "UPLOAD_PUBLIC_BASE_URL", true, nil)) + errs = append(errs, parse(&ArticleUploadMaxBytes, "ARTICLE_UPLOAD_MAX_BYTES", true, nil)) errs = append(errs, parse(&DbHost, "DB_HOST", true, nil)) errs = append(errs, parse(&DbPort, "DB_PORT", true, nil)) diff --git a/web/handlers/article.go b/web/handlers/article.go index aea8fef..5b5e994 100644 --- a/web/handlers/article.go +++ b/web/handlers/article.go @@ -1,10 +1,12 @@ package handlers import ( + "platform/pkg/env" "platform/web/auth" "platform/web/core" g "platform/web/globals" s "platform/web/services" + "strings" "github.com/gofiber/fiber/v2" ) @@ -126,3 +128,43 @@ func DeleteArticle(c *fiber.Ctx) error { return s.Article.Delete(req.Id) } + +func UploadArticleImage(c *fiber.Ctx) error { + _, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeArticleWrite) + if err != nil { + return err + } + + fileHeader, err := c.FormFile("file") + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "缺少上传文件 file") + } + + result, err := s.Article.UploadImage(fileHeader, articleUploadBaseURL(c)) + if err != nil { + return err + } + + return c.JSON(result) +} + +func articleUploadBaseURL(c *fiber.Ctx) string { + if env.UploadPublicBaseURL != "" { + return strings.TrimRight(env.UploadPublicBaseURL, "/") + } + + scheme := c.Protocol() + if forwardedProto := c.Get("X-Forwarded-Proto"); forwardedProto != "" { + scheme = strings.TrimSpace(strings.Split(forwardedProto, ",")[0]) + } + + host := c.Get(fiber.HeaderHost) + if forwardedHost := c.Get("X-Forwarded-Host"); forwardedHost != "" { + host = strings.TrimSpace(strings.Split(forwardedHost, ",")[0]) + } + + if host == "" { + return "" + } + return scheme + "://" + host +} diff --git a/web/middlewares.go b/web/middlewares.go index 1724424..80ccc70 100644 --- a/web/middlewares.go +++ b/web/middlewares.go @@ -1,11 +1,14 @@ package web import ( + "net/http" + "platform/pkg/env" "platform/web/auth" "github.com/gofiber/contrib/otelfiber/v2" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/recover" "github.com/gofiber/fiber/v2/middleware/requestid" @@ -66,6 +69,11 @@ func ApplyMiddlewares(app *fiber.App) { }, })) + // static uploads + app.Use("/uploads", filesystem.New(filesystem.Config{ + Root: http.Dir(env.UploadDir), + })) + // authenticate app.Use(auth.Authenticate()) } diff --git a/web/routes.go b/web/routes.go index 52f7ac4..1913a6b 100644 --- a/web/routes.go +++ b/web/routes.go @@ -300,6 +300,7 @@ func adminRouter(api fiber.Router) { article.Post("/create", handlers.CreateArticle) article.Post("/update", handlers.UpdateArticle) article.Post("/remove", handlers.DeleteArticle) + article.Post("/upload", handlers.UploadArticleImage) // article-group 文档分组 var articleGroup = api.Group("/article-group") diff --git a/web/services/article.go b/web/services/article.go index 735c87b..ac82272 100644 --- a/web/services/article.go +++ b/web/services/article.go @@ -2,12 +2,22 @@ package services import ( "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path" + "path/filepath" + "platform/pkg/env" "platform/pkg/u" "platform/web/core" m "platform/web/models" q "platform/web/queries" + "strings" "time" + "github.com/google/uuid" "gorm.io/gen/field" "gorm.io/gorm" ) @@ -16,6 +26,128 @@ var Article = &articleService{} type articleService struct{} +var articleUploadMimeExt = map[string]string{ + "image/gif": ".gif", + "image/jpeg": ".jpg", + "image/png": ".png", + "image/webp": ".webp", +} + +type ArticleUploadResult struct { + URL string `json:"url"` + Path string `json:"path"` + OriginalName string `json:"original_name"` + Size int64 `json:"size"` + MimeType string `json:"mime_type"` +} + +func (s *articleService) UploadImage(fileHeader *multipart.FileHeader, baseURL string) (*ArticleUploadResult, error) { + if fileHeader == nil { + return nil, core.NewBizErr("缺少上传文件") + } + if fileHeader.Size > int64(env.ArticleUploadMaxBytes) { + return nil, core.NewBizErr(fmt.Sprintf("图片大小不能超过 %s", formatUploadSizeLimit(env.ArticleUploadMaxBytes))) + } + + mimeType, ext, err := detectArticleImage(fileHeader) + if err != nil { + return nil, err + } + + now := time.Now() + year := now.Format("2006") + month := now.Format("01") + fileName := uuid.NewString() + ext + relativePath := path.Join("/uploads", "article", year, month, fileName) + targetDir := filepath.Join(env.UploadDir, "article", year, month) + finalPath := filepath.Join(targetDir, fileName) + + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return nil, core.NewServErr("创建上传目录失败", err) + } + + src, err := fileHeader.Open() + if err != nil { + return nil, core.NewServErr("打开上传文件失败", err) + } + defer src.Close() + + tmp, err := os.CreateTemp(targetDir, "upload-*"+ext) + if err != nil { + return nil, core.NewServErr("创建临时文件失败", err) + } + + tmpPath := tmp.Name() + finished := false + defer func() { + if !finished { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + }() + + limitedReader := &io.LimitedReader{R: src, N: int64(env.ArticleUploadMaxBytes) + 1} + written, err := io.Copy(tmp, limitedReader) + if err != nil { + return nil, core.NewServErr("保存上传文件失败", err) + } + if written > int64(env.ArticleUploadMaxBytes) { + return nil, core.NewBizErr(fmt.Sprintf("图片大小不能超过 %s", formatUploadSizeLimit(env.ArticleUploadMaxBytes))) + } + if err := tmp.Close(); err != nil { + return nil, core.NewServErr("关闭临时文件失败", err) + } + if err := os.Rename(tmpPath, finalPath); err != nil { + return nil, core.NewServErr("保存上传文件失败", err) + } + finished = true + + cleanBaseURL := strings.TrimRight(baseURL, "/") + url := relativePath + if cleanBaseURL != "" { + url = cleanBaseURL + relativePath + } + + return &ArticleUploadResult{ + URL: url, + Path: relativePath, + OriginalName: filepath.Base(fileHeader.Filename), + Size: written, + MimeType: mimeType, + }, nil +} + +func detectArticleImage(fileHeader *multipart.FileHeader) (string, string, error) { + file, err := fileHeader.Open() + if err != nil { + return "", "", core.NewServErr("打开上传文件失败", err) + } + defer file.Close() + + buf := make([]byte, 512) + n, err := io.ReadFull(file, buf) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return "", "", core.NewServErr("读取上传文件失败", err) + } + + mimeType := http.DetectContentType(buf[:n]) + ext, ok := articleUploadMimeExt[mimeType] + if !ok { + return "", "", core.NewBizErr("仅支持 JPG、PNG、WEBP、GIF 图片") + } + return mimeType, ext, nil +} + +func formatUploadSizeLimit(bytes int) string { + if bytes%(1024*1024) == 0 { + return fmt.Sprintf("%d MB", bytes/(1024*1024)) + } + if bytes%1024 == 0 { + return fmt.Sprintf("%d KB", bytes/1024) + } + return fmt.Sprintf("%d bytes", bytes) +} + func (s *articleService) Page(req *PageArticleReq) (result []*m.Article, count int64, err error) { do := q.Article.Where() if req.Keyword != nil && *req.Keyword != "" { diff --git a/web/services/article_upload_test.go b/web/services/article_upload_test.go new file mode 100644 index 0000000..37e26f7 --- /dev/null +++ b/web/services/article_upload_test.go @@ -0,0 +1,144 @@ +package services + +import ( + "bytes" + "encoding/base64" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "platform/pkg/env" +) + +func TestArticleUploadImageSuccess(t *testing.T) { + restore := snapshotUploadEnv() + defer restore() + + env.UploadDir = t.TempDir() + env.ArticleUploadMaxBytes = 5 * 1024 * 1024 + + fileHeader := newMultipartFileHeader(t, "file", "pixel.png", mustDecodeBase64(t, onePixelPNGBase64)) + + result, err := Article.UploadImage(fileHeader, "https://example.com") + if err != nil { + t.Fatalf("UploadImage returned error: %v", err) + } + + if result.MimeType != "image/png" { + t.Fatalf("unexpected mime type: %s", result.MimeType) + } + if !strings.HasPrefix(result.Path, "/uploads/article/") { + t.Fatalf("unexpected path: %s", result.Path) + } + if result.URL != "https://example.com"+result.Path { + t.Fatalf("unexpected url: %s", result.URL) + } + if result.OriginalName != "pixel.png" { + t.Fatalf("unexpected original name: %s", result.OriginalName) + } + + savedPath := filepath.Join(env.UploadDir, filepath.FromSlash(strings.TrimPrefix(result.Path, "/uploads/"))) + info, err := os.Stat(savedPath) + if err != nil { + t.Fatalf("saved file not found: %v", err) + } + if info.Size() != result.Size { + t.Fatalf("unexpected saved size: got %d want %d", info.Size(), result.Size) + } +} + +func TestArticleUploadImageRejectsUnsupportedType(t *testing.T) { + restore := snapshotUploadEnv() + defer restore() + + env.UploadDir = t.TempDir() + env.ArticleUploadMaxBytes = 5 * 1024 * 1024 + + fileHeader := newMultipartFileHeader(t, "file", "note.txt", []byte("not an image")) + + _, err := Article.UploadImage(fileHeader, "https://example.com") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "仅支持 JPG、PNG、WEBP、GIF 图片") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestArticleUploadImageRejectsOversizeFile(t *testing.T) { + restore := snapshotUploadEnv() + defer restore() + + env.UploadDir = t.TempDir() + env.ArticleUploadMaxBytes = 8 + + fileHeader := newMultipartFileHeader(t, "file", "large.png", bytes.Repeat([]byte("a"), 9)) + + _, err := Article.UploadImage(fileHeader, "https://example.com") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "图片大小不能超过") { + t.Fatalf("unexpected error: %v", err) + } +} + +func newMultipartFileHeader(t *testing.T, fieldName string, fileName string, content []byte) *multipart.FileHeader { + t.Helper() + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + part, err := writer.CreateFormFile(fieldName, fileName) + if err != nil { + t.Fatalf("CreateFormFile failed: %v", err) + } + if _, err := part.Write(content); err != nil { + t.Fatalf("Write content failed: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("Close multipart writer failed: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/", &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + if err := req.ParseMultipartForm(int64(body.Len()) + 1024); err != nil { + t.Fatalf("ParseMultipartForm failed: %v", err) + } + + file, fileHeader, err := req.FormFile(fieldName) + if err != nil { + t.Fatalf("FormFile failed: %v", err) + } + _ = file.Close() + + return fileHeader +} + +func mustDecodeBase64(t *testing.T, value string) []byte { + t.Helper() + + data, err := base64.StdEncoding.DecodeString(value) + if err != nil { + t.Fatalf("DecodeString failed: %v", err) + } + return data +} + +func snapshotUploadEnv() func() { + uploadDir := env.UploadDir + uploadPublicBaseURL := env.UploadPublicBaseURL + articleUploadMaxBytes := env.ArticleUploadMaxBytes + + return func() { + env.UploadDir = uploadDir + env.UploadPublicBaseURL = uploadPublicBaseURL + env.ArticleUploadMaxBytes = articleUploadMaxBytes + } +} + +const onePixelPNGBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+a7KQAAAAASUVORK5CYII=" diff --git a/web/web.go b/web/web.go index a283601..33624fc 100644 --- a/web/web.go +++ b/web/web.go @@ -53,7 +53,6 @@ func RunApp(pCtx context.Context) error { var fs embed.FS func RunWeb(ctx context.Context) error { - fiber := fiber.New(fiber.Config{ ProxyHeader: fiber.HeaderXForwardedFor, ErrorHandler: ErrorHandler,