Initial commit
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var authDB *sql.DB
|
||||
|
||||
// InitAuthDB initializes the authentication database
|
||||
func InitAuthDB(filepath string) error {
|
||||
var err error
|
||||
authDB, err = sql.Open("sqlite3", filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = authDB.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createTableSQL := `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at);
|
||||
`
|
||||
|
||||
_, err = authDB.Exec(createTableSQL)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateUser creates a new user with hashed password
|
||||
func CreateUser(username, password string) (*User, error) {
|
||||
// Generate UUID for user
|
||||
userID := uuid.New().String()
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
createdAt := time.Now().Unix()
|
||||
|
||||
_, err = authDB.Exec(
|
||||
"INSERT INTO users (id, username, password_hash, created_at) VALUES (?, ?, ?, ?)",
|
||||
userID, username, string(hashedPassword), createdAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return &User{
|
||||
ID: userID,
|
||||
Username: username,
|
||||
PasswordHash: string(hashedPassword),
|
||||
CreatedAt: time.Unix(createdAt, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves a user by username
|
||||
func GetUserByUsername(username string) (*User, error) {
|
||||
var user User
|
||||
var createdAt int64
|
||||
|
||||
err := authDB.QueryRow(
|
||||
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
|
||||
username,
|
||||
).Scan(&user.ID, &user.Username, &user.PasswordHash, &createdAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.CreatedAt = time.Unix(createdAt, 0)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the provided password matches the user's password hash
|
||||
func VerifyPassword(user *User, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a random session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session for a user
|
||||
func CreateSession(userID string, username string) (*Session, error) {
|
||||
sessionID, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
expiresAt := createdAt.Add(30 * 24 * time.Hour) // 30 days
|
||||
|
||||
_, err = authDB.Exec(
|
||||
"INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
|
||||
sessionID, userID, createdAt.Unix(), expiresAt.Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
return &Session{
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
func GetSession(sessionID string) (*Session, error) {
|
||||
var session Session
|
||||
var createdAt, expiresAt int64
|
||||
|
||||
err := authDB.QueryRow(
|
||||
`SELECT s.id, s.user_id, u.username, s.created_at, s.expires_at
|
||||
FROM sessions s
|
||||
JOIN users u ON s.user_id = u.id
|
||||
WHERE s.id = ?`,
|
||||
sessionID,
|
||||
).Scan(&session.ID, &session.UserID, &session.Username, &createdAt, &expiresAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.CreatedAt = time.Unix(createdAt, 0)
|
||||
session.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
|
||||
// Check if session is expired
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
DeleteSession(sessionID)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session by ID
|
||||
func DeleteSession(sessionID string) error {
|
||||
_, err := authDB.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions
|
||||
func CleanExpiredSessions() error {
|
||||
_, err := authDB.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdatePassword updates a user's password
|
||||
func UpdatePassword(userID string, newPassword string) error {
|
||||
// Hash the new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
_, err = authDB.Exec(
|
||||
"UPDATE users SET password_hash = ? WHERE id = ?",
|
||||
string(hashedPassword), userID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user