feat(epic2): Implement core authentication and authorization services
- Implement Audit Service (2.5) - gRPC server with Record and Query operations - Database persistence with audit schema - Service registry integration - Entry point: cmd/audit-service - Implement Identity Service (2.2) - User CRUD operations - Password hashing with argon2id - Email verification and password reset flows - Entry point: cmd/identity-service - Fix package naming conflicts in user_service.go - Implement Auth Service (2.1) - JWT token generation and validation - Login, RefreshToken, ValidateToken, Logout RPCs - Integration with Identity Service - Entry point: cmd/auth-service - Note: RefreshToken entity needs Ent generation - Implement Authz Service (2.3, 2.4) - Permission checking and authorization - User roles and permissions retrieval - RBAC-based authorization - Entry point: cmd/authz-service - Implement gRPC clients for all services - Auth, Identity, Authz, and Audit clients - Service discovery integration - Full gRPC communication - Add service configurations to config/default.yaml - Create SUMMARY.md with implementation details and testing instructions - Fix compilation errors in Identity Service (password package conflicts) - All services build successfully and tests pass
This commit is contained in:
124
services/audit/internal/api/grpc_server.go
Normal file
124
services/audit/internal/api/grpc_server.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Package api provides gRPC server implementation for Audit Service.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
auditv1 "git.dcentral.systems/toolz/goplt/api/proto/generated/audit/v1"
|
||||
"git.dcentral.systems/toolz/goplt/pkg/config"
|
||||
"git.dcentral.systems/toolz/goplt/pkg/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
// GRPCServer wraps the gRPC server with lifecycle management.
|
||||
type GRPCServer struct {
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
config config.ConfigProvider
|
||||
logger logger.Logger
|
||||
port int
|
||||
}
|
||||
|
||||
// NewGRPCServer creates a new gRPC server for the Audit Service.
|
||||
func NewGRPCServer(
|
||||
auditService *Server,
|
||||
cfg config.ConfigProvider,
|
||||
log logger.Logger,
|
||||
) (*GRPCServer, error) {
|
||||
// Get port from config
|
||||
port := cfg.GetInt("services.audit.port")
|
||||
if port == 0 {
|
||||
port = 8084 // Default port for audit service
|
||||
}
|
||||
|
||||
// Create listener
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
// Create gRPC server
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
// Register audit service
|
||||
auditv1.RegisterAuditServiceServer(grpcServer, auditService)
|
||||
|
||||
// Register health service
|
||||
healthServer := health.NewServer()
|
||||
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
|
||||
healthServer.SetServingStatus("audit.v1.AuditService", grpc_health_v1.HealthCheckResponse_SERVING)
|
||||
|
||||
// Register reflection for grpcurl
|
||||
reflection.Register(grpcServer)
|
||||
|
||||
return &GRPCServer{
|
||||
server: grpcServer,
|
||||
listener: listener,
|
||||
config: cfg,
|
||||
logger: log,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start starts the gRPC server.
|
||||
func (s *GRPCServer) Start() error {
|
||||
s.logger.Info("Starting Audit Service gRPC server",
|
||||
zap.Int("port", s.port),
|
||||
zap.String("addr", s.listener.Addr().String()),
|
||||
)
|
||||
|
||||
// Start server in a goroutine
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.server.Serve(s.listener); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait a bit to check for immediate errors
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("gRPC server failed to start: %w", err)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
s.logger.Info("Audit Service gRPC server started successfully",
|
||||
zap.Int("port", s.port),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully stops the gRPC server.
|
||||
func (s *GRPCServer) Stop(ctx context.Context) error {
|
||||
s.logger.Info("Stopping Audit Service gRPC server")
|
||||
|
||||
// Create a channel for graceful stop
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
s.server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
// Wait for graceful stop or timeout
|
||||
select {
|
||||
case <-stopped:
|
||||
s.logger.Info("Audit Service gRPC server stopped gracefully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.logger.Warn("Audit Service gRPC server stop timeout, forcing stop")
|
||||
s.server.Stop()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Port returns the port the server is listening on.
|
||||
func (s *GRPCServer) Port() int {
|
||||
return s.port
|
||||
}
|
||||
125
services/audit/internal/api/server.go
Normal file
125
services/audit/internal/api/server.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Package api provides gRPC server implementation for Audit Service.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
auditv1 "git.dcentral.systems/toolz/goplt/api/proto/generated/audit/v1"
|
||||
"git.dcentral.systems/toolz/goplt/services/audit/internal/service"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// Server implements the AuditService gRPC server.
|
||||
type Server struct {
|
||||
auditv1.UnimplementedAuditServiceServer
|
||||
service *service.AuditService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewServer creates a new Audit Service gRPC server.
|
||||
func NewServer(auditService *service.AuditService, logger *zap.Logger) *Server {
|
||||
return &Server{
|
||||
service: auditService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Record records an audit log entry.
|
||||
func (s *Server) Record(ctx context.Context, req *auditv1.RecordRequest) (*auditv1.RecordResponse, error) {
|
||||
if req.Entry == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "entry is required")
|
||||
}
|
||||
|
||||
entry := req.Entry
|
||||
|
||||
// Convert proto entry to service entry
|
||||
serviceEntry := &service.AuditLogEntry{
|
||||
UserID: entry.UserId,
|
||||
Action: entry.Action,
|
||||
Resource: entry.Resource,
|
||||
ResourceID: entry.ResourceId,
|
||||
IPAddress: entry.IpAddress,
|
||||
UserAgent: entry.UserAgent,
|
||||
Metadata: entry.Metadata,
|
||||
Timestamp: entry.Timestamp,
|
||||
}
|
||||
|
||||
// Record the audit log
|
||||
if err := s.service.Record(ctx, serviceEntry); err != nil {
|
||||
s.logger.Error("Failed to record audit log",
|
||||
zap.Error(err),
|
||||
zap.String("user_id", entry.UserId),
|
||||
zap.String("action", entry.Action),
|
||||
)
|
||||
return nil, status.Errorf(codes.Internal, "failed to record audit log: %v", err)
|
||||
}
|
||||
|
||||
return &auditv1.RecordResponse{
|
||||
Success: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Query queries audit logs based on filters.
|
||||
func (s *Server) Query(ctx context.Context, req *auditv1.QueryRequest) (*auditv1.QueryResponse, error) {
|
||||
// Convert proto filters to service filters
|
||||
filters := &service.AuditLogFilters{
|
||||
Limit: int(req.Limit),
|
||||
Offset: int(req.Offset),
|
||||
}
|
||||
|
||||
if req.UserId != nil {
|
||||
userID := *req.UserId
|
||||
filters.UserID = &userID
|
||||
}
|
||||
if req.Action != nil {
|
||||
action := *req.Action
|
||||
filters.Action = &action
|
||||
}
|
||||
if req.Resource != nil {
|
||||
resource := *req.Resource
|
||||
filters.Resource = &resource
|
||||
}
|
||||
if req.ResourceId != nil {
|
||||
resourceID := *req.ResourceId
|
||||
filters.ResourceID = &resourceID
|
||||
}
|
||||
if req.StartTime != nil {
|
||||
startTime := *req.StartTime
|
||||
filters.StartTime = &startTime
|
||||
}
|
||||
if req.EndTime != nil {
|
||||
endTime := *req.EndTime
|
||||
filters.EndTime = &endTime
|
||||
}
|
||||
|
||||
// Query audit logs
|
||||
entries, err := s.service.Query(ctx, filters)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to query audit logs",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, status.Errorf(codes.Internal, "failed to query audit logs: %v", err)
|
||||
}
|
||||
|
||||
// Convert service entries to proto entries
|
||||
protoEntries := make([]*auditv1.AuditLogEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
protoEntries = append(protoEntries, &auditv1.AuditLogEntry{
|
||||
UserId: entry.UserID,
|
||||
Action: entry.Action,
|
||||
Resource: entry.Resource,
|
||||
ResourceId: entry.ResourceID,
|
||||
IpAddress: entry.IPAddress,
|
||||
UserAgent: entry.UserAgent,
|
||||
Metadata: entry.Metadata,
|
||||
Timestamp: entry.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
return &auditv1.QueryResponse{
|
||||
Entries: protoEntries,
|
||||
Total: int32(len(protoEntries)), // Note: This is a simplified total, actual total would require a count query
|
||||
}, nil
|
||||
}
|
||||
181
services/audit/internal/service/audit_service.go
Normal file
181
services/audit/internal/service/audit_service.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Package service provides audit service business logic.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.dcentral.systems/toolz/goplt/internal/ent"
|
||||
"git.dcentral.systems/toolz/goplt/internal/ent/auditlog"
|
||||
"git.dcentral.systems/toolz/goplt/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditLogEntry represents an audit log entry.
|
||||
type AuditLogEntry struct {
|
||||
UserID string
|
||||
Action string
|
||||
Resource string
|
||||
ResourceID string
|
||||
IPAddress string
|
||||
UserAgent string
|
||||
Metadata map[string]string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// AuditLogFilters contains filters for querying audit logs.
|
||||
type AuditLogFilters struct {
|
||||
UserID *string
|
||||
Action *string
|
||||
Resource *string
|
||||
ResourceID *string
|
||||
StartTime *int64
|
||||
EndTime *int64
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// AuditService provides audit logging functionality.
|
||||
type AuditService struct {
|
||||
client *ent.Client
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
// NewAuditService creates a new audit service.
|
||||
func NewAuditService(client *ent.Client, log logger.Logger) *AuditService {
|
||||
return &AuditService{
|
||||
client: client,
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Record records an audit log entry.
|
||||
func (s *AuditService) Record(ctx context.Context, entry *AuditLogEntry) error {
|
||||
// Convert metadata map to JSON
|
||||
metadataJSON := make(map[string]interface{})
|
||||
for k, v := range entry.Metadata {
|
||||
metadataJSON[k] = v
|
||||
}
|
||||
|
||||
// Create audit log entry
|
||||
timestamp := time.Unix(entry.Timestamp, 0)
|
||||
if entry.Timestamp == 0 {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
create := s.client.AuditLog.Create().
|
||||
SetID(uuid.New().String()).
|
||||
SetUserID(entry.UserID).
|
||||
SetAction(entry.Action).
|
||||
SetMetadata(metadataJSON).
|
||||
SetTimestamp(timestamp)
|
||||
|
||||
if entry.Resource != "" {
|
||||
create = create.SetResource(entry.Resource)
|
||||
}
|
||||
if entry.ResourceID != "" {
|
||||
create = create.SetResourceID(entry.ResourceID)
|
||||
}
|
||||
if entry.IPAddress != "" {
|
||||
create = create.SetIPAddress(entry.IPAddress)
|
||||
}
|
||||
if entry.UserAgent != "" {
|
||||
create = create.SetUserAgent(entry.UserAgent)
|
||||
}
|
||||
|
||||
auditLog, err := create.Save(ctx)
|
||||
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to record audit log",
|
||||
zap.Error(err),
|
||||
zap.String("user_id", entry.UserID),
|
||||
zap.String("action", entry.Action),
|
||||
)
|
||||
return fmt.Errorf("failed to record audit log: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug("Audit log recorded",
|
||||
zap.String("id", auditLog.ID),
|
||||
zap.String("user_id", entry.UserID),
|
||||
zap.String("action", entry.Action),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query queries audit logs based on filters.
|
||||
func (s *AuditService) Query(ctx context.Context, filters *AuditLogFilters) ([]*AuditLogEntry, error) {
|
||||
query := s.client.AuditLog.Query()
|
||||
|
||||
// Apply filters
|
||||
if filters.UserID != nil {
|
||||
query = query.Where(auditlog.UserID(*filters.UserID))
|
||||
}
|
||||
if filters.Action != nil {
|
||||
query = query.Where(auditlog.Action(*filters.Action))
|
||||
}
|
||||
if filters.Resource != nil {
|
||||
query = query.Where(auditlog.Resource(*filters.Resource))
|
||||
}
|
||||
if filters.ResourceID != nil {
|
||||
query = query.Where(auditlog.ResourceID(*filters.ResourceID))
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
query = query.Where(auditlog.TimestampGTE(time.Unix(*filters.StartTime, 0)))
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
query = query.Where(auditlog.TimestampLTE(time.Unix(*filters.EndTime, 0)))
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if filters.Limit > 0 {
|
||||
query = query.Limit(filters.Limit)
|
||||
}
|
||||
if filters.Offset > 0 {
|
||||
query = query.Offset(filters.Offset)
|
||||
}
|
||||
|
||||
// Order by timestamp descending
|
||||
query = query.Order(ent.Desc(auditlog.FieldTimestamp))
|
||||
|
||||
// Execute query
|
||||
auditLogs, err := query.All(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to query audit logs",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to query audit logs: %w", err)
|
||||
}
|
||||
|
||||
// Convert to service entries
|
||||
entries := make([]*AuditLogEntry, 0, len(auditLogs))
|
||||
for _, log := range auditLogs {
|
||||
// Convert metadata from map[string]interface{} to map[string]string
|
||||
metadata := make(map[string]string)
|
||||
if log.Metadata != nil {
|
||||
for k, v := range log.Metadata {
|
||||
if str, ok := v.(string); ok {
|
||||
metadata[k] = str
|
||||
} else {
|
||||
metadata[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entry := &AuditLogEntry{
|
||||
UserID: log.UserID,
|
||||
Action: log.Action,
|
||||
Resource: log.Resource,
|
||||
ResourceID: log.ResourceID,
|
||||
IPAddress: log.IPAddress,
|
||||
UserAgent: log.UserAgent,
|
||||
Metadata: metadata,
|
||||
Timestamp: log.Timestamp.Unix(),
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
23
services/audit/service.go
Normal file
23
services/audit/service.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Package audit provides the Audit Service API.
|
||||
// This package exports the service interface and types for use by other packages.
|
||||
package audit
|
||||
|
||||
import (
|
||||
"git.dcentral.systems/toolz/goplt/services/audit/internal/service"
|
||||
)
|
||||
|
||||
// Service is the audit service interface.
|
||||
type Service = service.AuditService
|
||||
|
||||
// NewService creates a new audit service.
|
||||
func NewService(client interface{}, log interface{}) *Service {
|
||||
// This is a type-safe wrapper - we'll need to use type assertions
|
||||
// For now, return nil as placeholder - actual implementation will be in main.go
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuditLogEntry represents an audit log entry.
|
||||
type AuditLogEntry = service.AuditLogEntry
|
||||
|
||||
// AuditLogFilters contains filters for querying audit logs.
|
||||
type AuditLogFilters = service.AuditLogFilters
|
||||
88
services/identity/internal/password/password.go
Normal file
88
services/identity/internal/password/password.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Package password provides password hashing and verification using argon2id.
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default parameters for argon2id (OWASP recommended)
|
||||
memory = 64 * 1024 // 64 MB
|
||||
iterations = 3
|
||||
parallelism = 4
|
||||
saltLength = 16
|
||||
keyLength = 32
|
||||
)
|
||||
|
||||
// Hash hashes a password using argon2id.
|
||||
func Hash(password string) (string, error) {
|
||||
// Generate random salt
|
||||
salt := make([]byte, saltLength)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, keyLength)
|
||||
|
||||
// Encode salt and hash
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return formatted hash: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, memory, iterations, parallelism, b64Salt, b64Hash), nil
|
||||
}
|
||||
|
||||
// Verify verifies a password against a hash.
|
||||
func Verify(password, hash string) (bool, error) {
|
||||
// Parse hash format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, errors.New("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return false, fmt.Errorf("unsupported algorithm: %s", parts[1])
|
||||
}
|
||||
|
||||
// Parse version
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return false, fmt.Errorf("failed to parse version: %w", err)
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
var m, t, p int
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p); err != nil {
|
||||
return false, fmt.Errorf("failed to parse parameters: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and hash
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decode salt: %w", err)
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decode hash: %w", err)
|
||||
}
|
||||
|
||||
// Compute hash with same parameters
|
||||
actualHash := argon2.IDKey([]byte(password), salt, uint32(t), uint32(m), uint8(p), uint32(len(expectedHash)))
|
||||
|
||||
// Constant-time comparison
|
||||
if subtle.ConstantTimeCompare(expectedHash, actualHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
340
services/identity/internal/service/user_service.go
Normal file
340
services/identity/internal/service/user_service.go
Normal file
@@ -0,0 +1,340 @@
|
||||
// Package service provides user service business logic.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.dcentral.systems/toolz/goplt/internal/ent"
|
||||
"git.dcentral.systems/toolz/goplt/internal/ent/user"
|
||||
"git.dcentral.systems/toolz/goplt/pkg/logger"
|
||||
passwordpkg "git.dcentral.systems/toolz/goplt/services/identity/internal/password"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserService provides user management functionality.
|
||||
type UserService struct {
|
||||
client *ent.Client
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
// NewUserService creates a new user service.
|
||||
func NewUserService(client *ent.Client, log logger.Logger) *UserService {
|
||||
return &UserService{
|
||||
client: client,
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// generateToken generates a random token for email verification or password reset.
|
||||
func generateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (s *UserService) CreateUser(ctx context.Context, email, username, password, firstName, lastName string) (*ent.User, error) {
|
||||
// Check if user with email already exists
|
||||
exists, err := s.client.User.Query().
|
||||
Where(user.Email(email)).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check email existence: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
|
||||
// Hash password
|
||||
passwordHash, err := passwordpkg.Hash(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Generate email verification token
|
||||
verificationToken, err := generateToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate verification token: %w", err)
|
||||
}
|
||||
|
||||
// Create user
|
||||
create := s.client.User.Create().
|
||||
SetID(fmt.Sprintf("%d", time.Now().UnixNano())).
|
||||
SetEmail(email).
|
||||
SetPasswordHash(passwordHash).
|
||||
SetVerified(false).
|
||||
SetEmailVerificationToken(verificationToken)
|
||||
|
||||
if username != "" {
|
||||
create = create.SetUsername(username)
|
||||
}
|
||||
if firstName != "" {
|
||||
create = create.SetFirstName(firstName)
|
||||
}
|
||||
if lastName != "" {
|
||||
create = create.SetLastName(lastName)
|
||||
}
|
||||
|
||||
u, err := create.Save(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to create user",
|
||||
zap.Error(err),
|
||||
zap.String("email", email),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("User created",
|
||||
zap.String("user_id", u.ID),
|
||||
zap.String("email", email),
|
||||
)
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by ID.
|
||||
func (s *UserService) GetUser(ctx context.Context, id string) (*ent.User, error) {
|
||||
u, err := s.client.User.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email.
|
||||
func (s *UserService) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
|
||||
u, err := s.client.User.Query().
|
||||
Where(user.Email(email)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by email: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's profile.
|
||||
func (s *UserService) UpdateUser(ctx context.Context, id string, email, username, firstName, lastName *string) (*ent.User, error) {
|
||||
update := s.client.User.UpdateOneID(id)
|
||||
|
||||
if email != nil {
|
||||
// Check if email is already taken by another user
|
||||
exists, err := s.client.User.Query().
|
||||
Where(user.Email(*email), user.IDNEQ(id)).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check email existence: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("email %s is already taken", *email)
|
||||
}
|
||||
update = update.SetEmail(*email)
|
||||
}
|
||||
if username != nil {
|
||||
update = update.SetUsername(*username)
|
||||
}
|
||||
if firstName != nil {
|
||||
update = update.SetFirstName(*firstName)
|
||||
}
|
||||
if lastName != nil {
|
||||
update = update.SetLastName(*lastName)
|
||||
}
|
||||
|
||||
u, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("User updated",
|
||||
zap.String("user_id", id),
|
||||
)
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user (soft delete by setting verified to false, or hard delete).
|
||||
func (s *UserService) DeleteUser(ctx context.Context, id string) error {
|
||||
err := s.client.User.DeleteOneID(id).Exec(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("User deleted",
|
||||
zap.String("user_id", id),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyEmail verifies a user's email address using a verification token.
|
||||
func (s *UserService) VerifyEmail(ctx context.Context, token string) error {
|
||||
u, err := s.client.User.Query().
|
||||
Where(user.EmailVerificationToken(token)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return fmt.Errorf("invalid verification token")
|
||||
}
|
||||
return fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
// Update user to verified and clear token
|
||||
_, err = s.client.User.UpdateOneID(u.ID).
|
||||
SetVerified(true).
|
||||
ClearEmailVerificationToken().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify email: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Email verified",
|
||||
zap.String("user_id", u.ID),
|
||||
zap.String("email", u.Email),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestPasswordReset requests a password reset token.
|
||||
func (s *UserService) RequestPasswordReset(ctx context.Context, email string) (string, error) {
|
||||
u, err := s.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
// Don't reveal if user exists or not (security best practice)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Generate reset token
|
||||
resetToken, err := generateToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate reset token: %w", err)
|
||||
}
|
||||
|
||||
// Set reset token with expiration (24 hours)
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
_, err = s.client.User.UpdateOneID(u.ID).
|
||||
SetPasswordResetToken(resetToken).
|
||||
SetPasswordResetExpiresAt(expiresAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set reset token: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Password reset requested",
|
||||
zap.String("user_id", u.ID),
|
||||
zap.String("email", email),
|
||||
)
|
||||
|
||||
return resetToken, nil
|
||||
}
|
||||
|
||||
// ResetPassword resets a user's password using a reset token.
|
||||
func (s *UserService) ResetPassword(ctx context.Context, token, newPassword string) error {
|
||||
u, err := s.client.User.Query().
|
||||
Where(user.PasswordResetToken(token)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return fmt.Errorf("invalid reset token")
|
||||
}
|
||||
return fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if !u.PasswordResetExpiresAt.IsZero() && u.PasswordResetExpiresAt.Before(time.Now()) {
|
||||
return fmt.Errorf("reset token has expired")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
passwordHash, err := passwordpkg.Hash(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password and clear reset token
|
||||
_, err = s.client.User.UpdateOneID(u.ID).
|
||||
SetPasswordHash(passwordHash).
|
||||
ClearPasswordResetToken().
|
||||
ClearPasswordResetExpiresAt().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reset password: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Password reset",
|
||||
zap.String("user_id", u.ID),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password with old password verification.
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword string) error {
|
||||
u, err := s.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
valid, err := passwordpkg.Verify(oldPassword, u.PasswordHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify password: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid old password")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
passwordHash, err := passwordpkg.Hash(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = s.client.User.UpdateOneID(userID).
|
||||
SetPasswordHash(passwordHash).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change password: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Password changed",
|
||||
zap.String("user_id", userID),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against a user's password hash.
|
||||
func (s *UserService) VerifyPassword(ctx context.Context, email, password string) (*ent.User, error) {
|
||||
u, err := s.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
valid, err := passwordpkg.Verify(password, u.PasswordHash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify password: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
Reference in New Issue
Block a user