Files
sbv/internal/auth.go
T
2025-11-11 16:40:10 -07:00

206 lines
5.1 KiB
Go

package internal
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
var authDB *sql.DB
// InitAuthDB initializes the authentication database
func InitAuthDB(filepath string) error {
var err error
authDB, err = sql.Open("sqlite3", filepath)
if err != nil {
return err
}
if err = authDB.Ping(); err != nil {
return err
}
createTableSQL := `
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at);
`
_, err = authDB.Exec(createTableSQL)
return err
}
// CreateUser creates a new user with hashed password
func CreateUser(username, password string) (*User, error) {
// Generate UUID for user
userID := uuid.New().String()
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
createdAt := time.Now().Unix()
_, err = authDB.Exec(
"INSERT INTO users (id, username, password_hash, created_at) VALUES (?, ?, ?, ?)",
userID, username, string(hashedPassword), createdAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return &User{
ID: userID,
Username: username,
PasswordHash: string(hashedPassword),
CreatedAt: time.Unix(createdAt, 0),
}, nil
}
// GetUserByUsername retrieves a user by username
func GetUserByUsername(username string) (*User, error) {
var user User
var createdAt int64
err := authDB.QueryRow(
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
username,
).Scan(&user.ID, &user.Username, &user.PasswordHash, &createdAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
user.CreatedAt = time.Unix(createdAt, 0)
return &user, nil
}
// VerifyPassword checks if the provided password matches the user's password hash
func VerifyPassword(user *User, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// GenerateSessionID generates a random session ID
func GenerateSessionID() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// CreateSession creates a new session for a user
func CreateSession(userID string, username string) (*Session, error) {
sessionID, err := GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
createdAt := time.Now()
expiresAt := createdAt.Add(30 * 24 * time.Hour) // 30 days
_, err = authDB.Exec(
"INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
sessionID, userID, createdAt.Unix(), expiresAt.Unix(),
)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &Session{
ID: sessionID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}, nil
}
// GetSession retrieves a session by ID
func GetSession(sessionID string) (*Session, error) {
var session Session
var createdAt, expiresAt int64
err := authDB.QueryRow(
`SELECT s.id, s.user_id, u.username, s.created_at, s.expires_at
FROM sessions s
JOIN users u ON s.user_id = u.id
WHERE s.id = ?`,
sessionID,
).Scan(&session.ID, &session.UserID, &session.Username, &createdAt, &expiresAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found")
}
return nil, err
}
session.CreatedAt = time.Unix(createdAt, 0)
session.ExpiresAt = time.Unix(expiresAt, 0)
// Check if session is expired
if time.Now().After(session.ExpiresAt) {
DeleteSession(sessionID)
return nil, fmt.Errorf("session expired")
}
return &session, nil
}
// DeleteSession deletes a session by ID
func DeleteSession(sessionID string) error {
_, err := authDB.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
return err
}
// CleanExpiredSessions removes all expired sessions
func CleanExpiredSessions() error {
_, err := authDB.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().Unix())
return err
}
// UpdatePassword updates a user's password
func UpdatePassword(userID string, newPassword string) error {
// Hash the new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
_, err = authDB.Exec(
"UPDATE users SET password_hash = ? WHERE id = ?",
string(hashedPassword), userID,
)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}