Files

243 lines
7.6 KiB
Go

package authority
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
)
const (
AuthoritySchemaVersion = "rap.cluster_authority.v1"
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
QuorumSchemaVersion = "rap.cluster_authority.quorum.v1"
QuorumEnvelopeVersion = "rap.cluster_authority.quorum_envelope.v1"
AlgorithmEd25519 = "ed25519"
)
var (
ErrInvalidKey = errors.New("invalid cluster authority key")
ErrInvalidSignature = errors.New("invalid cluster authority signature")
ErrInvalidPayload = errors.New("invalid cluster authority payload")
)
type Signature struct {
SchemaVersion string `json:"schema_version"`
Algorithm string `json:"algorithm"`
KeyFingerprint string `json:"key_fingerprint"`
Signature string `json:"signature"`
}
type QuorumMember struct {
NodeID string `json:"node_id,omitempty"`
Role string `json:"role,omitempty"`
PublicKey string `json:"public_key"`
PublicKeyFingerprint string `json:"public_key_fingerprint"`
Scopes []string `json:"scopes,omitempty"`
}
type QuorumDescriptor struct {
SchemaVersion string `json:"schema_version"`
ClusterID string `json:"cluster_id"`
Epoch string `json:"epoch"`
Threshold int `json:"threshold"`
Members []QuorumMember `json:"members"`
}
type QuorumEnvelope struct {
SchemaVersion string `json:"schema_version"`
ClusterID string `json:"cluster_id"`
Epoch string `json:"epoch"`
Threshold int `json:"threshold"`
PayloadSHA256 string `json:"payload_sha256"`
QuorumSHA256 string `json:"quorum_sha256"`
Signatures []Signature `json:"signatures"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
DecisionReason string `json:"decision_reason,omitempty"`
}
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
if signature.SchemaVersion != SignatureSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
}
if signature.Algorithm != AlgorithmEd25519 {
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
}
publicKey, err := decodePublicKey(publicKeyB64)
if err != nil {
return err
}
if signature.KeyFingerprint != Fingerprint(publicKey) {
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return err
}
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
}
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
return ErrInvalidSignature
}
return nil
}
func VerifyQuorumRaw(descriptor QuorumDescriptor, payload json.RawMessage, envelope QuorumEnvelope, requiredScope string) error {
if descriptor.SchemaVersion != QuorumSchemaVersion {
return fmt.Errorf("%w: quorum schema_version must be %s", ErrInvalidSignature, QuorumSchemaVersion)
}
if envelope.SchemaVersion != QuorumEnvelopeVersion {
return fmt.Errorf("%w: quorum envelope schema_version must be %s", ErrInvalidSignature, QuorumEnvelopeVersion)
}
if strings.TrimSpace(descriptor.ClusterID) == "" || descriptor.ClusterID != envelope.ClusterID {
return fmt.Errorf("%w: quorum cluster mismatch", ErrInvalidSignature)
}
if strings.TrimSpace(descriptor.Epoch) == "" || descriptor.Epoch != envelope.Epoch {
return fmt.Errorf("%w: quorum epoch mismatch", ErrInvalidSignature)
}
threshold := descriptor.Threshold
if envelope.Threshold > threshold {
threshold = envelope.Threshold
}
if threshold <= 0 || threshold > len(descriptor.Members) {
return fmt.Errorf("%w: invalid quorum threshold", ErrInvalidSignature)
}
payloadHash, err := HashRaw(payload)
if err != nil {
return err
}
if envelope.PayloadSHA256 != payloadHash {
return fmt.Errorf("%w: quorum payload hash mismatch", ErrInvalidSignature)
}
descriptorHash, err := HashRaw(mustMarshalQuorumDescriptor(descriptor))
if err != nil {
return err
}
if envelope.QuorumSHA256 != descriptorHash {
return fmt.Errorf("%w: quorum descriptor hash mismatch", ErrInvalidSignature)
}
members := map[string]QuorumMember{}
for _, member := range descriptor.Members {
fingerprint := strings.TrimSpace(member.PublicKeyFingerprint)
if fingerprint == "" {
publicKey, err := decodePublicKey(member.PublicKey)
if err != nil {
return err
}
fingerprint = Fingerprint(publicKey)
}
if _, exists := members[fingerprint]; exists {
return fmt.Errorf("%w: duplicate quorum member", ErrInvalidSignature)
}
member.PublicKeyFingerprint = fingerprint
members[fingerprint] = member
}
seen := map[string]bool{}
valid := 0
for _, signature := range envelope.Signatures {
fingerprint := strings.TrimSpace(signature.KeyFingerprint)
if seen[fingerprint] {
continue
}
member, ok := members[fingerprint]
if !ok {
return fmt.Errorf("%w: quorum signer is not a member", ErrInvalidSignature)
}
if requiredScope != "" && !memberAllowsScope(member, requiredScope) {
return fmt.Errorf("%w: quorum signer scope mismatch", ErrInvalidSignature)
}
if err := VerifyRaw(member.PublicKey, payload, signature); err != nil {
return err
}
seen[fingerprint] = true
valid++
}
if valid < threshold {
return fmt.Errorf("%w: quorum threshold not met", ErrInvalidSignature)
}
return nil
}
func QuorumDescriptorHash(descriptor QuorumDescriptor) (string, error) {
return HashRaw(mustMarshalQuorumDescriptor(descriptor))
}
func Fingerprint(publicKey ed25519.PublicKey) string {
sum := sha256.Sum256(publicKey)
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
}
func HashRaw(raw json.RawMessage) (string, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return "", err
}
sum := sha256.Sum256(canonical)
return hex.EncodeToString(sum[:]), nil
}
func mustMarshalQuorumDescriptor(descriptor QuorumDescriptor) json.RawMessage {
raw, err := json.Marshal(descriptor)
if err != nil {
return nil
}
return raw
}
func memberAllowsScope(member QuorumMember, requiredScope string) bool {
requiredScope = strings.TrimSpace(requiredScope)
if requiredScope == "" {
return true
}
for _, scope := range member.Scopes {
scope = strings.TrimSpace(scope)
if scope == "*" || scope == requiredScope {
return true
}
}
return false
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
}
return canonical, nil
}
func decodePublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
if value == "" {
return nil, errors.New("empty base64 value")
}
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.RawStdEncoding.DecodeString(value)
}