135 lines
3.2 KiB
Go
135 lines
3.2 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"os"
|
|
"sync"
|
|
)
|
|
|
|
type CredentialsService struct {
|
|
DB *sql.DB
|
|
// Cache for decrypted keys
|
|
cache map[string]string
|
|
cacheMutex sync.RWMutex
|
|
}
|
|
|
|
func NewCredentialsService(db *sql.DB) *CredentialsService {
|
|
return &CredentialsService{
|
|
DB: db,
|
|
cache: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// SaveCredentials saves the encrypted payload for a service
|
|
func (s *CredentialsService) SaveCredentials(ctx context.Context, serviceName, encryptedPayload, updatedBy string) error {
|
|
query := `
|
|
INSERT INTO external_services_credentials (service_name, encrypted_payload, updated_by, updated_at)
|
|
VALUES ($1, $2, $3, NOW())
|
|
ON CONFLICT (service_name)
|
|
DO UPDATE SET
|
|
encrypted_payload = EXCLUDED.encrypted_payload,
|
|
updated_by = EXCLUDED.updated_by,
|
|
updated_at = NOW()
|
|
`
|
|
_, err := s.DB.ExecContext(ctx, query, serviceName, encryptedPayload, updatedBy)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Invalidate cache
|
|
s.cacheMutex.Lock()
|
|
delete(s.cache, serviceName)
|
|
s.cacheMutex.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDecryptedKey retrieves and decrypts the key for a service
|
|
func (s *CredentialsService) GetDecryptedKey(ctx context.Context, serviceName string) (string, error) {
|
|
// Check cache first
|
|
s.cacheMutex.RLock()
|
|
if val, ok := s.cache[serviceName]; ok {
|
|
s.cacheMutex.RUnlock()
|
|
return val, nil
|
|
}
|
|
s.cacheMutex.RUnlock()
|
|
|
|
// Fetch from DB
|
|
var encryptedPayload string
|
|
query := `SELECT encrypted_payload FROM external_services_credentials WHERE service_name = $1`
|
|
err := s.DB.QueryRowContext(ctx, query, serviceName).Scan(&encryptedPayload)
|
|
if err == sql.ErrNoRows {
|
|
return "", fmt.Errorf("credentials for service %s not found", serviceName)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Decrypt
|
|
decrypted, err := s.decryptPayload(encryptedPayload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decrypt credentials: %w", err)
|
|
}
|
|
|
|
// Update cache
|
|
s.cacheMutex.Lock()
|
|
s.cache[serviceName] = decrypted
|
|
s.cacheMutex.Unlock()
|
|
|
|
return decrypted, nil
|
|
}
|
|
|
|
func (s *CredentialsService) decryptPayload(encryptedPayload string) (string, error) {
|
|
// 1. Decode Private Key from Env
|
|
rawPrivateKey, err := base64.StdEncoding.DecodeString(os.Getenv("RSA_PRIVATE_KEY_BASE64"))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode env RSA private key: %w", err)
|
|
}
|
|
|
|
block, _ := pem.Decode(rawPrivateKey)
|
|
if block == nil {
|
|
return "", fmt.Errorf("failed to parse PEM block containing the private key")
|
|
}
|
|
|
|
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
// Try generic PKCS8 if PKCS1 fails
|
|
if key, err2 := x509.ParsePKCS8PrivateKey(block.Bytes); err2 == nil {
|
|
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
|
privKey = rsaKey
|
|
} else {
|
|
return "", fmt.Errorf("key is not RSA")
|
|
}
|
|
} else {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
// 2. Decode ciphertext
|
|
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPayload)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// 3. Decrypt using RSA-OAEP
|
|
plaintext, err := rsa.DecryptOAEP(
|
|
sha256.New(),
|
|
rand.Reader,
|
|
privKey,
|
|
ciphertext,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|