package logger import ( "context" "net/http" "net/http/httptest" "testing" "git.dcentral.systems/toolz/goplt/pkg/logger" "github.com/gin-gonic/gin" ) func TestRequestIDMiddleware_GenerateNewID(t *testing.T) { // Cannot run in parallel: gin.SetMode() modifies global state gin.SetMode(gin.TestMode) router := gin.New() router.Use(RequestIDMiddleware()) router.GET("/test", func(c *gin.Context) { requestID := RequestIDFromContext(c.Request.Context()) if requestID == "" { t.Error("Request ID should be generated") } c.JSON(http.StatusOK, gin.H{"request_id": requestID}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } // Check that request ID is in response header requestID := w.Header().Get(RequestIDHeader) if requestID == "" { t.Error("Request ID should be in response header") } } func TestRequestIDMiddleware_UseExistingID(t *testing.T) { // Cannot run in parallel: gin.SetMode() modifies global state gin.SetMode(gin.TestMode) router := gin.New() router.Use(RequestIDMiddleware()) router.GET("/test", func(c *gin.Context) { requestID := RequestIDFromContext(c.Request.Context()) if requestID != "existing-id" { t.Errorf("Expected request ID 'existing-id', got %q", requestID) } c.JSON(http.StatusOK, gin.H{"request_id": requestID}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.Header.Set(RequestIDHeader, "existing-id") w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } // Check that the same request ID is in response header requestID := w.Header().Get(RequestIDHeader) if requestID != "existing-id" { t.Errorf("Expected request ID 'existing-id' in header, got %q", requestID) } } func TestRequestIDFromContext(t *testing.T) { t.Parallel() tests := []struct { name string ctx context.Context want string wantEmpty bool }{ { name: "with request ID", ctx: context.WithValue(context.Background(), requestIDKey, "test-id"), want: "test-id", wantEmpty: false, }, { name: "without request ID", ctx: context.Background(), want: "", wantEmpty: true, }, { name: "with wrong type", ctx: context.WithValue(context.Background(), requestIDKey, 123), want: "", wantEmpty: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := RequestIDFromContext(tt.ctx) if tt.wantEmpty && got != "" { t.Errorf("RequestIDFromContext() = %q, want empty string", got) } if !tt.wantEmpty && got != tt.want { t.Errorf("RequestIDFromContext() = %q, want %q", got, tt.want) } }) } } func TestSetRequestID(t *testing.T) { t.Parallel() ctx := context.Background() newCtx := SetRequestID(ctx, "test-id") requestID := RequestIDFromContext(newCtx) if requestID != "test-id" { t.Errorf("SetRequestID failed, got %q, want %q", requestID, "test-id") } // Original context should not have the ID originalID := RequestIDFromContext(ctx) if originalID != "" { t.Error("Original context should not have request ID") } } func TestSetUserID(t *testing.T) { t.Parallel() ctx := context.Background() newCtx := SetUserID(ctx, "user-123") userID := UserIDFromContext(newCtx) if userID != "user-123" { t.Errorf("SetUserID failed, got %q, want %q", userID, "user-123") } // Original context should not have the ID originalID := UserIDFromContext(ctx) if originalID != "" { t.Error("Original context should not have user ID") } } func TestUserIDFromContext(t *testing.T) { t.Parallel() tests := []struct { name string ctx context.Context want string wantEmpty bool }{ { name: "with user ID", ctx: context.WithValue(context.Background(), userIDKey, "user-123"), want: "user-123", wantEmpty: false, }, { name: "without user ID", ctx: context.Background(), want: "", wantEmpty: true, }, { name: "with wrong type", ctx: context.WithValue(context.Background(), userIDKey, 123), want: "", wantEmpty: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := UserIDFromContext(tt.ctx) if tt.wantEmpty && got != "" { t.Errorf("UserIDFromContext() = %q, want empty string", got) } if !tt.wantEmpty && got != tt.want { t.Errorf("UserIDFromContext() = %q, want %q", got, tt.want) } }) } } func TestLoggingMiddleware(t *testing.T) { // Cannot run in parallel: gin.SetMode() modifies global state // Create a mock logger that records log calls mockLog := &mockLoggerForMiddleware{} mockLog.logs = make([]logEntry, 0) gin.SetMode(gin.TestMode) router := gin.New() router.Use(RequestIDMiddleware()) router.Use(LoggingMiddleware(mockLog)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "test"}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } // Verify that logs were recorded if len(mockLog.logs) < 2 { t.Fatalf("Expected at least 2 log entries (request + response), got %d", len(mockLog.logs)) } // Check request log requestLog := mockLog.logs[0] if requestLog.message != "HTTP request" { t.Errorf("Expected 'HTTP request' log, got %q", requestLog.message) } // Check response log responseLog := mockLog.logs[1] if responseLog.message != "HTTP response" { t.Errorf("Expected 'HTTP response' log, got %q", responseLog.message) } } func TestLoggingMiddleware_WithRequestID(t *testing.T) { // Cannot run in parallel: gin.SetMode() modifies global state // Create a mock logger mockLog := &mockLoggerForMiddleware{} mockLog.logs = make([]logEntry, 0) gin.SetMode(gin.TestMode) router := gin.New() router.Use(RequestIDMiddleware()) router.Use(LoggingMiddleware(mockLog)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "test"}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.Header.Set(RequestIDHeader, "custom-request-id") w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } // Verify that request ID is in the logs if len(mockLog.logs) < 2 { t.Fatalf("Expected at least 2 log entries, got %d", len(mockLog.logs)) } // The logger should have received context with request ID // (We can't easily verify this without exposing internal state, but we can check the logs were made) } func TestRequestIDMiddleware_MultipleRequests(t *testing.T) { // Cannot run in parallel: gin.SetMode() modifies global state gin.SetMode(gin.TestMode) router := gin.New() router.Use(RequestIDMiddleware()) router.GET("/test", func(c *gin.Context) { requestID := RequestIDFromContext(c.Request.Context()) c.JSON(http.StatusOK, gin.H{"request_id": requestID}) }) // Make multiple requests requestIDs := make(map[string]bool) for i := 0; i < 10; i++ { req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) requestID := w.Header().Get(RequestIDHeader) if requestID == "" { t.Errorf("Request %d: Request ID should be generated", i) continue } if requestIDs[requestID] { t.Errorf("Request ID %q was duplicated", requestID) } requestIDs[requestID] = true } } func TestSetRequestID_Overwrite(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), requestIDKey, "old-id") newCtx := SetRequestID(ctx, "new-id") requestID := RequestIDFromContext(newCtx) if requestID != "new-id" { t.Errorf("SetRequestID failed to overwrite, got %q, want %q", requestID, "new-id") } } func TestSetUserID_Overwrite(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), userIDKey, "old-user") newCtx := SetUserID(ctx, "new-user") userID := UserIDFromContext(newCtx) if userID != "new-user" { t.Errorf("SetUserID failed to overwrite, got %q, want %q", userID, "new-user") } } // mockLoggerForMiddleware is a mock logger that records log calls for testing type mockLoggerForMiddleware struct { logs []logEntry } type logEntry struct { message string fields []logger.Field } func (m *mockLoggerForMiddleware) Debug(msg string, fields ...logger.Field) { m.logs = append(m.logs, logEntry{message: msg, fields: fields}) } func (m *mockLoggerForMiddleware) Info(msg string, fields ...logger.Field) { m.logs = append(m.logs, logEntry{message: msg, fields: fields}) } func (m *mockLoggerForMiddleware) Warn(msg string, fields ...logger.Field) { m.logs = append(m.logs, logEntry{message: msg, fields: fields}) } func (m *mockLoggerForMiddleware) Error(msg string, fields ...logger.Field) { m.logs = append(m.logs, logEntry{message: msg, fields: fields}) } func (m *mockLoggerForMiddleware) With(_ ...logger.Field) logger.Logger { return m } func (m *mockLoggerForMiddleware) WithContext(_ context.Context) logger.Logger { return m }