- Remove t.Parallel() from tests that use gin.SetMode()
- gin.SetMode() modifies global state and is not thread-safe
- Tests affected:
* TestRequestIDMiddleware_GenerateNewID
* TestRequestIDMiddleware_UseExistingID
* TestLoggingMiddleware
* TestLoggingMiddleware_WithRequestID
* TestRequestIDMiddleware_MultipleRequests
- Add comments explaining why these tests cannot run in parallel
- All tests now pass with race detector enabled (-race flag)
This fixes data race warnings that were occurring when running tests
with the race detector, specifically when multiple tests tried to set
Gin's mode concurrently.
360 lines
9.1 KiB
Go
360 lines
9.1 KiB
Go
package logger
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"git.dcentral.systems/toolz/goplt/pkg/logger"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func TestRequestIDMiddleware_GenerateNewID(t *testing.T) {
|
|
// Cannot run in parallel: gin.SetMode() modifies global state
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestIDMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
requestID := RequestIDFromContext(c.Request.Context())
|
|
if requestID == "" {
|
|
t.Error("Request ID should be generated")
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Check that request ID is in response header
|
|
requestID := w.Header().Get(RequestIDHeader)
|
|
if requestID == "" {
|
|
t.Error("Request ID should be in response header")
|
|
}
|
|
}
|
|
|
|
func TestRequestIDMiddleware_UseExistingID(t *testing.T) {
|
|
// Cannot run in parallel: gin.SetMode() modifies global state
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestIDMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
requestID := RequestIDFromContext(c.Request.Context())
|
|
if requestID != "existing-id" {
|
|
t.Errorf("Expected request ID 'existing-id', got %q", requestID)
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set(RequestIDHeader, "existing-id")
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Check that the same request ID is in response header
|
|
requestID := w.Header().Get(RequestIDHeader)
|
|
if requestID != "existing-id" {
|
|
t.Errorf("Expected request ID 'existing-id' in header, got %q", requestID)
|
|
}
|
|
}
|
|
|
|
func TestRequestIDFromContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
ctx context.Context
|
|
want string
|
|
wantEmpty bool
|
|
}{
|
|
{
|
|
name: "with request ID",
|
|
ctx: context.WithValue(context.Background(), RequestIDKey(), "test-id"),
|
|
want: "test-id",
|
|
wantEmpty: false,
|
|
},
|
|
{
|
|
name: "without request ID",
|
|
ctx: context.Background(),
|
|
want: "",
|
|
wantEmpty: true,
|
|
},
|
|
{
|
|
name: "with wrong type",
|
|
ctx: context.WithValue(context.Background(), RequestIDKey(), 123),
|
|
want: "",
|
|
wantEmpty: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := RequestIDFromContext(tt.ctx)
|
|
if tt.wantEmpty && got != "" {
|
|
t.Errorf("RequestIDFromContext() = %q, want empty string", got)
|
|
}
|
|
if !tt.wantEmpty && got != tt.want {
|
|
t.Errorf("RequestIDFromContext() = %q, want %q", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetRequestID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
newCtx := SetRequestID(ctx, "test-id")
|
|
|
|
requestID := RequestIDFromContext(newCtx)
|
|
if requestID != "test-id" {
|
|
t.Errorf("SetRequestID failed, got %q, want %q", requestID, "test-id")
|
|
}
|
|
|
|
// Original context should not have the ID
|
|
originalID := RequestIDFromContext(ctx)
|
|
if originalID != "" {
|
|
t.Error("Original context should not have request ID")
|
|
}
|
|
}
|
|
|
|
func TestSetUserID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
newCtx := SetUserID(ctx, "user-123")
|
|
|
|
userID := UserIDFromContext(newCtx)
|
|
if userID != "user-123" {
|
|
t.Errorf("SetUserID failed, got %q, want %q", userID, "user-123")
|
|
}
|
|
|
|
// Original context should not have the ID
|
|
originalID := UserIDFromContext(ctx)
|
|
if originalID != "" {
|
|
t.Error("Original context should not have user ID")
|
|
}
|
|
}
|
|
|
|
func TestUserIDFromContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
ctx context.Context
|
|
want string
|
|
wantEmpty bool
|
|
}{
|
|
{
|
|
name: "with user ID",
|
|
ctx: context.WithValue(context.Background(), UserIDKey(), "user-123"),
|
|
want: "user-123",
|
|
wantEmpty: false,
|
|
},
|
|
{
|
|
name: "without user ID",
|
|
ctx: context.Background(),
|
|
want: "",
|
|
wantEmpty: true,
|
|
},
|
|
{
|
|
name: "with wrong type",
|
|
ctx: context.WithValue(context.Background(), UserIDKey(), 123),
|
|
want: "",
|
|
wantEmpty: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := UserIDFromContext(tt.ctx)
|
|
if tt.wantEmpty && got != "" {
|
|
t.Errorf("UserIDFromContext() = %q, want empty string", got)
|
|
}
|
|
if !tt.wantEmpty && got != tt.want {
|
|
t.Errorf("UserIDFromContext() = %q, want %q", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLoggingMiddleware(t *testing.T) {
|
|
// Cannot run in parallel: gin.SetMode() modifies global state
|
|
|
|
// Create a mock logger that records log calls
|
|
mockLog := &mockLoggerForMiddleware{}
|
|
mockLog.logs = make([]logEntry, 0)
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestIDMiddleware())
|
|
router.Use(LoggingMiddleware(mockLog))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "test"})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Verify that logs were recorded
|
|
if len(mockLog.logs) < 2 {
|
|
t.Fatalf("Expected at least 2 log entries (request + response), got %d", len(mockLog.logs))
|
|
}
|
|
|
|
// Check request log
|
|
requestLog := mockLog.logs[0]
|
|
if requestLog.message != "HTTP request" {
|
|
t.Errorf("Expected 'HTTP request' log, got %q", requestLog.message)
|
|
}
|
|
|
|
// Check response log
|
|
responseLog := mockLog.logs[1]
|
|
if responseLog.message != "HTTP response" {
|
|
t.Errorf("Expected 'HTTP response' log, got %q", responseLog.message)
|
|
}
|
|
}
|
|
|
|
func TestLoggingMiddleware_WithRequestID(t *testing.T) {
|
|
// Cannot run in parallel: gin.SetMode() modifies global state
|
|
|
|
// Create a mock logger
|
|
mockLog := &mockLoggerForMiddleware{}
|
|
mockLog.logs = make([]logEntry, 0)
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestIDMiddleware())
|
|
router.Use(LoggingMiddleware(mockLog))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "test"})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set(RequestIDHeader, "custom-request-id")
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Verify that request ID is in the logs
|
|
if len(mockLog.logs) < 2 {
|
|
t.Fatalf("Expected at least 2 log entries, got %d", len(mockLog.logs))
|
|
}
|
|
|
|
// The logger should have received context with request ID
|
|
// (We can't easily verify this without exposing internal state, but we can check the logs were made)
|
|
}
|
|
|
|
func TestRequestIDMiddleware_MultipleRequests(t *testing.T) {
|
|
// Cannot run in parallel: gin.SetMode() modifies global state
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestIDMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
requestID := RequestIDFromContext(c.Request.Context())
|
|
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
|
})
|
|
|
|
// Make multiple requests
|
|
requestIDs := make(map[string]bool)
|
|
for i := 0; i < 10; i++ {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
requestID := w.Header().Get(RequestIDHeader)
|
|
if requestID == "" {
|
|
t.Errorf("Request %d: Request ID should be generated", i)
|
|
continue
|
|
}
|
|
|
|
if requestIDs[requestID] {
|
|
t.Errorf("Request ID %q was duplicated", requestID)
|
|
}
|
|
requestIDs[requestID] = true
|
|
}
|
|
}
|
|
|
|
func TestSetRequestID_Overwrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.WithValue(context.Background(), RequestIDKey(), "old-id")
|
|
newCtx := SetRequestID(ctx, "new-id")
|
|
|
|
requestID := RequestIDFromContext(newCtx)
|
|
if requestID != "new-id" {
|
|
t.Errorf("SetRequestID failed to overwrite, got %q, want %q", requestID, "new-id")
|
|
}
|
|
}
|
|
|
|
func TestSetUserID_Overwrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.WithValue(context.Background(), UserIDKey(), "old-user")
|
|
newCtx := SetUserID(ctx, "new-user")
|
|
|
|
userID := UserIDFromContext(newCtx)
|
|
if userID != "new-user" {
|
|
t.Errorf("SetUserID failed to overwrite, got %q, want %q", userID, "new-user")
|
|
}
|
|
}
|
|
|
|
// mockLoggerForMiddleware is a mock logger that records log calls for testing
|
|
type mockLoggerForMiddleware struct {
|
|
logs []logEntry
|
|
}
|
|
|
|
type logEntry struct {
|
|
message string
|
|
fields []logger.Field
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) Debug(msg string, fields ...logger.Field) {
|
|
m.logs = append(m.logs, logEntry{message: msg, fields: fields})
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) Info(msg string, fields ...logger.Field) {
|
|
m.logs = append(m.logs, logEntry{message: msg, fields: fields})
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) Warn(msg string, fields ...logger.Field) {
|
|
m.logs = append(m.logs, logEntry{message: msg, fields: fields})
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) Error(msg string, fields ...logger.Field) {
|
|
m.logs = append(m.logs, logEntry{message: msg, fields: fields})
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) With(fields ...logger.Field) logger.Logger {
|
|
return m
|
|
}
|
|
|
|
func (m *mockLoggerForMiddleware) WithContext(ctx context.Context) logger.Logger {
|
|
return m
|
|
}
|