// Package main provides FX providers for Auth Service. // This file creates the service inline to avoid importing internal packages. package main import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "net" "time" authv1 "git.dcentral.systems/toolz/goplt/api/proto/generated/auth/v1" "git.dcentral.systems/toolz/goplt/internal/ent/refreshtoken" "git.dcentral.systems/toolz/goplt/internal/infra/database" "git.dcentral.systems/toolz/goplt/pkg/config" "git.dcentral.systems/toolz/goplt/pkg/logger" "git.dcentral.systems/toolz/goplt/pkg/services" "github.com/golang-jwt/jwt/v5" "go.uber.org/fx" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" ) const ( accessTokenLifetime = 15 * time.Minute refreshTokenLifetime = 7 * 24 * time.Hour ) // authService provides authentication functionality. type authService struct { client *database.Client logger logger.Logger identityClient services.IdentityServiceClient authzClient services.AuthzServiceClient jwtSecret []byte accessTokenExpiry time.Duration refreshTokenExpiry time.Duration } // hashToken hashes a token using SHA256. func hashToken(token string) string { h := sha256.Sum256([]byte(token)) return hex.EncodeToString(h[:]) } // generateRefreshToken generates a random refresh token. func generateRefreshToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate token: %w", err) } return hex.EncodeToString(b), nil } // generateAccessToken generates a JWT access token. func (s *authService) generateAccessToken(userID, email string, roles []string) (string, int64, error) { expiresAt := time.Now().Add(s.accessTokenExpiry) claims := jwt.MapClaims{ "sub": userID, "email": email, "roles": roles, "exp": expiresAt.Unix(), "iat": time.Now().Unix(), "token_type": "access", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(s.jwtSecret) if err != nil { return "", 0, fmt.Errorf("failed to sign token: %w", err) } return tokenString, int64(s.accessTokenExpiry.Seconds()), nil } // generateRefreshToken generates a refresh token and stores it in the database. func (s *authService) generateRefreshToken(ctx context.Context, userID string) (string, error) { token, err := generateRefreshToken() if err != nil { return "", err } tokenHash := hashToken(token) expiresAt := time.Now().Add(s.refreshTokenExpiry) // Store refresh token in database _, err = s.client.RefreshToken.Create(). SetID(fmt.Sprintf("%d-%d", time.Now().Unix(), time.Now().UnixNano()%1000000)). SetUserID(userID). SetTokenHash(tokenHash). SetExpiresAt(expiresAt). Save(ctx) if err != nil { return "", fmt.Errorf("failed to store refresh token: %w", err) } return token, nil } // validateAccessToken validates a JWT access token. func (s *authService) validateAccessToken(tokenString string) (*jwt.Token, jwt.MapClaims, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return s.jwtSecret, nil }) if err != nil { return nil, nil, fmt.Errorf("failed to parse token: %w", err) } if !token.Valid { return nil, nil, fmt.Errorf("invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, nil, fmt.Errorf("invalid token claims") } // Check expiration if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { return nil, nil, fmt.Errorf("token expired") } } return token, claims, nil } // validateRefreshToken validates a refresh token. func (s *authService) validateRefreshToken(ctx context.Context, tokenString string) (string, error) { tokenHash := hashToken(tokenString) // Find refresh token by hash rt, err := s.client.RefreshToken.Query(). Where(refreshtoken.TokenHash(tokenHash)). Only(ctx) if err != nil { return "", fmt.Errorf("invalid refresh token") } // Check if token has expired if rt.ExpiresAt.Before(time.Now()) { // Delete expired token _ = s.client.RefreshToken.DeleteOneID(rt.ID).Exec(ctx) return "", fmt.Errorf("refresh token expired") } return rt.UserID, nil } // revokeRefreshToken revokes a refresh token. func (s *authService) revokeRefreshToken(ctx context.Context, tokenString string) error { tokenHash := hashToken(tokenString) // Find and delete refresh token rt, err := s.client.RefreshToken.Query(). Where(refreshtoken.TokenHash(tokenHash)). Only(ctx) if err != nil { // Token not found, consider it already revoked return nil } return s.client.RefreshToken.DeleteOneID(rt.ID).Exec(ctx) } // login authenticates a user and returns tokens. func (s *authService) login(ctx context.Context, email, password string) (*authv1.LoginResponse, error) { // Verify credentials with Identity Service user, err := s.identityClient.VerifyPassword(ctx, email, password) if err != nil { return nil, fmt.Errorf("invalid credentials") } // Get user roles from Authz Service roles := []string{} if s.authzClient != nil { userRoles, err := s.authzClient.GetUserRoles(ctx, user.ID) if err != nil { s.logger.Warn("Failed to get user roles", zap.String("user_id", user.ID), zap.Error(err), ) // Continue without roles rather than failing login } else { for _, role := range userRoles { roles = append(roles, role.Name) } } } // Generate tokens accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } refreshToken, err := s.generateRefreshToken(ctx, user.ID) if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } return &authv1.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: expiresIn, TokenType: "Bearer", }, nil } // refreshToken refreshes an access token. func (s *authService) refreshToken(ctx context.Context, refreshTokenString string) (*authv1.RefreshTokenResponse, error) { // Validate refresh token userID, err := s.validateRefreshToken(ctx, refreshTokenString) if err != nil { return nil, err } // Get user from Identity Service user, err := s.identityClient.GetUser(ctx, userID) if err != nil { return nil, fmt.Errorf("user not found") } // Get user roles from Authz Service roles := []string{} if s.authzClient != nil { userRoles, err := s.authzClient.GetUserRoles(ctx, user.ID) if err != nil { s.logger.Warn("Failed to get user roles", zap.String("user_id", user.ID), zap.Error(err), ) // Continue without roles rather than failing refresh } else { for _, role := range userRoles { roles = append(roles, role.Name) } } } // Generate new tokens accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } // Generate new refresh token (rotate) newRefreshToken, err := s.generateRefreshToken(ctx, user.ID) if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } // Revoke old refresh token _ = s.revokeRefreshToken(ctx, refreshTokenString) return &authv1.RefreshTokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, ExpiresIn: expiresIn, TokenType: "Bearer", }, nil } // validateToken validates a JWT token. func (s *authService) validateToken(tokenString string) (*authv1.ValidateTokenResponse, error) { _, claims, err := s.validateAccessToken(tokenString) if err != nil { return nil, err } userID, _ := claims["sub"].(string) email, _ := claims["email"].(string) exp, _ := claims["exp"].(float64) roles := []string{} if rolesClaim, ok := claims["roles"].([]interface{}); ok { for _, r := range rolesClaim { if role, ok := r.(string); ok { roles = append(roles, role) } } } return &authv1.ValidateTokenResponse{ UserId: userID, Email: email, Roles: roles, ExpiresAt: int64(exp), }, nil } // logout invalidates a refresh token. func (s *authService) logout(ctx context.Context, refreshTokenString string) error { return s.revokeRefreshToken(ctx, refreshTokenString) } // authServerImpl implements the AuthService gRPC server. type authServerImpl struct { authv1.UnimplementedAuthServiceServer service *authService logger *zap.Logger } // Login authenticates a user and returns tokens. func (s *authServerImpl) Login(ctx context.Context, req *authv1.LoginRequest) (*authv1.LoginResponse, error) { resp, err := s.service.login(ctx, req.Email, req.Password) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "login failed: %v", err) } return resp, nil } // RefreshToken refreshes an access token. func (s *authServerImpl) RefreshToken(ctx context.Context, req *authv1.RefreshTokenRequest) (*authv1.RefreshTokenResponse, error) { resp, err := s.service.refreshToken(ctx, req.RefreshToken) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "refresh failed: %v", err) } return resp, nil } // ValidateToken validates a JWT token. func (s *authServerImpl) ValidateToken(ctx context.Context, req *authv1.ValidateTokenRequest) (*authv1.ValidateTokenResponse, error) { resp, err := s.service.validateToken(req.Token) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "validation failed: %v", err) } return resp, nil } // Logout invalidates a refresh token. func (s *authServerImpl) Logout(ctx context.Context, req *authv1.LogoutRequest) (*authv1.LogoutResponse, error) { if err := s.service.logout(ctx, req.RefreshToken); err != nil { return nil, status.Errorf(codes.Internal, "logout failed: %v", err) } return &authv1.LogoutResponse{Success: true}, nil } // provideAuthService creates the auth service and gRPC server. func provideAuthService() fx.Option { return fx.Options( // Auth service fx.Provide(func( client *database.Client, log logger.Logger, identityClient services.IdentityServiceClient, authzClient services.AuthzServiceClient, cfg config.ConfigProvider, ) (*authService, error) { jwtSecret := cfg.GetString("auth.jwt_secret") if jwtSecret == "" { return nil, fmt.Errorf("auth.jwt_secret is required in configuration") } return &authService{ client: client, logger: log, identityClient: identityClient, authzClient: authzClient, jwtSecret: []byte(jwtSecret), accessTokenExpiry: accessTokenLifetime, refreshTokenExpiry: refreshTokenLifetime, }, nil }), // gRPC server implementation fx.Provide(func(authService *authService, log logger.Logger) (*authServerImpl, error) { zapLogger, _ := zap.NewProduction() return &authServerImpl{ service: authService, logger: zapLogger, }, nil }), // gRPC server wrapper fx.Provide(func( serverImpl *authServerImpl, cfg config.ConfigProvider, log logger.Logger, ) (*grpcServerWrapper, error) { port := cfg.GetInt("services.auth.port") if port == 0 { port = 8081 } addr := fmt.Sprintf("0.0.0.0:%d", port) listener, err := net.Listen("tcp", addr) if err != nil { return nil, fmt.Errorf("failed to listen on %s: %w", addr, err) } grpcServer := grpc.NewServer() authv1.RegisterAuthServiceServer(grpcServer, serverImpl) // Register health service healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) // Set serving status for the default service (empty string) - this is what Consul checks healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) // Also set for the specific service name healthServer.SetServingStatus("auth.v1.AuthService", grpc_health_v1.HealthCheckResponse_SERVING) // Register reflection for grpcurl reflection.Register(grpcServer) return &grpcServerWrapper{ server: grpcServer, listener: listener, port: port, logger: log, }, nil }), ) }