// Package main provides FX providers for Auth Service. // This file creates the service inline to avoid importing internal packages. package main import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "net" "time" authv1 "git.dcentral.systems/toolz/goplt/api/proto/generated/auth/v1" "git.dcentral.systems/toolz/goplt/internal/ent" "git.dcentral.systems/toolz/goplt/pkg/config" "git.dcentral.systems/toolz/goplt/pkg/logger" "git.dcentral.systems/toolz/goplt/pkg/services" "github.com/golang-jwt/jwt/v5" "go.uber.org/fx" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" ) const ( accessTokenLifetime = 15 * time.Minute refreshTokenLifetime = 7 * 24 * time.Hour ) // authService provides authentication functionality. type authService struct { client *ent.Client logger logger.Logger identityClient services.IdentityServiceClient jwtSecret []byte accessTokenExpiry time.Duration refreshTokenExpiry time.Duration } // hashToken hashes a token using SHA256. func hashToken(token string) string { h := sha256.Sum256([]byte(token)) return hex.EncodeToString(h[:]) } // generateRefreshToken generates a random refresh token. func generateRefreshToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate token: %w", err) } return hex.EncodeToString(b), nil } // generateAccessToken generates a JWT access token. func (s *authService) generateAccessToken(userID, email string, roles []string) (string, int64, error) { expiresAt := time.Now().Add(s.accessTokenExpiry) claims := jwt.MapClaims{ "sub": userID, "email": email, "roles": roles, "exp": expiresAt.Unix(), "iat": time.Now().Unix(), "token_type": "access", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(s.jwtSecret) if err != nil { return "", 0, fmt.Errorf("failed to sign token: %w", err) } return tokenString, int64(s.accessTokenExpiry.Seconds()), nil } // generateRefreshToken generates a refresh token and stores it in the database. // Note: This is a simplified version - RefreshToken entity needs to be generated first func (s *authService) generateRefreshToken(ctx context.Context, userID string) (string, error) { token, err := generateRefreshToken() if err != nil { return "", err } // TODO: Store refresh token in database using RefreshToken entity once generated // For now, we'll just return the token // tokenHash := hashToken(token) // expiresAt := time.Now().Add(s.refreshTokenExpiry) // _, err = s.client.RefreshToken.Create()... return token, nil } // validateAccessToken validates a JWT access token. func (s *authService) validateAccessToken(tokenString string) (*jwt.Token, jwt.MapClaims, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return s.jwtSecret, nil }) if err != nil { return nil, nil, fmt.Errorf("failed to parse token: %w", err) } if !token.Valid { return nil, nil, fmt.Errorf("invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, nil, fmt.Errorf("invalid token claims") } // Check expiration if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { return nil, nil, fmt.Errorf("token expired") } } return token, claims, nil } // validateRefreshToken validates a refresh token. // Note: This is a simplified version - RefreshToken entity needs to be generated first func (s *authService) validateRefreshToken(ctx context.Context, tokenString string) (string, error) { // TODO: Use RefreshToken entity once generated // tokenHash := hashToken(tokenString) // rt, err := s.client.RefreshToken.Query()... // return rt.UserID, nil // For now, return error to indicate this needs proper implementation return "", fmt.Errorf("refresh token validation not yet implemented - RefreshToken entity needs to be generated") } // revokeRefreshToken revokes a refresh token. // Note: This is a simplified version - RefreshToken entity needs to be generated first func (s *authService) revokeRefreshToken(ctx context.Context, tokenString string) error { // TODO: Implement once RefreshToken entity is generated // tokenHash := hashToken(tokenString) // rt, err := s.client.RefreshToken.Query()... // return s.client.RefreshToken.DeleteOneID(rt.ID).Exec(ctx) return nil // Placeholder } // login authenticates a user and returns tokens. func (s *authService) login(ctx context.Context, email, password string) (*authv1.LoginResponse, error) { // Verify credentials with Identity Service user, err := s.identityClient.GetUserByEmail(ctx, email) if err != nil { return nil, fmt.Errorf("invalid credentials") } // Note: In a real implementation, we'd call VerifyPassword on Identity Service // For now, we'll assume Identity Service validates the password // This is a simplified version - the Identity Service should expose VerifyPassword // Get user roles (simplified - would come from Authz Service) roles := []string{} // TODO: Get from Authz Service // Generate tokens accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } refreshToken, err := s.generateRefreshToken(ctx, user.ID) if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } return &authv1.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: expiresIn, TokenType: "Bearer", }, nil } // refreshToken refreshes an access token. func (s *authService) refreshToken(ctx context.Context, refreshTokenString string) (*authv1.RefreshTokenResponse, error) { // Validate refresh token userID, err := s.validateRefreshToken(ctx, refreshTokenString) if err != nil { return nil, err } // Get user from Identity Service user, err := s.identityClient.GetUser(ctx, userID) if err != nil { return nil, fmt.Errorf("user not found") } // Get user roles (simplified) roles := []string{} // TODO: Get from Authz Service // Generate new tokens accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } // Generate new refresh token (rotate) newRefreshToken, err := s.generateRefreshToken(ctx, user.ID) if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } // Revoke old refresh token _ = s.revokeRefreshToken(ctx, refreshTokenString) return &authv1.RefreshTokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, ExpiresIn: expiresIn, TokenType: "Bearer", }, nil } // validateToken validates a JWT token. func (s *authService) validateToken(tokenString string) (*authv1.ValidateTokenResponse, error) { _, claims, err := s.validateAccessToken(tokenString) if err != nil { return nil, err } userID, _ := claims["sub"].(string) email, _ := claims["email"].(string) exp, _ := claims["exp"].(float64) roles := []string{} if rolesClaim, ok := claims["roles"].([]interface{}); ok { for _, r := range rolesClaim { if role, ok := r.(string); ok { roles = append(roles, role) } } } return &authv1.ValidateTokenResponse{ UserId: userID, Email: email, Roles: roles, ExpiresAt: int64(exp), }, nil } // logout invalidates a refresh token. func (s *authService) logout(ctx context.Context, refreshTokenString string) error { return s.revokeRefreshToken(ctx, refreshTokenString) } // authServerImpl implements the AuthService gRPC server. type authServerImpl struct { authv1.UnimplementedAuthServiceServer service *authService logger *zap.Logger } // Login authenticates a user and returns tokens. func (s *authServerImpl) Login(ctx context.Context, req *authv1.LoginRequest) (*authv1.LoginResponse, error) { resp, err := s.service.login(ctx, req.Email, req.Password) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "login failed: %v", err) } return resp, nil } // RefreshToken refreshes an access token. func (s *authServerImpl) RefreshToken(ctx context.Context, req *authv1.RefreshTokenRequest) (*authv1.RefreshTokenResponse, error) { resp, err := s.service.refreshToken(ctx, req.RefreshToken) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "refresh failed: %v", err) } return resp, nil } // ValidateToken validates a JWT token. func (s *authServerImpl) ValidateToken(ctx context.Context, req *authv1.ValidateTokenRequest) (*authv1.ValidateTokenResponse, error) { resp, err := s.service.validateToken(req.Token) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "validation failed: %v", err) } return resp, nil } // Logout invalidates a refresh token. func (s *authServerImpl) Logout(ctx context.Context, req *authv1.LogoutRequest) (*authv1.LogoutResponse, error) { if err := s.service.logout(ctx, req.RefreshToken); err != nil { return nil, status.Errorf(codes.Internal, "logout failed: %v", err) } return &authv1.LogoutResponse{Success: true}, nil } // provideAuthService creates the auth service and gRPC server. func provideAuthService() fx.Option { return fx.Options( // Auth service fx.Provide(func( client *ent.Client, log logger.Logger, identityClient services.IdentityServiceClient, cfg config.ConfigProvider, ) (*authService, error) { jwtSecret := cfg.GetString("auth.jwt_secret") if jwtSecret == "" { jwtSecret = "default-secret-change-in-production" // TODO: Generate or require } return &authService{ client: client, logger: log, identityClient: identityClient, jwtSecret: []byte(jwtSecret), accessTokenExpiry: accessTokenLifetime, refreshTokenExpiry: refreshTokenLifetime, }, nil }), // gRPC server implementation fx.Provide(func(authService *authService, log logger.Logger) (*authServerImpl, error) { zapLogger, _ := zap.NewProduction() return &authServerImpl{ service: authService, logger: zapLogger, }, nil }), // gRPC server wrapper fx.Provide(func( serverImpl *authServerImpl, cfg config.ConfigProvider, log logger.Logger, ) (*grpcServerWrapper, error) { port := cfg.GetInt("services.auth.port") if port == 0 { port = 8081 } addr := fmt.Sprintf("0.0.0.0:%d", port) listener, err := net.Listen("tcp", addr) if err != nil { return nil, fmt.Errorf("failed to listen on %s: %w", addr, err) } grpcServer := grpc.NewServer() authv1.RegisterAuthServiceServer(grpcServer, serverImpl) // Register health service healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) healthServer.SetServingStatus("auth.v1.AuthService", grpc_health_v1.HealthCheckResponse_SERVING) // Register reflection for grpcurl reflection.Register(grpcServer) return &grpcServerWrapper{ server: grpcServer, listener: listener, port: port, logger: log, }, nil }), ) }