website/pkg/models/two_factor.go

247 lines
5.7 KiB
Go
Raw Normal View History

2023-09-19 00:22:50 +00:00
package models
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"image/png"
"math/rand"
"strings"
"time"
"code.nonshy.com/nonshy/website/pkg/config"
"code.nonshy.com/nonshy/website/pkg/encryption"
"code.nonshy.com/nonshy/website/pkg/log"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
)
// TwoFactor table to hold 2FA TOTP tokens for more secure login.
type TwoFactor struct {
UserID uint64 `gorm:"primaryKey"` // owner ID
Enabled bool
EncryptedSecret []byte // encrypted OTP secret (URL format)
HashedSecret string // verification hash for the EncryptedSecret being decoded correctly
BackupCodes []byte // encrypted backup codes
CreatedAt time.Time
UpdatedAt time.Time
// Private vars
isNew bool // needs creation, didn't exist in DB
}
// IsNew returns if the 2FA record was freshly generated (not in DB yet).
func (tf *TwoFactor) IsNew() bool {
return tf.isNew
}
// New2FA initializes a TwoFactor config for a user, with randomly generated secrets.
func New2FA(userID uint64) *TwoFactor {
var tf = &TwoFactor{
isNew: true,
UserID: userID,
}
// Generate backup codes.
if err := tf.GenerateBackupCodes(); err != nil {
log.Error("New2FA(%d): GenerateBackupCodes: %s", userID, err)
}
return tf
}
// Get2FA looks up the TwoFactor config for a user, or returns an empty struct ready to initialize.
func Get2FA(userID uint64) *TwoFactor {
var (
tf = &TwoFactor{}
result = DB.First(&tf, userID)
)
if result.Error != nil {
return New2FA(userID)
}
return tf
}
// SetSecret sets (and encrypts) the EncryptedSecret.
func (tf *TwoFactor) SetSecret(url string) error {
// Get the hash of the original secret for verification.
hash := encryption.Hash([]byte(url))
// Encrypt it.
ciphertext, err := encryption.EncryptString(url)
if err != nil {
return err
}
// Store it.
tf.EncryptedSecret = ciphertext
tf.HashedSecret = hash
return nil
}
// GetSecret decrypts and verifies the TOTP secret (URL).
func (tf *TwoFactor) GetSecret() (string, error) {
// Decrypt it.
plaintext, err := encryption.DecryptString(tf.EncryptedSecret)
if err != nil {
return "", err
}
// Verify it.
if !encryption.VerifyHash([]byte(plaintext), tf.HashedSecret) {
return "", errors.New("hash of secret did not match: the site AES key may be wrong")
}
return plaintext, nil
}
// Validate a given 2FA code or Backup Code.
func (tf *TwoFactor) Validate(code string) error {
// Reconstruct the stored TOTP key.
secret, err := tf.GetSecret()
if err != nil {
return err
}
// Reconstruct the OTP key object.
key, err := otp.NewKeyFromURL(secret)
if err != nil {
return err
}
// Check for TOTP secret.
if totp.Validate(code, key.Secret()) {
return nil
}
// Check for (and burn) a Backup Code.
if tf.ValidateBackupCode(code) {
return nil
}
return errors.New("not a valid code")
}
// GenerateBackupCodes will generate and reset the backup codes (encrypted).
func (tf *TwoFactor) GenerateBackupCodes() error {
var (
codes = []string{}
distinct = map[string]interface{}{}
alphabet = []byte("abcdefghijklmnopqrstuvwxyz0123456789")
)
for i := 0; i < config.TwoFactorBackupCodeCount; i++ {
for {
var code []byte
for j := 0; j < config.TwoFactorBackupCodeLength; j++ {
code = append(code, alphabet[rand.Intn(len(alphabet))])
}
// Check for distinctness.
var codeStr = string(code)
if _, ok := distinct[codeStr]; ok {
continue
}
distinct[codeStr] = nil
codes = append(codes, codeStr)
break
}
}
// Encrypt the codes.
return tf.SetBackupCodes(codes)
}
// SetBackupCodes encrypts and stores the codes to DB.
func (tf *TwoFactor) SetBackupCodes(codes []string) error {
ciphertext, err := encryption.EncryptString(strings.Join(codes, ","))
if err != nil {
return err
}
tf.BackupCodes = ciphertext
return nil
}
// GetBackupCodes returns the list of still-valid backup codes.
func (tf *TwoFactor) GetBackupCodes() ([]string, error) {
// Decrypt the backup codes.
plaintext, err := encryption.DecryptString(tf.BackupCodes)
if err != nil {
return nil, err
}
return strings.Split(plaintext, ","), nil
}
// ValidateBackupCode will check if the code is a backup code and burn it if so.
func (tf *TwoFactor) ValidateBackupCode(code string) bool {
var (
codes, err = tf.GetBackupCodes()
newCodes = []string{} // in case of burning one
)
if err != nil {
log.Error("ValidateBackupCode: %s", err)
return false
}
// Check for a match to our backup codes.
code = strings.ToLower(code)
var matched bool
for _, check := range codes {
if check == code {
// Successful match!
matched = true
} else {
newCodes = append(newCodes, check)
}
}
// If we found a match, burn the code.
if matched {
if err := tf.SetBackupCodes(newCodes); err != nil {
log.Error("ValidateBackupCode: SetBackupCodes: %s", err)
return false
}
// Save it to DB.
if err := tf.Save(); err != nil {
log.Error("ValidateBackupCode: saving changes to DB: %s", err)
return false
}
}
return matched
}
// QRCodeAsDataURL returns an HTML img tag that embeds the 2FA QR code as a PNG data URL.
func (tf *TwoFactor) QRCodeAsDataURL(key *otp.Key) (string, error) {
var buf bytes.Buffer
img, err := key.Image(200, 200)
if err != nil {
return "", err
}
png.Encode(&buf, img)
var dataURL = fmt.Sprintf("data:image/png;base64,%s", base64.StdEncoding.EncodeToString(buf.Bytes()))
return fmt.Sprintf(`<img src="%s" alt="QR Code">`, dataURL), nil
}
// Save the note.
func (tf *TwoFactor) Save() error {
log.Error("SAVE 2FA: %+v", tf)
if tf.isNew {
return DB.Create(tf).Error
}
return DB.Save(tf).Error
}
// Delete the DB entry.
func (tf *TwoFactor) Delete() error {
if tf.isNew {
return nil
}
return DB.Delete(tf).Error
}