Files
goplt/cmd/auth-service/auth_service_fx.go

426 lines
12 KiB
Go

// Package main provides FX providers for Auth Service.
// This file creates the service inline to avoid importing internal packages.
package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"time"
authv1 "git.dcentral.systems/toolz/goplt/api/proto/generated/auth/v1"
"git.dcentral.systems/toolz/goplt/internal/ent/refreshtoken"
"git.dcentral.systems/toolz/goplt/internal/infra/database"
"git.dcentral.systems/toolz/goplt/pkg/config"
"git.dcentral.systems/toolz/goplt/pkg/logger"
"git.dcentral.systems/toolz/goplt/pkg/services"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/fx"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
)
const (
accessTokenLifetime = 15 * time.Minute
refreshTokenLifetime = 7 * 24 * time.Hour
)
// authService provides authentication functionality.
type authService struct {
client *database.Client
logger logger.Logger
identityClient services.IdentityServiceClient
authzClient services.AuthzServiceClient
jwtSecret []byte
accessTokenExpiry time.Duration
refreshTokenExpiry time.Duration
}
// hashToken hashes a token using SHA256.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return hex.EncodeToString(h[:])
}
// generateRefreshToken generates a random refresh token.
func generateRefreshToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
return hex.EncodeToString(b), nil
}
// generateAccessToken generates a JWT access token.
func (s *authService) generateAccessToken(userID, email string, roles []string) (string, int64, error) {
expiresAt := time.Now().Add(s.accessTokenExpiry)
claims := jwt.MapClaims{
"sub": userID,
"email": email,
"roles": roles,
"exp": expiresAt.Unix(),
"iat": time.Now().Unix(),
"token_type": "access",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(s.jwtSecret)
if err != nil {
return "", 0, fmt.Errorf("failed to sign token: %w", err)
}
return tokenString, int64(s.accessTokenExpiry.Seconds()), nil
}
// generateRefreshToken generates a refresh token and stores it in the database.
func (s *authService) generateRefreshToken(ctx context.Context, userID string) (string, error) {
token, err := generateRefreshToken()
if err != nil {
return "", err
}
tokenHash := hashToken(token)
expiresAt := time.Now().Add(s.refreshTokenExpiry)
// Store refresh token in database
_, err = s.client.RefreshToken.Create().
SetID(fmt.Sprintf("%d-%d", time.Now().Unix(), time.Now().UnixNano()%1000000)).
SetUserID(userID).
SetTokenHash(tokenHash).
SetExpiresAt(expiresAt).
Save(ctx)
if err != nil {
return "", fmt.Errorf("failed to store refresh token: %w", err)
}
return token, nil
}
// validateAccessToken validates a JWT access token.
func (s *authService) validateAccessToken(tokenString string) (*jwt.Token, jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.jwtSecret, nil
})
if err != nil {
return nil, nil, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return nil, nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, nil, fmt.Errorf("invalid token claims")
}
// Check expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, nil, fmt.Errorf("token expired")
}
}
return token, claims, nil
}
// validateRefreshToken validates a refresh token.
func (s *authService) validateRefreshToken(ctx context.Context, tokenString string) (string, error) {
tokenHash := hashToken(tokenString)
// Find refresh token by hash
rt, err := s.client.RefreshToken.Query().
Where(refreshtoken.TokenHash(tokenHash)).
Only(ctx)
if err != nil {
return "", fmt.Errorf("invalid refresh token")
}
// Check if token has expired
if rt.ExpiresAt.Before(time.Now()) {
// Delete expired token
_ = s.client.RefreshToken.DeleteOneID(rt.ID).Exec(ctx)
return "", fmt.Errorf("refresh token expired")
}
return rt.UserID, nil
}
// revokeRefreshToken revokes a refresh token.
func (s *authService) revokeRefreshToken(ctx context.Context, tokenString string) error {
tokenHash := hashToken(tokenString)
// Find and delete refresh token
rt, err := s.client.RefreshToken.Query().
Where(refreshtoken.TokenHash(tokenHash)).
Only(ctx)
if err != nil {
// Token not found, consider it already revoked
return nil
}
return s.client.RefreshToken.DeleteOneID(rt.ID).Exec(ctx)
}
// login authenticates a user and returns tokens.
func (s *authService) login(ctx context.Context, email, password string) (*authv1.LoginResponse, error) {
// Verify credentials with Identity Service
user, err := s.identityClient.VerifyPassword(ctx, email, password)
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// Get user roles from Authz Service
roles := []string{}
if s.authzClient != nil {
userRoles, err := s.authzClient.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn("Failed to get user roles",
zap.String("user_id", user.ID),
zap.Error(err),
)
// Continue without roles rather than failing login
} else {
for _, role := range userRoles {
roles = append(roles, role.Name)
}
}
}
// Generate tokens
accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, err := s.generateRefreshToken(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
return &authv1.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: expiresIn,
TokenType: "Bearer",
}, nil
}
// refreshToken refreshes an access token.
func (s *authService) refreshToken(ctx context.Context, refreshTokenString string) (*authv1.RefreshTokenResponse, error) {
// Validate refresh token
userID, err := s.validateRefreshToken(ctx, refreshTokenString)
if err != nil {
return nil, err
}
// Get user from Identity Service
user, err := s.identityClient.GetUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("user not found")
}
// Get user roles from Authz Service
roles := []string{}
if s.authzClient != nil {
userRoles, err := s.authzClient.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn("Failed to get user roles",
zap.String("user_id", user.ID),
zap.Error(err),
)
// Continue without roles rather than failing refresh
} else {
for _, role := range userRoles {
roles = append(roles, role.Name)
}
}
}
// Generate new tokens
accessToken, expiresIn, err := s.generateAccessToken(user.ID, user.Email, roles)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
// Generate new refresh token (rotate)
newRefreshToken, err := s.generateRefreshToken(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
// Revoke old refresh token
_ = s.revokeRefreshToken(ctx, refreshTokenString)
return &authv1.RefreshTokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
ExpiresIn: expiresIn,
TokenType: "Bearer",
}, nil
}
// validateToken validates a JWT token.
func (s *authService) validateToken(tokenString string) (*authv1.ValidateTokenResponse, error) {
_, claims, err := s.validateAccessToken(tokenString)
if err != nil {
return nil, err
}
userID, _ := claims["sub"].(string)
email, _ := claims["email"].(string)
exp, _ := claims["exp"].(float64)
roles := []string{}
if rolesClaim, ok := claims["roles"].([]interface{}); ok {
for _, r := range rolesClaim {
if role, ok := r.(string); ok {
roles = append(roles, role)
}
}
}
return &authv1.ValidateTokenResponse{
UserId: userID,
Email: email,
Roles: roles,
ExpiresAt: int64(exp),
}, nil
}
// logout invalidates a refresh token.
func (s *authService) logout(ctx context.Context, refreshTokenString string) error {
return s.revokeRefreshToken(ctx, refreshTokenString)
}
// authServerImpl implements the AuthService gRPC server.
type authServerImpl struct {
authv1.UnimplementedAuthServiceServer
service *authService
logger *zap.Logger
}
// Login authenticates a user and returns tokens.
func (s *authServerImpl) Login(ctx context.Context, req *authv1.LoginRequest) (*authv1.LoginResponse, error) {
resp, err := s.service.login(ctx, req.Email, req.Password)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "login failed: %v", err)
}
return resp, nil
}
// RefreshToken refreshes an access token.
func (s *authServerImpl) RefreshToken(ctx context.Context, req *authv1.RefreshTokenRequest) (*authv1.RefreshTokenResponse, error) {
resp, err := s.service.refreshToken(ctx, req.RefreshToken)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "refresh failed: %v", err)
}
return resp, nil
}
// ValidateToken validates a JWT token.
func (s *authServerImpl) ValidateToken(ctx context.Context, req *authv1.ValidateTokenRequest) (*authv1.ValidateTokenResponse, error) {
resp, err := s.service.validateToken(req.Token)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "validation failed: %v", err)
}
return resp, nil
}
// Logout invalidates a refresh token.
func (s *authServerImpl) Logout(ctx context.Context, req *authv1.LogoutRequest) (*authv1.LogoutResponse, error) {
if err := s.service.logout(ctx, req.RefreshToken); err != nil {
return nil, status.Errorf(codes.Internal, "logout failed: %v", err)
}
return &authv1.LogoutResponse{Success: true}, nil
}
// provideAuthService creates the auth service and gRPC server.
func provideAuthService() fx.Option {
return fx.Options(
// Auth service
fx.Provide(func(
client *database.Client,
log logger.Logger,
identityClient services.IdentityServiceClient,
authzClient services.AuthzServiceClient,
cfg config.ConfigProvider,
) (*authService, error) {
jwtSecret := cfg.GetString("auth.jwt_secret")
if jwtSecret == "" {
return nil, fmt.Errorf("auth.jwt_secret is required in configuration")
}
return &authService{
client: client,
logger: log,
identityClient: identityClient,
authzClient: authzClient,
jwtSecret: []byte(jwtSecret),
accessTokenExpiry: accessTokenLifetime,
refreshTokenExpiry: refreshTokenLifetime,
}, nil
}),
// gRPC server implementation
fx.Provide(func(authService *authService, log logger.Logger) (*authServerImpl, error) {
zapLogger, _ := zap.NewProduction()
return &authServerImpl{
service: authService,
logger: zapLogger,
}, nil
}),
// gRPC server wrapper
fx.Provide(func(
serverImpl *authServerImpl,
cfg config.ConfigProvider,
log logger.Logger,
) (*grpcServerWrapper, error) {
port := cfg.GetInt("services.auth.port")
if port == 0 {
port = 8081
}
addr := fmt.Sprintf("0.0.0.0:%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
grpcServer := grpc.NewServer()
authv1.RegisterAuthServiceServer(grpcServer, serverImpl)
// Register health service
healthServer := health.NewServer()
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
// Set serving status for the default service (empty string) - this is what Consul checks
healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
// Also set for the specific service name
healthServer.SetServingStatus("auth.v1.AuthService", grpc_health_v1.HealthCheckResponse_SERVING)
// Register reflection for grpcurl
reflection.Register(grpcServer)
return &grpcServerWrapper{
server: grpcServer,
listener: listener,
port: port,
logger: log,
}, nil
}),
)
}