96 lines
2.5 KiB
Go
96 lines
2.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
type TokenManager struct {
|
|
issuer string
|
|
accessSecret []byte
|
|
refreshHashSecret []byte
|
|
accessTTL time.Duration
|
|
refreshTTL time.Duration
|
|
}
|
|
|
|
type AccessClaims struct {
|
|
AuthSessionID string `json:"sid"`
|
|
DeviceID string `json:"did"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func NewTokenManager(cfg TokenConfig) *TokenManager {
|
|
return &TokenManager{
|
|
issuer: cfg.Issuer,
|
|
accessSecret: []byte(cfg.AccessTokenSecret),
|
|
refreshHashSecret: []byte(cfg.RefreshHashSecret),
|
|
accessTTL: cfg.AccessTokenTTL,
|
|
refreshTTL: cfg.RefreshTokenTTL,
|
|
}
|
|
}
|
|
|
|
type TokenConfig struct {
|
|
Issuer string
|
|
AccessTokenSecret string
|
|
RefreshHashSecret string
|
|
AccessTokenTTL time.Duration
|
|
RefreshTokenTTL time.Duration
|
|
}
|
|
|
|
func (m *TokenManager) IssueAccessToken(userID, authSessionID, deviceID string, now time.Time) (string, time.Time, error) {
|
|
expiresAt := now.Add(m.accessTTL)
|
|
claims := AccessClaims{
|
|
AuthSessionID: authSessionID,
|
|
DeviceID: deviceID,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: m.issuer,
|
|
Subject: userID,
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signed, err := token.SignedString(m.accessSecret)
|
|
if err != nil {
|
|
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
|
|
}
|
|
|
|
return signed, expiresAt, nil
|
|
}
|
|
|
|
func (m *TokenManager) IssueRefreshToken(authSessionID string, now time.Time) (raw string, hash string, expiresAt time.Time, err error) {
|
|
secret := make([]byte, 32)
|
|
if _, err = rand.Read(secret); err != nil {
|
|
return "", "", time.Time{}, fmt.Errorf("read random refresh secret: %w", err)
|
|
}
|
|
|
|
encodedSecret := base64.RawURLEncoding.EncodeToString(secret)
|
|
raw = authSessionID + "." + encodedSecret
|
|
hash = m.HashRefreshToken(raw)
|
|
expiresAt = now.Add(m.refreshTTL)
|
|
return raw, hash, expiresAt, nil
|
|
}
|
|
|
|
func (m *TokenManager) HashRefreshToken(token string) string {
|
|
mac := hmac.New(sha256.New, m.refreshHashSecret)
|
|
_, _ = mac.Write([]byte(token))
|
|
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
func (m *TokenManager) ParseRefreshToken(token string) (string, error) {
|
|
sessionID, _, ok := strings.Cut(token, ".")
|
|
if !ok || sessionID == "" {
|
|
return "", ErrInvalidRefreshToken
|
|
}
|
|
return sessionID, nil
|
|
}
|