Initial commit

This commit is contained in:
lowcarbdev
2025-11-11 16:40:10 -07:00
commit b79e599640
57 changed files with 11811 additions and 0 deletions
+205
View File
@@ -0,0 +1,205 @@
package internal
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
var authDB *sql.DB
// InitAuthDB initializes the authentication database
func InitAuthDB(filepath string) error {
var err error
authDB, err = sql.Open("sqlite3", filepath)
if err != nil {
return err
}
if err = authDB.Ping(); err != nil {
return err
}
createTableSQL := `
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at);
`
_, err = authDB.Exec(createTableSQL)
return err
}
// CreateUser creates a new user with hashed password
func CreateUser(username, password string) (*User, error) {
// Generate UUID for user
userID := uuid.New().String()
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
createdAt := time.Now().Unix()
_, err = authDB.Exec(
"INSERT INTO users (id, username, password_hash, created_at) VALUES (?, ?, ?, ?)",
userID, username, string(hashedPassword), createdAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return &User{
ID: userID,
Username: username,
PasswordHash: string(hashedPassword),
CreatedAt: time.Unix(createdAt, 0),
}, nil
}
// GetUserByUsername retrieves a user by username
func GetUserByUsername(username string) (*User, error) {
var user User
var createdAt int64
err := authDB.QueryRow(
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
username,
).Scan(&user.ID, &user.Username, &user.PasswordHash, &createdAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
user.CreatedAt = time.Unix(createdAt, 0)
return &user, nil
}
// VerifyPassword checks if the provided password matches the user's password hash
func VerifyPassword(user *User, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// GenerateSessionID generates a random session ID
func GenerateSessionID() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// CreateSession creates a new session for a user
func CreateSession(userID string, username string) (*Session, error) {
sessionID, err := GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
createdAt := time.Now()
expiresAt := createdAt.Add(30 * 24 * time.Hour) // 30 days
_, err = authDB.Exec(
"INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
sessionID, userID, createdAt.Unix(), expiresAt.Unix(),
)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &Session{
ID: sessionID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}, nil
}
// GetSession retrieves a session by ID
func GetSession(sessionID string) (*Session, error) {
var session Session
var createdAt, expiresAt int64
err := authDB.QueryRow(
`SELECT s.id, s.user_id, u.username, s.created_at, s.expires_at
FROM sessions s
JOIN users u ON s.user_id = u.id
WHERE s.id = ?`,
sessionID,
).Scan(&session.ID, &session.UserID, &session.Username, &createdAt, &expiresAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found")
}
return nil, err
}
session.CreatedAt = time.Unix(createdAt, 0)
session.ExpiresAt = time.Unix(expiresAt, 0)
// Check if session is expired
if time.Now().After(session.ExpiresAt) {
DeleteSession(sessionID)
return nil, fmt.Errorf("session expired")
}
return &session, nil
}
// DeleteSession deletes a session by ID
func DeleteSession(sessionID string) error {
_, err := authDB.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
return err
}
// CleanExpiredSessions removes all expired sessions
func CleanExpiredSessions() error {
_, err := authDB.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().Unix())
return err
}
// UpdatePassword updates a user's password
func UpdatePassword(userID string, newPassword string) error {
// Hash the new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
_, err = authDB.Exec(
"UPDATE users SET password_hash = ? WHERE id = ?",
string(hashedPassword), userID,
)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
+278
View File
@@ -0,0 +1,278 @@
package internal
import (
"fmt"
"log/slog"
"net/http"
"os"
"strings"
"time"
"github.com/labstack/echo/v4"
)
func HandleRegister(c echo.Context) error {
var req RegisterRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Invalid request body",
})
}
// Validate input
req.Username = strings.TrimSpace(req.Username)
if req.Username == "" {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Username is required",
})
}
if len(req.Username) < 3 {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Username must be at least 3 characters",
})
}
if len(req.Password) < 6 {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Password must be at least 6 characters",
})
}
// Create user
user, err := CreateUser(req.Username, req.Password)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
return c.JSON(http.StatusConflict, AuthResponse{
Success: false,
Error: "Username already exists",
})
}
slog.Error("Error creating user", "error", err)
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to create user",
})
}
// Create session
session, err := CreateSession(user.ID, user.Username)
if err != nil {
slog.Error("Error creating session", "error", err)
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to create session",
})
}
// Set session cookie
c.SetCookie(&http.Cookie{
Name: "session_id",
Value: session.ID,
Expires: session.ExpiresAt,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
// Initialize user's database (using UUID as filename)
dbPathPrefix := os.Getenv("DB_PATH_PREFIX")
if dbPathPrefix == "" {
dbPathPrefix = "."
}
userDBPath := fmt.Sprintf("%s/sbv_%s.db", dbPathPrefix, user.ID)
if err := InitUserDB(user.ID, userDBPath); err != nil {
slog.Error("Error initializing user database", "error", err)
return echo.ErrInternalServerError
}
return c.JSON(http.StatusOK, AuthResponse{
Success: true,
User: user,
Session: session,
})
}
func HandleLogin(c echo.Context) error {
var req LoginRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Invalid request body",
})
}
// Validate input
req.Username = strings.TrimSpace(req.Username)
if req.Username == "" || req.Password == "" {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Username and password are required",
})
}
// Get user
user, err := GetUserByUsername(req.Username)
if err != nil {
return c.JSON(http.StatusUnauthorized, AuthResponse{
Success: false,
Error: "Invalid username or password",
})
}
// Verify password
if !VerifyPassword(user, req.Password) {
return c.JSON(http.StatusUnauthorized, AuthResponse{
Success: false,
Error: "Invalid username or password",
})
}
// Create session
session, err := CreateSession(user.ID, user.Username)
if err != nil {
slog.Error("Error creating session", "error", err)
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to create session",
})
}
// Set session cookie
c.SetCookie(&http.Cookie{
Name: "session_id",
Value: session.ID,
Expires: session.ExpiresAt,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
return c.JSON(http.StatusOK, AuthResponse{
Success: true,
User: user,
Session: session,
})
}
func HandleLogout(c echo.Context) error {
// Get session ID from cookie
cookie, err := c.Cookie("session_id")
if err == nil {
// Delete session from database
DeleteSession(cookie.Value)
}
// Clear cookie
c.SetCookie(&http.Cookie{
Name: "session_id",
Value: "",
Expires: time.Now().Add(-1 * time.Hour),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
return c.JSON(http.StatusOK, map[string]bool{
"success": true,
})
}
func HandleMe(c echo.Context) error {
// Get session from context (set by AuthMiddleware)
session, ok := c.Get("session").(*Session)
if !ok {
return c.JSON(http.StatusUnauthorized, AuthResponse{
Success: false,
Error: "Unauthorized",
})
}
// Get user
user, err := GetUserByUsername(session.Username)
if err != nil {
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to get user info",
})
}
return c.JSON(http.StatusOK, AuthResponse{
Success: true,
User: user,
Session: session,
})
}
func HandleChangePassword(c echo.Context) error {
// Get session from context (set by AuthMiddleware)
session, ok := c.Get("session").(*Session)
if !ok {
return c.JSON(http.StatusUnauthorized, AuthResponse{
Success: false,
Error: "Unauthorized",
})
}
var req ChangePasswordRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "Invalid request body",
})
}
// Validate input
if req.OldPassword == "" || req.NewPassword == "" || req.ConfirmPassword == "" {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "All fields are required",
})
}
if req.NewPassword != req.ConfirmPassword {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "New passwords do not match",
})
}
if len(req.NewPassword) < 6 {
return c.JSON(http.StatusBadRequest, AuthResponse{
Success: false,
Error: "New password must be at least 6 characters",
})
}
// Get user
user, err := GetUserByUsername(session.Username)
if err != nil {
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to get user info",
})
}
// Verify old password
if !VerifyPassword(user, req.OldPassword) {
return c.JSON(http.StatusUnauthorized, AuthResponse{
Success: false,
Error: "Current password is incorrect",
})
}
// Update password
if err := UpdatePassword(user.ID, req.NewPassword); err != nil {
slog.Error("Error updating password", "error", err)
return c.JSON(http.StatusInternalServerError, AuthResponse{
Success: false,
Error: "Failed to update password",
})
}
return c.JSON(http.StatusOK, AuthResponse{
Success: true,
})
}
+39
View File
@@ -0,0 +1,39 @@
package internal
import (
"net/http"
"github.com/labstack/echo/v4"
)
// CustomCORSMiddleware creates a custom CORS middleware that properly handles credentials
func CustomCORSMiddleware() echo.MiddlewareFunc {
allowedOrigins := map[string]bool{
"http://localhost:5173": true,
"http://localhost:3000": true,
"http://localhost:8081": true,
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
origin := c.Request().Header.Get("Origin")
// Check if origin is allowed
if allowedOrigins[origin] {
c.Response().Header().Set("Access-Control-Allow-Origin", origin)
c.Response().Header().Set("Access-Control-Allow-Credentials", "true")
}
// Handle preflight requests
if c.Request().Method == http.MethodOptions {
c.Response().Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Response().Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization")
c.Response().Header().Set("Access-Control-Max-Age", "3600")
return c.NoContent(http.StatusNoContent)
}
return next(c)
}
}
}
+909
View File
@@ -0,0 +1,909 @@
package internal
import (
"database/sql"
"fmt"
"log/slog"
"os"
"regexp"
"strings"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
)
var db *sql.DB
// userDBs stores per-user database connections (keyed by user ID)
var userDBs = make(map[string]*sql.DB)
var userDBsMutex sync.RWMutex
// truncateString truncates a string to maxLen characters for logging
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// SanitizeUsername converts a username to a safe filesystem name
func SanitizeUsername(username string) string {
// Convert to lowercase
safe := strings.ToLower(username)
// Replace spaces and special characters with underscores
reg := regexp.MustCompile(`[^a-z0-9]+`)
safe = reg.ReplaceAllString(safe, "_")
// Remove leading/trailing underscores
safe = strings.Trim(safe, "_")
// Ensure it's not empty
if safe == "" {
safe = "user"
}
return safe
}
func InitDB(filepath string) error {
var err error
db, err = sql.Open("sqlite3", filepath)
if err != nil {
return err
}
if err = db.Ping(); err != nil {
return err
}
createTableSQL := `
-- Unified table for SMS messages, MMS messages, and call logs
-- record_type: 1 = SMS, 2 = MMS, 3 = call
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
record_type INTEGER NOT NULL DEFAULT 1,
address TEXT NOT NULL,
body TEXT,
type INTEGER NOT NULL,
date INTEGER NOT NULL,
read INTEGER DEFAULT 0,
thread_id INTEGER,
subject TEXT,
media_type TEXT,
media_data BLOB,
protocol INTEGER,
status INTEGER,
service_center TEXT,
sub_id INTEGER,
contact_name TEXT,
sender TEXT,
content_type TEXT,
read_report INTEGER,
read_status INTEGER,
message_id TEXT,
message_size INTEGER,
message_type INTEGER,
sim_slot INTEGER,
addresses TEXT,
duration INTEGER,
presentation INTEGER,
subscription_id TEXT
);
CREATE INDEX IF NOT EXISTS idx_address ON messages(address);
CREATE INDEX IF NOT EXISTS idx_date ON messages(date);
CREATE INDEX IF NOT EXISTS idx_thread ON messages(thread_id);
CREATE INDEX IF NOT EXISTS idx_record_type ON messages(record_type);
CREATE INDEX IF NOT EXISTS idx_record_type_date ON messages(record_type, date);
-- Create unique constraints for idempotent imports
-- record_type differentiates SMS (1), MMS (2), and calls (3)
CREATE UNIQUE INDEX IF NOT EXISTS idx_message_unique ON messages(record_type, address, date, type, COALESCE(body, ''), COALESCE(content_type, ''), COALESCE(message_id, ''), COALESCE(duration, 0));
-- Create FTS5 virtual table for full-text search of messages
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
message_id UNINDEXED,
address UNINDEXED,
body,
contact_name UNINDEXED,
date UNINDEXED,
content='messages',
content_rowid='id'
);
-- Create triggers to keep FTS table in sync
CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, message_id, address, body, contact_name, date)
VALUES (new.id, new.id, new.address, new.body, new.contact_name, new.date);
END;
CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, message_id, address, body, contact_name, date)
VALUES('delete', old.id, old.id, old.address, old.body, old.contact_name, old.date);
END;
CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, message_id, address, body, contact_name, date)
VALUES('delete', old.id, old.id, old.address, old.body, old.contact_name, old.date);
INSERT INTO messages_fts(rowid, message_id, address, body, contact_name, date)
VALUES (new.id, new.id, new.address, new.body, new.contact_name, new.date);
END;
`
_, err = db.Exec(createTableSQL)
if err != nil {
return err
}
slog.Info("Database initialized successfully")
return nil
}
// InitUserDB initializes a database for a specific user
func InitUserDB(userID string, filepath string) error {
userDB, err := sql.Open("sqlite3", filepath)
if err != nil {
return err
}
if err = userDB.Ping(); err != nil {
return err
}
createTableSQL := `
-- Unified table for SMS messages, MMS messages, and call logs
-- record_type: 1 = SMS, 2 = MMS, 3 = call
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
record_type INTEGER NOT NULL DEFAULT 1,
address TEXT NOT NULL,
body TEXT,
type INTEGER NOT NULL,
date INTEGER NOT NULL,
read INTEGER DEFAULT 0,
thread_id INTEGER,
subject TEXT,
media_type TEXT,
media_data BLOB,
protocol INTEGER,
status INTEGER,
service_center TEXT,
sub_id INTEGER,
contact_name TEXT,
sender TEXT,
content_type TEXT,
read_report INTEGER,
read_status INTEGER,
message_id TEXT,
message_size INTEGER,
message_type INTEGER,
sim_slot INTEGER,
addresses TEXT,
duration INTEGER,
presentation INTEGER,
subscription_id TEXT
);
CREATE INDEX IF NOT EXISTS idx_address ON messages(address);
CREATE INDEX IF NOT EXISTS idx_date ON messages(date);
CREATE INDEX IF NOT EXISTS idx_thread ON messages(thread_id);
CREATE INDEX IF NOT EXISTS idx_record_type ON messages(record_type);
CREATE INDEX IF NOT EXISTS idx_record_type_date ON messages(record_type, date);
-- record_type differentiates SMS (1), MMS (2), and calls (3)
CREATE UNIQUE INDEX IF NOT EXISTS idx_message_unique ON messages(record_type, address, date, type, COALESCE(body, ''), COALESCE(content_type, ''), COALESCE(message_id, ''), COALESCE(duration, 0));
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
message_id UNINDEXED,
address UNINDEXED,
body,
contact_name UNINDEXED,
date UNINDEXED,
content='messages',
content_rowid='id'
);
CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, message_id, address, body, contact_name, date)
VALUES (new.id, new.id, new.address, new.body, new.contact_name, new.date);
END;
CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, message_id, address, body, contact_name, date)
VALUES('delete', old.id, old.id, old.address, old.body, old.contact_name, old.date);
END;
CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
INSERT INTO messages_fts(messages_fts, rowid, message_id, address, body, contact_name, date)
VALUES('delete', old.id, old.id, old.address, old.body, old.contact_name, old.date);
INSERT INTO messages_fts(rowid, message_id, address, body, contact_name, date)
VALUES (new.id, new.id, new.address, new.body, new.contact_name, new.date);
END;
`
_, err = userDB.Exec(createTableSQL)
if err != nil {
return err
}
// Store in map
userDBsMutex.Lock()
userDBs[userID] = userDB
userDBsMutex.Unlock()
slog.Info("User database initialized", "user_id", userID, "path", filepath)
return nil
}
// GetUserDB retrieves the database connection for a specific user
func GetUserDB(userID string, username string) (*sql.DB, error) {
userDBsMutex.RLock()
defer userDBsMutex.RUnlock()
userDB, exists := userDBs[userID]
if !exists {
// Try to open the database if it exists
dbPathPrefix := os.Getenv("DB_PATH_PREFIX")
if dbPathPrefix == "" {
dbPathPrefix = "."
}
// Use UUID as database filename instead of sanitized username
filepath := fmt.Sprintf("%s/sbv_%s.db", dbPathPrefix, userID)
if _, err := os.Stat(filepath); err == nil {
// Database file exists, try to open it
userDBsMutex.RUnlock()
if err := InitUserDB(userID, filepath); err != nil {
userDBsMutex.RLock()
return nil, fmt.Errorf("failed to open user database: %w", err)
}
userDBsMutex.RLock()
userDB = userDBs[userID]
} else {
return nil, fmt.Errorf("user database not found for user %s", username)
}
}
return userDB, nil
}
func InsertMessage(userDB *sql.DB, msg *Message) error {
// Convert addresses slice to JSON string
var addressesJSON string
if len(msg.Addresses) > 0 {
addresses := strings.Join(msg.Addresses, ",")
addressesJSON = addresses
}
// Determine record type: 1 = SMS, 2 = MMS
// MMS messages have ContentType set (e.g., 'application/vnd.wap.multipart.related')
// SMS messages do not have ContentType
recordType := 1 // Default to SMS
if msg.ContentType != "" {
recordType = 2 // MMS
}
query := `
INSERT INTO messages (
record_type, address, body, type, date, read, thread_id, subject, media_type, media_data,
protocol, status, service_center, sub_id, contact_name, sender,
content_type, read_report, read_status, message_id, message_size, message_type, sim_slot, addresses
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING
`
result, err := userDB.Exec(query,
recordType, // record_type: 1 = SMS, 2 = MMS
msg.Address,
msg.Body,
msg.Type,
msg.Date.Unix(),
msg.Read,
msg.ThreadID,
msg.Subject,
msg.MediaType,
msg.MediaData,
msg.Protocol,
msg.Status,
msg.ServiceCenter,
msg.SubID,
msg.ContactName,
msg.Sender,
msg.ContentType,
msg.ReadReport,
msg.ReadStatus,
msg.MessageID,
msg.MessageSize,
msg.MessageType,
msg.SimSlot,
addressesJSON,
)
if err != nil {
slog.Debug("InsertMessage: Error inserting message", "error", err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
msg.ID = id
return nil
}
func InsertCallLog(userDB *sql.DB, call *CallLog) error {
query := `
INSERT INTO messages (record_type, address, type, date, duration, presentation, subscription_id, contact_name)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING
`
result, err := userDB.Exec(query,
3, // record_type: 3 = call
call.Number,
call.Type,
call.Date.Unix(),
call.Duration,
call.Presentation,
call.SubscriptionID,
call.ContactName,
)
if err != nil {
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
call.ID = id
return nil
}
// InsertCallLogBatch inserts multiple call logs in a single transaction for better performance
func InsertCallLogBatch(userDB *sql.DB, calls []CallLog) error {
if len(calls) == 0 {
return nil
}
tx, err := userDB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`
INSERT INTO messages (record_type, address, type, date, duration, presentation, subscription_id, contact_name)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING
`)
if err != nil {
return err
}
defer stmt.Close()
for i := range calls {
_, err := stmt.Exec(
3, // record_type: 3 = call
calls[i].Number,
calls[i].Type,
calls[i].Date.Unix(),
calls[i].Duration,
calls[i].Presentation,
calls[i].SubscriptionID,
calls[i].ContactName,
)
if err != nil {
return err
}
}
return tx.Commit()
}
func GetConversations(userDB *sql.DB, startDate, endDate *time.Time) ([]Conversation, error) {
// Build a query that groups all activity (messages and calls) by address
query := `
SELECT
address,
MAX(COALESCE(contact_name, '')) as contact_name,
(
SELECT COALESCE(subject, '')
FROM messages m2
WHERE m2.address = messages.address
AND m2.subject IS NOT NULL
AND m2.subject != ''
ORDER BY date DESC
LIMIT 1
) as subject,
(
SELECT
CASE
WHEN record_type = 1 THEN body -- SMS
WHEN record_type = 2 THEN body -- MMS
WHEN record_type = 3 AND type = 1 THEN 'Incoming call'
WHEN record_type = 3 AND type = 2 THEN 'Outgoing call'
WHEN record_type = 3 AND type = 3 THEN 'Missed call'
WHEN record_type = 3 AND type = 4 THEN 'Voicemail'
WHEN record_type = 3 AND type = 5 THEN 'Rejected call'
WHEN record_type = 3 AND type = 6 THEN 'Refused call'
ELSE 'Call'
END
FROM messages m3
WHERE m3.address = messages.address
ORDER BY date DESC
LIMIT 1
) as last_message,
MAX(date) as last_date,
COUNT(*) as activity_count
FROM messages
WHERE 1=1
`
args := []interface{}{}
if startDate != nil {
query += " AND date >= ?"
args = append(args, startDate.Unix())
}
if endDate != nil {
query += " AND date <= ?"
args = append(args, endDate.Unix())
}
query += `
GROUP BY address
ORDER BY last_date DESC
`
rows, err := userDB.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
conversations := []Conversation{}
for rows.Next() {
var c Conversation
var lastDateUnix int64
var subject sql.NullString
err := rows.Scan(&c.Address, &c.ContactName, &subject, &c.LastMessage, &lastDateUnix, &c.MessageCount)
if err != nil {
return nil, err
}
c.LastDate = time.Unix(lastDateUnix, 0)
c.Subject = subject.String
c.Type = "conversation" // Changed from "message" or "call" to indicate it's a merged conversation
conversations = append(conversations, c)
}
return conversations, nil
}
func formatCallType(callType int) string {
switch callType {
case 1:
return "Incoming call"
case 2:
return "Outgoing call"
case 3:
return "Missed call"
case 4:
return "Voicemail"
case 5:
return "Rejected call"
case 6:
return "Refused call"
default:
return "Call"
}
}
func GetMessages(userDB *sql.DB, address string, startDate, endDate *time.Time) ([]Message, error) {
query := `
SELECT id, address, body, type, date, read, thread_id,
COALESCE(subject, ''), COALESCE(media_type, ''), COALESCE(media_data, ''),
COALESCE(protocol, 0), COALESCE(status, 0), COALESCE(service_center, ''),
COALESCE(sub_id, 0), COALESCE(contact_name, ''), COALESCE(sender, ''),
COALESCE(content_type, ''), COALESCE(read_report, 0), COALESCE(read_status, 0),
COALESCE(message_id, ''), COALESCE(message_size, 0), COALESCE(message_type, 0),
COALESCE(sim_slot, 0), COALESCE(addresses, '')
FROM messages
WHERE record_type IN (1, 2) AND address = ? -- 1 = SMS, 2 = MMS
`
args := []interface{}{address}
if startDate != nil {
query += " AND date >= ?"
args = append(args, startDate.Unix())
}
if endDate != nil {
query += " AND date <= ?"
args = append(args, endDate.Unix())
}
query += " ORDER BY date ASC"
slog.Debug("GetMessages: executing query", "address", address)
slog.Debug("GetMessages: SQL query", "query", query)
slog.Debug("GetMessages: query arguments", "args", args)
rows, err := userDB.Query(query, args...)
if err != nil {
slog.Debug("GetMessages: Query error", "error", err)
return nil, err
}
defer rows.Close()
messages := []Message{}
for rows.Next() {
var m Message
var dateUnix int64
var readInt int
var addressesStr string
err := rows.Scan(&m.ID, &m.Address, &m.Body, &m.Type, &dateUnix,
&readInt, &m.ThreadID, &m.Subject, &m.MediaType, &m.MediaData,
&m.Protocol, &m.Status, &m.ServiceCenter, &m.SubID, &m.ContactName, &m.Sender,
&m.ContentType, &m.ReadReport, &m.ReadStatus, &m.MessageID,
&m.MessageSize, &m.MessageType, &m.SimSlot, &addressesStr)
if err != nil {
return nil, err
}
m.Date = time.Unix(dateUnix, 0)
m.Read = readInt == 1
// Parse addresses from comma-separated string
if addressesStr != "" {
m.Addresses = strings.Split(addressesStr, ",")
}
// Don't load media data - it will be fetched on demand via /api/media
// Clear MediaData to save memory in response
m.MediaData = nil
slog.Debug("GetMessages: Message", "id", m.ID, "address", m.Address, "media_type", m.MediaType, "body", truncateString(m.Body, 50))
messages = append(messages, m)
}
slog.Debug("GetMessages: Returning messages", "count", len(messages), "address", address)
return messages, nil
}
func GetCallLogs(userDB *sql.DB, number string, startDate, endDate *time.Time) ([]CallLog, error) {
query := `
SELECT id, address, duration, date, type,
COALESCE(presentation, 0), COALESCE(subscription_id, ''), COALESCE(contact_name, '')
FROM messages
WHERE record_type = 3 AND address = ? -- 3 = call
`
args := []interface{}{number}
if startDate != nil {
query += " AND date >= ?"
args = append(args, startDate.Unix())
}
if endDate != nil {
query += " AND date <= ?"
args = append(args, endDate.Unix())
}
query += " ORDER BY date ASC"
rows, err := userDB.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
calls := []CallLog{}
for rows.Next() {
var c CallLog
var dateUnix int64
err := rows.Scan(&c.ID, &c.Number, &c.Duration, &dateUnix, &c.Type,
&c.Presentation, &c.SubscriptionID, &c.ContactName)
if err != nil {
return nil, err
}
c.Date = time.Unix(dateUnix, 0)
calls = append(calls, c)
}
return calls, nil
}
func GetActivity(userDB *sql.DB, startDate, endDate *time.Time, limit, offset int) ([]ActivityItem, error) {
return GetActivityByAddress(userDB, "", startDate, endDate, limit, offset)
}
func GetActivityByAddress(userDB *sql.DB, address string, startDate, endDate *time.Time, limit, offset int) ([]ActivityItem, error) {
var activities []ActivityItem
// Query from unified table
query := `
SELECT record_type, date, address, COALESCE(contact_name, '') as contact_name,
id, body, type, read, thread_id, COALESCE(subject, ''),
COALESCE(media_type, ''), COALESCE(media_data, ''),
COALESCE(protocol, 0), COALESCE(status, 0), COALESCE(service_center, ''),
COALESCE(sub_id, 0), COALESCE(content_type, ''), COALESCE(read_report, 0),
COALESCE(read_status, 0), COALESCE(message_id, ''), COALESCE(message_size, 0),
COALESCE(message_type, 0), COALESCE(sim_slot, 0), COALESCE(addresses, ''),
COALESCE(duration, 0), COALESCE(presentation, 0), COALESCE(subscription_id, ''),
COALESCE(sender, '')
FROM messages
WHERE 1=1
`
args := []interface{}{}
if address != "" {
query += " AND address = ?"
args = append(args, address)
}
if startDate != nil {
query += " AND date >= ?"
args = append(args, startDate.Unix())
}
if endDate != nil {
query += " AND date <= ?"
args = append(args, endDate.Unix())
}
query += " ORDER BY date ASC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
slog.Debug("GetActivityByAddress: executing query", "address", address, "limit", limit, "offset", offset)
slog.Debug("GetActivityByAddress: SQL query", "query", query)
slog.Debug("GetActivityByAddress: query arguments", "args", args)
rows, err := userDB.Query(query, args...)
if err != nil {
slog.Debug("GetActivityByAddress: Query error", "error", err)
return nil, err
}
defer rows.Close()
for rows.Next() {
var recordType int64
var dateUnix int64
var address, contactName string
// Shared fields
var id sql.NullInt64
var itemType sql.NullInt64 // type field - used for both message type and call type
// Message fields
var body, subject, mediaType, serviceCenter, contentType, messageID, subscriptionID, addressesStr, sender sql.NullString
var readInt, threadID, protocol, status, subID, readReport, readStatus, messageSize, messageTypeField, simSlot sql.NullInt64
var mediaData []byte
// Call fields
var duration, presentation sql.NullInt64
err := rows.Scan(&recordType, &dateUnix, &address, &contactName,
&id, &body, &itemType, &readInt, &threadID, &subject,
&mediaType, &mediaData,
&protocol, &status, &serviceCenter,
&subID, &contentType, &readReport,
&readStatus, &messageID, &messageSize,
&messageTypeField, &simSlot, &addressesStr,
&duration, &presentation, &subscriptionID, &sender)
if err != nil {
return nil, err
}
var activityTypeStr string
if recordType == 1 || recordType == 2 {
// 1 = SMS, 2 = MMS
activityTypeStr = "message"
} else if recordType == 3 {
// 3 = call
activityTypeStr = "call"
}
activity := ActivityItem{
Type: activityTypeStr,
Date: time.Unix(dateUnix, 0),
Address: address,
ContactName: contactName,
}
if (recordType == 1 || recordType == 2) && id.Valid {
// Handle SMS (1) and MMS (2)
msg := &Message{
ID: id.Int64,
Address: address,
Body: body.String,
Date: time.Unix(dateUnix, 0),
ThreadID: int(threadID.Int64),
Subject: subject.String,
MediaType: mediaType.String,
MediaData: mediaData,
Protocol: int(protocol.Int64),
Status: int(status.Int64),
ServiceCenter: serviceCenter.String,
SubID: int(subID.Int64),
ContactName: contactName,
ContentType: contentType.String,
ReadReport: int(readReport.Int64),
ReadStatus: int(readStatus.Int64),
MessageID: messageID.String,
MessageSize: int(messageSize.Int64),
MessageType: int(messageTypeField.Int64),
SimSlot: int(simSlot.Int64),
Sender: sender.String,
}
if itemType.Valid {
msg.Type = int(itemType.Int64)
}
if readInt.Valid {
msg.Read = readInt.Int64 == 1
}
// Parse addresses from comma-separated string
slog.Debug("GetActivityByAddress: addressesStr raw", "id", id.Int64, "valid", addressesStr.Valid, "value", addressesStr.String)
if addressesStr.Valid && addressesStr.String != "" {
msg.Addresses = strings.Split(addressesStr.String, ",")
slog.Debug("GetActivityByAddress: addresses split result", "id", id.Int64, "count", len(msg.Addresses), "values", msg.Addresses)
} else if strings.Contains(address, ",") {
// Fallback: If addresses field is empty but address contains commas,
// this is a group conversation - parse the address field
msg.Addresses = strings.Split(address, ",")
slog.Debug("GetActivityByAddress: addresses from address field", "id", id.Int64, "count", len(msg.Addresses), "values", msg.Addresses)
}
// Don't load media data - it will be fetched on demand via /api/media
// Clear MediaData to save memory in response
msg.MediaData = nil
slog.Debug("GetActivityByAddress: Message", "id", msg.ID, "address", msg.Address, "type", msg.Type, "sender", msg.Sender, "addresses", msg.Addresses, "media_type", msg.MediaType, "body", truncateString(msg.Body, 50))
activity.Message = msg
} else if recordType == 3 && id.Valid {
// Handle calls (3)
call := &CallLog{
ID: id.Int64,
Number: address,
Duration: int(duration.Int64),
Date: time.Unix(dateUnix, 0),
Type: int(itemType.Int64),
Presentation: int(presentation.Int64),
SubscriptionID: subscriptionID.String,
ContactName: contactName,
}
slog.Debug("GetActivityByAddress: Call", "id", call.ID, "number", call.Number, "type", call.Type, "duration", call.Duration)
activity.Call = call
}
activities = append(activities, activity)
}
slog.Debug("GetActivityByAddress: Returning activities", "count", len(activities), "address", address)
return activities, nil
}
func GetMessageMedia(userDB *sql.DB, messageID string) ([]byte, string, error) {
query := `
SELECT COALESCE(media_data, ''), COALESCE(media_type, '')
FROM messages
WHERE id = ? AND record_type IN (1, 2) -- 1 = SMS, 2 = MMS
`
slog.Debug("GetMessageMedia: Fetching media", "message_id", messageID)
slog.Debug("GetMessageMedia: SQL query", "query", query)
var mediaData []byte
var mediaType string
err := userDB.QueryRow(query, messageID).Scan(&mediaData, &mediaType)
if err != nil {
slog.Debug("GetMessageMedia: Error scanning row", "message_id", messageID, "error", err)
return nil, "", err
}
slog.Debug("GetMessageMedia: Found media", "media_type", mediaType, "data_length", len(mediaData), "message_id", messageID)
if len(mediaData) == 0 || mediaType == "" {
slog.Debug("GetMessageMedia: No media found", "message_id", messageID)
return nil, "", fmt.Errorf("no media found")
}
// Convert HEIC to JPEG if needed
if isHEICContentType(mediaType) {
convertedData, err := convertHEICtoJPEG(mediaData)
if err != nil {
slog.Error("Failed to convert HEIC to JPEG", "message_id", messageID, "error", err)
// Return original if conversion fails
return mediaData, mediaType, nil
}
return convertedData, "image/jpeg", nil
}
// Convert unsupported video formats (3GP, etc.) to MP4 if needed
if needsVideoConversion(mediaType) {
slog.Info("Converting video to MP4", "from_type", mediaType, "message_id", messageID)
convertedData, err := convertVideoToMP4(mediaData)
if err != nil {
slog.Error("Failed to convert video to MP4", "message_id", messageID, "error", err)
// Return original if conversion fails
return mediaData, mediaType, nil
}
slog.Info("Successfully converted video to MP4", "message_id", messageID)
return convertedData, "video/mp4", nil
}
return mediaData, mediaType, nil
}
func GetDateRange(userDB *sql.DB) (time.Time, time.Time, error) {
var minDate, maxDate int64
// Get min/max from unified messages table
query := "SELECT MIN(date), MAX(date) FROM messages"
var min, max sql.NullInt64
err := userDB.QueryRow(query).Scan(&min, &max)
if err != nil && err != sql.ErrNoRows {
return time.Time{}, time.Time{}, err
}
if !min.Valid || !max.Valid {
return time.Time{}, time.Time{}, fmt.Errorf("no data available")
}
minDate = min.Int64
maxDate = max.Int64
return time.Unix(minDate, 0), time.Unix(maxDate, 0), nil
}
// SearchResult represents a message search result
type SearchResult struct {
MessageID int64 `json:"message_id"`
Address string `json:"address"`
ContactName string `json:"contact_name"`
Body string `json:"body"`
Date time.Time `json:"date"`
Snippet string `json:"snippet"`
}
// SearchMessages performs full-text search on message contents
func SearchMessages(userDB *sql.DB, query string, limit int) ([]SearchResult, error) {
if query == "" {
return []SearchResult{}, nil
}
sqlQuery := `
SELECT
m.id,
m.address,
COALESCE(m.contact_name, ''),
m.body,
m.date,
snippet(messages_fts, 2, '<mark>', '</mark>', '...', 50) as snippet
FROM messages_fts
JOIN messages m ON messages_fts.rowid = m.id
WHERE messages_fts MATCH ?
ORDER BY rank
LIMIT ?
`
rows, err := userDB.Query(sqlQuery, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
results := []SearchResult{}
for rows.Next() {
var r SearchResult
var dateUnix int64
err := rows.Scan(&r.MessageID, &r.Address, &r.ContactName, &r.Body, &dateUnix, &r.Snippet)
if err != nil {
return nil, err
}
r.Date = time.Unix(dateUnix, 0)
results = append(results, r)
}
return results, nil
}
+359
View File
@@ -0,0 +1,359 @@
package internal
import (
"database/sql"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/labstack/echo/v4"
)
// getUserDB is a helper function to get the user's database connection from the context
func getUserDB(c echo.Context) (*sql.DB, error) {
userID, ok := c.Get("user_id").(string)
if !ok {
return nil, fmt.Errorf("user_id not found in context")
}
username, ok := c.Get("username").(string)
if !ok {
return nil, fmt.Errorf("username not found in context")
}
return GetUserDB(userID, username)
}
func HandleUpload(c echo.Context) error {
// Use a smaller memory limit for the form parsing itself (32 MB)
// Large files will be streamed directly to disk
err := c.Request().ParseMultipartForm(32 << 20) // 32 MB max in memory
if err != nil {
slog.Error("Error parsing form", "error", err)
return c.JSON(http.StatusBadRequest, UploadResponse{
Success: false,
Error: "Failed to parse form data. File may be too large or corrupted.",
})
}
file, header, err := c.Request().FormFile("file")
if err != nil {
slog.Error("Error getting file", "error", err)
return c.JSON(http.StatusBadRequest, UploadResponse{
Success: false,
Error: "Failed to get file from form",
})
}
defer file.Close()
slog.Info("Receiving file", "filename", header.Filename, "size", header.Size)
// Save uploaded file to temporary location first
tempFilePath, err := SaveUploadedFile(file, header.Filename)
if err != nil {
slog.Error("Error saving file", "error", err)
return c.JSON(http.StatusInternalServerError, UploadResponse{
Success: false,
Error: "Failed to save uploaded file: " + err.Error(),
})
}
slog.Info("File saved", "path", tempFilePath)
// Get user ID from context
userID, ok := c.Get("user_id").(string)
if !ok {
return c.JSON(http.StatusUnauthorized, UploadResponse{
Success: false,
Error: "User not authenticated",
})
}
// Get username from context
username, ok := c.Get("username").(string)
if !ok {
return c.JSON(http.StatusUnauthorized, UploadResponse{
Success: false,
Error: "User not authenticated",
})
}
// Start background processing with user context
go ProcessUploadedFile(userID, username, tempFilePath)
// Return immediately - client will poll /api/progress for status
return c.JSON(http.StatusOK, UploadResponse{
Success: true,
MessageCount: 0,
CallLogCount: 0,
Processing: true,
})
}
func HandleConversations(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
var startDate, endDate *time.Time
if startStr := c.QueryParam("start"); startStr != "" {
t, err := time.Parse(time.RFC3339, startStr)
if err == nil {
startDate = &t
}
}
if endStr := c.QueryParam("end"); endStr != "" {
t, err := time.Parse(time.RFC3339, endStr)
if err == nil {
endDate = &t
}
}
conversations, err := GetConversations(userDB, startDate, endDate)
if err != nil {
slog.Error("Error getting conversations", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get conversations",
})
}
return c.JSON(http.StatusOK, conversations)
}
func HandleMessages(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
address := c.QueryParam("address")
convType := c.QueryParam("type")
if address == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Address parameter required",
})
}
var startDate, endDate *time.Time
if startStr := c.QueryParam("start"); startStr != "" {
t, err := time.Parse(time.RFC3339, startStr)
if err == nil {
startDate = &t
}
}
if endStr := c.QueryParam("end"); endStr != "" {
t, err := time.Parse(time.RFC3339, endStr)
if err == nil {
endDate = &t
}
}
// If type is "call", return call logs instead of messages
if convType == "call" {
calls, err := GetCallLogs(userDB, address, startDate, endDate)
if err != nil {
slog.Error("Error getting call logs", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get call logs",
})
}
return c.JSON(http.StatusOK, calls)
}
// If type is "conversation", return combined messages and calls
if convType == "conversation" {
// Use a large limit to get all activity for this address
// We don't use pagination here since we want all items for the thread view
activities, err := GetActivityByAddress(userDB, address, startDate, endDate, 10000, 0)
if err != nil {
slog.Error("Error getting activity", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get activity",
})
}
return c.JSON(http.StatusOK, activities)
}
messages, err := GetMessages(userDB, address, startDate, endDate)
if err != nil {
slog.Error("Error getting messages", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get messages",
})
}
return c.JSON(http.StatusOK, messages)
}
func HandleActivity(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
var startDate, endDate *time.Time
if startStr := c.QueryParam("start"); startStr != "" {
t, err := time.Parse(time.RFC3339, startStr)
if err == nil {
startDate = &t
}
}
if endStr := c.QueryParam("end"); endStr != "" {
t, err := time.Parse(time.RFC3339, endStr)
if err == nil {
endDate = &t
}
}
// Parse pagination parameters
limit := 50 // default limit
offset := 0 // default offset
if limitStr := c.QueryParam("limit"); limitStr != "" {
if val, err := strconv.Atoi(limitStr); err == nil {
limit = val
}
}
if offsetStr := c.QueryParam("offset"); offsetStr != "" {
if val, err := strconv.Atoi(offsetStr); err == nil {
offset = val
}
}
activities, err := GetActivity(userDB, startDate, endDate, limit, offset)
if err != nil {
slog.Error("Error getting activity", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get activity",
})
}
return c.JSON(http.StatusOK, activities)
}
func HandleDateRange(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
minDate, maxDate, err := GetDateRange(userDB)
if err != nil {
slog.Error("Error getting date range", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get date range",
})
}
return c.JSON(http.StatusOK, map[string]interface{}{
"min_date": minDate,
"max_date": maxDate,
})
}
func HandleProgress(c echo.Context) error {
progress := GetUploadProgress()
if progress == nil {
return c.JSON(http.StatusOK, map[string]interface{}{
"status": "no_upload",
})
}
return c.JSON(http.StatusOK, progress)
}
func HandleMedia(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
// Get message ID from query parameter
messageID := c.QueryParam("id")
if messageID == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Message ID required",
})
}
// Fetch media from database
media, contentType, err := GetMessageMedia(userDB, messageID)
if err != nil {
slog.Error("Error getting media", "error", err)
return c.JSON(http.StatusNotFound, map[string]string{
"error": "Media not found",
})
}
if len(media) == 0 {
return c.JSON(http.StatusNotFound, map[string]string{
"error": "No media for this message",
})
}
// Set appropriate headers
c.Response().Header().Set("Cache-Control", "public, max-age=31536000") // Cache for 1 year
c.Response().Header().Set("Content-Length", fmt.Sprintf("%d", len(media)))
// Write binary data with proper content type
return c.Blob(http.StatusOK, contentType, media)
}
func HandleSearch(c echo.Context) error {
userDB, err := getUserDB(c)
if err != nil {
slog.Error("Error getting user database", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Failed to get user database",
})
}
// Get search query from query parameter
query := c.QueryParam("q")
if query == "" {
return c.JSON(http.StatusOK, []SearchResult{})
}
// Get limit from query parameter, default to 100
limit := 100
if limitStr := c.QueryParam("limit"); limitStr != "" {
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
limit = parsedLimit
}
}
// Perform search
results, err := SearchMessages(userDB, query, limit)
if err != nil {
slog.Error("Error searching messages", "error", err)
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "Search failed: " + err.Error(),
})
}
return c.JSON(http.StatusOK, results)
}
+570
View File
@@ -0,0 +1,570 @@
package internal
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/labstack/echo/v4"
)
// Global test user ID - stored here so setupTestContext can use it
var testUserID string
// setupTestDB creates a test database with sample data
func setupTestDB(t *testing.T) (string, func()) {
tmpDB := "test_handlers.db"
tmpAuthDB := "test_handlers_auth.db"
// Clean up any existing test database
os.Remove(tmpDB)
os.Remove(tmpAuthDB)
// Initialize auth database first
if err := InitAuthDB(tmpAuthDB); err != nil {
t.Fatalf("Failed to initialize auth database: %v", err)
}
// Initialize main database
if err := InitDB(tmpDB); err != nil {
t.Fatalf("Failed to initialize database: %v", err)
}
// Create test user
user, err := CreateUser("testuser", "password123")
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
// Store user ID globally for setupTestContext to use
testUserID = user.ID
// Initialize user database (using UUID-based filename)
userDBPath := fmt.Sprintf("sbv_%s.db", user.ID)
if err := InitUserDB(user.ID, userDBPath); err != nil {
t.Fatalf("Failed to initialize user database: %v", err)
}
// Get user database connection
userDB, err := GetUserDB(user.ID, user.Username)
if err != nil {
t.Fatalf("Failed to get user database: %v", err)
}
// Insert test messages
sampleXML := `<?xml version='1.0' encoding='UTF-8' standalone='yes' ?>
<smses count="3">
<sms protocol="0" address="+15551234567" date="1285799668000" type="2" body="Test sent message" read="1" status="-1" />
<sms protocol="0" address="+15551234567" date="1285799669000" type="1" body="Test received message" read="1" status="-1" />
<mms date="1285799670000" rr="null" sub="null" read="1" ct_t="application/vnd.wap.multipart.related" msg_box="2" address="+15559876543" m_type="128" text_only="0">
<parts>
<part seq="0" ct="text/plain" name="null" chset="106" text="Test MMS message" />
</parts>
<addrs>
<addr address="+15552226543" type="137" charset="106" />
<addr address="+15551116565" type="151" charset="106" />
</addrs>
</mms>
</smses>`
reader := strings.NewReader(sampleXML)
result, err := ParseSMSBackup(reader)
if err != nil {
t.Fatalf("Failed to parse XML: %v", err)
}
for i := range result.Messages {
if err := InsertMessage(userDB, &result.Messages[i]); err != nil {
t.Fatalf("Failed to insert message: %v", err)
}
}
// Return cleanup function
cleanup := func() {
if db != nil {
db.Close()
}
if userDB != nil {
userDB.Close()
}
os.Remove(tmpDB)
os.Remove(tmpAuthDB)
os.Remove(userDBPath)
}
return tmpDB, cleanup
}
// setupTestContext creates an Echo context with user authentication
func setupTestContext(method, url string, body string) (echo.Context, *httptest.ResponseRecorder) {
e := echo.New()
var req *http.Request
if body != "" {
req = httptest.NewRequest(method, url, strings.NewReader(body))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
} else {
req = httptest.NewRequest(method, url, nil)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Set user context (simulating authentication middleware)
// Use the global testUserID which was set by setupTestDB
c.Set("user_id", testUserID)
c.Set("username", "testuser")
return c, rec
}
func TestHealthEndpoint(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
}
if err := handler(c); err != nil {
t.Fatalf("Health check failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
if rec.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rec.Body.String())
}
}
func TestHandleConversations(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/conversations", "")
if err := HandleConversations(c); err != nil {
t.Fatalf("HandleConversations failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var conversations []Conversation
if err := json.Unmarshal(rec.Body.Bytes(), &conversations); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should have 2 conversations (one for +15551234567, one for +15559876543)
if len(conversations) < 1 {
t.Errorf("Expected at least 1 conversation, got %d", len(conversations))
}
}
func TestHandleConversationsWithDateRange(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
// Test with date range that includes all messages
start := time.Unix(1285799668, 0).Add(-time.Hour).Format(time.RFC3339)
end := time.Unix(1285799671, 0).Format(time.RFC3339)
c, rec := setupTestContext(http.MethodGet, "/api/conversations?start="+start+"&end="+end, "")
c.QueryParams().Add("start", start)
c.QueryParams().Add("end", end)
if err := HandleConversations(c); err != nil {
t.Fatalf("HandleConversations with date range failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var conversations []Conversation
if err := json.Unmarshal(rec.Body.Bytes(), &conversations); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if len(conversations) < 1 {
t.Errorf("Expected at least 1 conversation, got %d", len(conversations))
}
}
func TestHandleMessages(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
// First get conversations to find a valid address
c1, rec1 := setupTestContext(http.MethodGet, "/api/conversations", "")
if err := HandleConversations(c1); err != nil {
t.Fatalf("HandleConversations failed: %v", err)
}
var conversations []Conversation
if err := json.Unmarshal(rec1.Body.Bytes(), &conversations); err != nil {
t.Fatalf("Failed to parse conversations: %v", err)
}
if len(conversations) == 0 {
t.Fatal("No conversations found in test database")
}
// Use the first conversation's address
testAddress := conversations[0].Address
c, rec := setupTestContext(http.MethodGet, "/api/messages?address="+testAddress, "")
c.QueryParams().Add("address", testAddress)
if err := HandleMessages(c); err != nil {
t.Fatalf("HandleMessages failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var messages []Message
if err := json.Unmarshal(rec.Body.Bytes(), &messages); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Verify the response is valid JSON array (might be empty if address format doesn't match)
// The important thing is that the handler responds correctly
t.Logf("Got %d messages for address %s", len(messages), testAddress)
}
func TestHandleMessagesWithoutAddress(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/messages", "")
if err := HandleMessages(c); err != nil {
t.Fatalf("HandleMessages failed: %v", err)
}
if rec.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", rec.Code)
}
var response map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse error response: %v", err)
}
if !strings.Contains(response["error"], "Address parameter required") {
t.Errorf("Expected error about missing address, got: %s", response["error"])
}
}
func TestHandleMessagesConversationType(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
// First get conversations to find a valid address
c1, rec1 := setupTestContext(http.MethodGet, "/api/conversations", "")
if err := HandleConversations(c1); err != nil {
t.Fatalf("HandleConversations failed: %v", err)
}
var conversations []Conversation
if err := json.Unmarshal(rec1.Body.Bytes(), &conversations); err != nil {
t.Fatalf("Failed to parse conversations: %v", err)
}
if len(conversations) == 0 {
t.Fatal("No conversations found in test database")
}
// Use the first conversation's address
testAddress := conversations[0].Address
c, rec := setupTestContext(http.MethodGet, "/api/messages?address="+testAddress+"&type=conversation", "")
c.QueryParams().Add("address", testAddress)
c.QueryParams().Add("type", "conversation")
if err := HandleMessages(c); err != nil {
t.Fatalf("HandleMessages with type=conversation failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var activities []ActivityItem
if err := json.Unmarshal(rec.Body.Bytes(), &activities); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Verify the response is valid JSON array (might be empty if address format doesn't match)
// The important thing is that the handler responds correctly with type=conversation
t.Logf("Got %d activities for address %s with type=conversation", len(activities), testAddress)
}
func TestHandleActivity(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/activity", "")
if err := HandleActivity(c); err != nil {
t.Fatalf("HandleActivity failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var activities []ActivityItem
if err := json.Unmarshal(rec.Body.Bytes(), &activities); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should have 3 activities (2 SMS + 1 MMS)
if len(activities) != 3 {
t.Errorf("Expected 3 activities, got %d", len(activities))
}
}
func TestHandleActivityWithPagination(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/activity?limit=1&offset=0", "")
c.QueryParams().Add("limit", "1")
c.QueryParams().Add("offset", "0")
if err := HandleActivity(c); err != nil {
t.Fatalf("HandleActivity with pagination failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var activities []ActivityItem
if err := json.Unmarshal(rec.Body.Bytes(), &activities); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should have exactly 1 activity due to limit
if len(activities) != 1 {
t.Errorf("Expected 1 activity, got %d", len(activities))
}
}
func TestHandleDateRange(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/daterange", "")
if err := HandleDateRange(c); err != nil {
t.Fatalf("HandleDateRange failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var response map[string]interface{}
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["min_date"] == nil {
t.Error("Expected min_date in response")
}
if response["max_date"] == nil {
t.Error("Expected max_date in response")
}
}
func TestHandleMedia(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
// Test without message ID
c, rec := setupTestContext(http.MethodGet, "/api/media", "")
if err := HandleMedia(c); err != nil {
t.Fatalf("HandleMedia failed: %v", err)
}
if rec.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", rec.Code)
}
var response map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse error response: %v", err)
}
if !strings.Contains(response["error"], "Message ID required") {
t.Errorf("Expected error about missing message ID, got: %s", response["error"])
}
}
func TestHandleMediaNotFound(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/media?id=99999", "")
c.QueryParams().Add("id", "99999")
if err := HandleMedia(c); err != nil {
t.Fatalf("HandleMedia failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", rec.Code)
}
}
func TestHandleSearch(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/search?q=Test", "")
c.QueryParams().Add("q", "Test")
if err := HandleSearch(c); err != nil {
t.Fatalf("HandleSearch failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var results []SearchResult
if err := json.Unmarshal(rec.Body.Bytes(), &results); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should find messages containing "Test"
if len(results) < 1 {
t.Errorf("Expected at least 1 search result, got %d", len(results))
}
}
func TestHandleSearchEmpty(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/search", "")
if err := HandleSearch(c); err != nil {
t.Fatalf("HandleSearch failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var results []SearchResult
if err := json.Unmarshal(rec.Body.Bytes(), &results); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should return empty array for empty query
if len(results) != 0 {
t.Errorf("Expected 0 search results for empty query, got %d", len(results))
}
}
func TestHandleSearchWithLimit(t *testing.T) {
_, cleanup := setupTestDB(t)
defer cleanup()
c, rec := setupTestContext(http.MethodGet, "/api/search?q=Test&limit=1", "")
c.QueryParams().Add("q", "Test")
c.QueryParams().Add("limit", "1")
if err := HandleSearch(c); err != nil {
t.Fatalf("HandleSearch with limit failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var results []SearchResult
if err := json.Unmarshal(rec.Body.Bytes(), &results); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should respect the limit
if len(results) > 1 {
t.Errorf("Expected at most 1 search result, got %d", len(results))
}
}
func TestHandleProgress(t *testing.T) {
c, rec := setupTestContext(http.MethodGet, "/api/progress", "")
if err := HandleProgress(c); err != nil {
t.Fatalf("HandleProgress failed: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rec.Code)
}
var response map[string]interface{}
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should return no_upload status when no upload is in progress
if response["status"] != "no_upload" {
t.Errorf("Expected status 'no_upload', got '%v'", response["status"])
}
}
func TestGetUserDBHelperMissingUserID(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Don't set user_id in context
_, err := getUserDB(c)
if err == nil {
t.Error("Expected error when user_id is missing")
}
if !strings.Contains(err.Error(), "user_id not found") {
t.Errorf("Expected error about missing user_id, got: %v", err)
}
}
func TestGetUserDBHelperMissingUsername(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Set user_id but not username
c.Set("user_id", "test-user-id")
_, err := getUserDB(c)
if err == nil {
t.Error("Expected error when username is missing")
}
if !strings.Contains(err.Error(), "username not found") {
t.Errorf("Expected error about missing username, got: %v", err)
}
}
+67
View File
@@ -0,0 +1,67 @@
//go:build !heic
package internal
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/jpeg"
"log/slog"
)
// convertHEICtoJPEG returns a placeholder image when HEIC support is disabled
// This version does not require the libheif library
func convertHEICtoJPEG(heicData []byte) ([]byte, error) {
slog.Warn("HEIC conversion is disabled. Returning placeholder image. Build with -tags heic to enable HEIC support.")
// Return a simple placeholder JPEG image (400x300 gray rectangle with text)
// This is better than returning an error, as it allows the app to function
return generatePlaceholderJPEG()
}
// generatePlaceholderJPEG creates a simple gray placeholder image
func generatePlaceholderJPEG() ([]byte, error) {
// Create a 400x300 image
width, height := 400, 300
img := image.NewRGBA(image.Rect(0, 0, width, height))
// Fill with gray background
gray := color.RGBA{200, 200, 200, 255}
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
img.Set(x, y, gray)
}
}
// Add a dark border
borderColor := color.RGBA{100, 100, 100, 255}
for x := 0; x < width; x++ {
img.Set(x, 0, borderColor)
img.Set(x, height-1, borderColor)
}
for y := 0; y < height; y++ {
img.Set(0, y, borderColor)
img.Set(width-1, y, borderColor)
}
// Encode as JPEG
var buf bytes.Buffer
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 80})
if err != nil {
return nil, fmt.Errorf("failed to encode placeholder image: %w", err)
}
return buf.Bytes(), nil
}
// Alternative: Return a base64-encoded minimal JPEG (1x1 pixel)
// This is more efficient but less user-friendly
func generateMinimalPlaceholderJPEG() ([]byte, error) {
// 1x1 gray pixel JPEG (base64 encoded minimal JPEG)
minimalJPEG := "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwA/wA/h"
return base64.StdEncoding.DecodeString(minimalJPEG)
}
+55
View File
@@ -0,0 +1,55 @@
//go:build heic
package internal
import (
"bytes"
"image/jpeg"
"github.com/strukturag/libheif-go"
)
// convertHEICtoJPEG converts HEIC image data to JPEG format
// Returns the converted JPEG data or an error if conversion fails
// This version requires the libheif library and is enabled with the 'heic' build tag
func convertHEICtoJPEG(heicData []byte) ([]byte, error) {
// Create a new HEIF context
ctx, err := libheif.NewContext()
if err != nil {
return nil, err
}
// Read HEIC data from memory
err = ctx.ReadFromMemory(heicData)
if err != nil {
return nil, err
}
// Get the primary image handle
handle, err := ctx.GetPrimaryImageHandle()
if err != nil {
return nil, err
}
// Decode the image to RGB format
img, err := handle.DecodeImage(libheif.ColorspaceRGB, libheif.ChromaInterleavedRGB, nil)
if err != nil {
return nil, err
}
// Convert to Go's standard image.Image
goImg, err := img.GetImage()
if err != nil {
return nil, err
}
// Encode as JPEG with high quality
var buf bytes.Buffer
err = jpeg.Encode(&buf, goImg, &jpeg.Options{Quality: 90})
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
+49
View File
@@ -0,0 +1,49 @@
package internal
import (
"net/http"
"github.com/labstack/echo/v4"
)
// AuthMiddleware checks for a valid session cookie
func AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Get session ID from cookie
cookie, err := c.Cookie("session_id")
if err != nil {
return c.JSON(http.StatusUnauthorized, map[string]string{
"error": "Unauthorized: No session found",
})
}
// Validate session
session, err := GetSession(cookie.Value)
if err != nil {
return c.JSON(http.StatusUnauthorized, map[string]string{
"error": "Unauthorized: Invalid or expired session",
})
}
// Store session in context for use by handlers
c.Set("session", session)
c.Set("user_id", session.UserID)
c.Set("username", session.Username)
return next(c)
}
}
// NoCacheMiddleware adds cache control headers to prevent browser caching
// This ensures that dynamic API responses are always fetched fresh from the server
func NoCacheMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Set headers to prevent caching
c.Response().Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, private")
c.Response().Header().Set("Pragma", "no-cache")
c.Response().Header().Set("Expires", "0")
return next(c)
}
}
+112
View File
@@ -0,0 +1,112 @@
package internal
import "time"
type Message struct {
ID int64 `json:"id"`
Address string `json:"address"`
Body string `json:"body"`
Type int `json:"type"` // 1 = received, 2 = sent, 3 = draft, 4 = outbox, 5 = failed, 6 = queued
Date time.Time `json:"date"`
Read bool `json:"read"`
ThreadID int `json:"thread_id"`
Subject string `json:"subject,omitempty"`
MediaType string `json:"media_type,omitempty"`
MediaData []byte `json:"-"`
MediaBase64 string `json:"media_base64,omitempty"`
// Additional SMS fields
Protocol int `json:"protocol,omitempty"`
Status int `json:"status,omitempty"` // -1 = none, 0 = complete, 32 = pending, 64 = failed
ServiceCenter string `json:"service_center,omitempty"`
SubID int `json:"sub_id,omitempty"`
ContactName string `json:"contact_name,omitempty"`
Sender string `json:"sender,omitempty"` // Sender phone number for received messages
// Additional MMS fields
ContentType string `json:"content_type,omitempty"` // ct_t field
ReadReport int `json:"read_report,omitempty"` // rr field
ReadStatus int `json:"read_status,omitempty"`
MessageID string `json:"message_id,omitempty"` // m_id field
MessageSize int `json:"message_size,omitempty"` // m_size field
MessageType int `json:"message_type,omitempty"` // m_type field
SimSlot int `json:"sim_slot,omitempty"`
Addresses []string `json:"addresses,omitempty"` // All phone numbers in conversation (for MMS)
}
type CallLog struct {
ID int64 `json:"id"`
Number string `json:"number"`
Duration int `json:"duration"` // in seconds
Date time.Time `json:"date"`
Type int `json:"type"` // 1 = incoming, 2 = outgoing, 3 = missed, 4 = voicemail, 5 = rejected, 6 = refused
Presentation int `json:"presentation,omitempty"` // 1 = allowed, 2 = restricted, 3 = unknown, 4 = payphone
SubscriptionID string `json:"subscription_id,omitempty"`
ContactName string `json:"contact_name,omitempty"`
}
type Conversation struct {
Address string `json:"address"`
ContactName string `json:"contact_name,omitempty"`
Subject string `json:"subject,omitempty"`
LastMessage string `json:"last_message"`
LastDate time.Time `json:"last_date"`
MessageCount int `json:"message_count"`
Type string `json:"type"` // "sms", "mms", or "call"
}
type ActivityItem struct {
Type string `json:"type"` // "message" or "call"
Date time.Time `json:"date"`
Address string `json:"address"`
ContactName string `json:"contact_name,omitempty"`
// Message-specific fields
Message *Message `json:"message,omitempty"`
// Call-specific fields
Call *CallLog `json:"call,omitempty"`
}
type UploadResponse struct {
Success bool `json:"success"`
MessageCount int `json:"message_count"`
CallLogCount int `json:"call_log_count"`
Processing bool `json:"processing,omitempty"`
Error string `json:"error,omitempty"`
}
type User struct {
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"-"` // Never send password hash to client
CreatedAt time.Time `json:"created_at"`
}
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Username string `json:"username"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type AuthResponse struct {
Success bool `json:"success"`
User *User `json:"user,omitempty"`
Session *Session `json:"session,omitempty"`
Error string `json:"error,omitempty"`
}
type ChangePasswordRequest struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
ConfirmPassword string `json:"confirm_password"`
}
+883
View File
@@ -0,0 +1,883 @@
package internal
import (
"bytes"
"database/sql"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
)
type SMSBackup struct {
XMLName xml.Name `xml:"smses"`
Count int `xml:"count,attr"`
Messages []SMSEntry `xml:"sms"`
MMS []MMSEntry `xml:"mms"`
Calls []CallEntry `xml:"call"`
}
type SMSEntry struct {
Address string `xml:"address,attr"`
Date string `xml:"date,attr"`
Type string `xml:"type,attr"`
Body string `xml:"body,attr"`
Read string `xml:"read,attr"`
ThreadID string `xml:"thread_id,attr"`
Subject string `xml:"subject,attr"`
Protocol string `xml:"protocol,attr"`
TOA string `xml:"toa,attr"`
SCTOA string `xml:"sc_toa,attr"`
ServiceCenter string `xml:"service_center,attr"`
Status string `xml:"status,attr"`
SubID string `xml:"sub_id,attr"`
ReadableDate string `xml:"readable_date,attr"`
ContactName string `xml:"contact_name,attr"`
}
type MMSEntry struct {
Address string `xml:"address,attr"`
Date string `xml:"date,attr"`
Type string `xml:"msg_box,attr"`
Read string `xml:"read,attr"`
ThreadID string `xml:"thread_id,attr"`
Subject string `xml:"sub,attr"`
TrID string `xml:"tr_id,attr"`
ContentType string `xml:"ct_t,attr"`
ReadReport string `xml:"rr,attr"`
ReadStatus string `xml:"read_status,attr"`
MessageID string `xml:"m_id,attr"`
MessageSize string `xml:"m_size,attr"`
MessageType string `xml:"m_type,attr"`
SimSlot string `xml:"sim_slot,attr"`
ReadableDate string `xml:"readable_date,attr"`
ContactName string `xml:"contact_name,attr"`
Parts []MMSPart `xml:"parts>part"`
Addrs []MMSAddr `xml:"addrs>addr"`
Body string `xml:"body,attr"`
}
type MMSPart struct {
Seq string `xml:"seq,attr"`
ContentType string `xml:"ct,attr"`
Name string `xml:"name,attr"`
Charset string `xml:"chset,attr"`
CL string `xml:"cl,attr"`
Text string `xml:"text,attr"`
Data string `xml:"data,attr"`
}
type MMSAddr struct {
Address string `xml:"address,attr"`
Type string `xml:"type,attr"`
Charset string `xml:"charset,attr"`
}
type CallEntry struct {
Number string `xml:"number,attr"`
Duration string `xml:"duration,attr"`
Date string `xml:"date,attr"`
Type string `xml:"type,attr"`
Presentation string `xml:"presentation,attr"`
SubscriptionID string `xml:"subscription_id,attr"`
ReadableDate string `xml:"readable_date,attr"`
ContactName string `xml:"contact_name,attr"`
}
type ParseResult struct {
Messages []Message
Calls []CallLog
}
func ParseSMSBackup(r io.Reader) (ParseResult, error) {
var backup SMSBackup
decoder := xml.NewDecoder(r)
err := decoder.Decode(&backup)
if err != nil {
return ParseResult{}, err
}
var result ParseResult
// Parse SMS messages
for _, sms := range backup.Messages {
msg, err := convertSMSEntry(sms)
if err != nil {
slog.Error("Error parsing SMS", "error", err)
continue
}
result.Messages = append(result.Messages, msg)
}
// Parse MMS messages
for _, mms := range backup.MMS {
msg, err := convertMMSEntry(mms)
if err != nil {
slog.Error("Error parsing MMS", "error", err)
continue
}
result.Messages = append(result.Messages, msg)
}
// Parse call logs
for _, call := range backup.Calls {
callLog, err := convertCallEntry(call)
if err != nil {
slog.Error("Error parsing call log", "error", err)
continue
}
result.Calls = append(result.Calls, callLog)
}
return result, nil
}
func convertSMSEntry(sms SMSEntry) (Message, error) {
dateMs, err := strconv.ParseInt(sms.Date, 10, 64)
if err != nil {
return Message{}, err
}
msgType, _ := strconv.Atoi(sms.Type)
read := sms.Read == "1"
threadID, _ := strconv.Atoi(sms.ThreadID)
protocol, _ := strconv.Atoi(sms.Protocol)
status, _ := strconv.Atoi(sms.Status)
subID, _ := strconv.Atoi(sms.SubID)
// Normalize the phone number to remove formatting differences
normalizedAddress := normalizePhoneNumber(sms.Address)
// For SMS, the address is the single phone number
addresses := []string{}
if normalizedAddress != "" {
addresses = append(addresses, normalizedAddress)
}
// For received SMS messages, the sender is the address
var sender string
if msgType == 1 && normalizedAddress != "" {
sender = normalizedAddress
}
return Message{
Address: normalizedAddress,
Body: sms.Body,
Type: msgType,
Date: time.Unix(dateMs/1000, 0),
Read: read,
ThreadID: threadID,
Subject: normalizeNullString(sms.Subject),
Protocol: protocol,
Status: status,
ServiceCenter: sms.ServiceCenter,
SubID: subID,
ContactName: sms.ContactName,
Sender: sender,
Addresses: addresses,
}, nil
}
func convertMMSEntry(mms MMSEntry) (Message, error) {
dateMs, err := strconv.ParseInt(mms.Date, 10, 64)
if err != nil {
return Message{}, err
}
msgType, _ := strconv.Atoi(mms.Type)
read := mms.Read == "1"
threadID, _ := strconv.Atoi(mms.ThreadID)
readReport, _ := strconv.Atoi(mms.ReadReport)
readStatus, _ := strconv.Atoi(mms.ReadStatus)
messageSize, _ := strconv.Atoi(mms.MessageSize)
messageType, _ := strconv.Atoi(mms.MessageType)
simSlot, _ := strconv.Atoi(mms.SimSlot)
// Normalize the phone number to remove formatting differences
normalizedAddress := normalizePhoneNumber(mms.Address)
// Extract all addresses from MMS and find the sender (type 137 = FROM)
// Include ALL addresses to keep group conversations consistent
addressMap := make(map[string]bool)
var senderAddress string
var firstAddress string
for _, addr := range mms.Addrs {
if addr.Address != "" {
// Normalize each address to prevent duplicates due to formatting
normalizedAddr := normalizePhoneNumber(addr.Address)
if normalizedAddr != "" {
addressMap[normalizedAddr] = true
// Remember the first address we encounter
if firstAddress == "" {
firstAddress = normalizedAddr
}
// Type 137 (0x89) = FROM (sender in Android MMS)
// For received messages, this tells us who sent it
addrType, _ := strconv.Atoi(addr.Type)
if addrType == 137 {
senderAddress = normalizedAddr
}
}
}
}
// If no type 137 sender was found for a received message, use the first address
// or the single address for 1-on-1 conversations
if msgType == 1 && senderAddress == "" {
if len(addressMap) == 1 && firstAddress != "" {
// 1-on-1 conversation: the single address is definitely the sender
senderAddress = firstAddress
} else if len(addressMap) > 1 && firstAddress != "" {
// Group conversation without explicit sender: use first address as best guess
senderAddress = firstAddress
}
}
// Convert map to sorted, deduplicated slice
addresses := make([]string, 0, len(addressMap))
for addr := range addressMap {
addresses = append(addresses, addr)
}
// Sort addresses for consistency
sort.Strings(addresses)
// Determine the primary address field for conversation grouping
var primaryAddress string
if len(addresses) >= 3 {
// Group MMS (3+ participants) - join all normalized addresses to create a consistent group identifier
primaryAddress = strings.Join(addresses, ",")
} else if len(addresses) > 0 {
// MMS with 1-2 addresses - use the normalized address
primaryAddress = normalizedAddress
} else {
// Fallback to normalized mms.Address if no addresses found in mms.Addrs
primaryAddress = normalizedAddress
}
// For received messages, store the sender in the Sender field
// This allows us to display who sent each message in the UI
var sender string
if msgType == 1 && senderAddress != "" {
// Received message - store the sender address
sender = senderAddress
}
msg := Message{
Address: primaryAddress,
Type: msgType,
Date: time.Unix(dateMs/1000, 0),
Read: read,
ThreadID: threadID,
Subject: normalizeNullString(mms.Subject),
ContentType: mms.ContentType,
ReadReport: readReport,
ReadStatus: readStatus,
MessageID: mms.MessageID,
MessageSize: messageSize,
MessageType: messageType,
SimSlot: simSlot,
ContactName: mms.ContactName,
Sender: sender,
Addresses: addresses,
}
// Extract body text and media from parts
var bodyText string
for _, part := range mms.Parts {
// Skip SMIL content - it's presentation metadata, not actual message content
if isSMILContentType(part.ContentType) {
continue
}
// Check for VCF (vCard) files - these are text/* but should be treated as media attachments
if isVCardContentType(part.ContentType) && part.Data != "" {
if msg.MediaType == "" { // Only store first media item
data, err := base64.StdEncoding.DecodeString(part.Data)
if err == nil {
msg.MediaType = part.ContentType
msg.MediaData = data
}
}
continue
}
// Check for media - media parts often have text="null" which should be ignored
if part.ContentType != "" && part.Data != "" && !isTextContentType(part.ContentType) {
// This is media content (image, video, audio, etc.)
if msg.MediaType == "" { // Only store first media item
data, err := base64.StdEncoding.DecodeString(part.Data)
if err == nil {
// Store all media as-is (including HEIC images in original format)
msg.MediaType = part.ContentType
msg.MediaData = data
}
}
} else if part.Text != "" && normalizeNullString(part.Text) != "" {
// This is actual text content (not "null")
bodyText += part.Text + " "
}
}
if bodyText != "" {
msg.Body = strings.TrimSpace(bodyText)
}
// Extract group name from RCS proto: tr_id if available
// Use it as the subject if the current subject is empty or starts with "proto:"
if mms.TrID != "" && strings.HasPrefix(mms.TrID, "proto:") {
groupName := extractGroupNameFromTrID(mms.TrID)
if groupName != "" {
// Only use the extracted name if subject is empty or also starts with "proto:"
if msg.Subject == "" || strings.HasPrefix(mms.Subject, "proto:") {
msg.Subject = groupName
}
}
}
return msg, nil
}
// normalizeNullString converts the string "null" to an empty string
func normalizeNullString(s string) string {
if strings.TrimSpace(strings.ToLower(s)) == "null" {
return ""
}
return s
}
// isTextContentType checks if a content type is text-based
func isTextContentType(contentType string) bool {
ct := strings.ToLower(strings.TrimSpace(contentType))
return strings.HasPrefix(ct, "text/") ||
ct == "application/xml" ||
ct == "application/json"
}
// isSMILContentType checks if a content type is SMIL markup
func isSMILContentType(contentType string) bool {
ct := strings.ToLower(strings.TrimSpace(contentType))
return ct == "application/smil" ||
strings.HasPrefix(ct, "application/smil+") ||
strings.Contains(ct, "smil")
}
// isSMILMarkup checks if the body text is SMIL (Synchronized Multimedia Integration Language) markup
// which is MMS presentation metadata and should not be displayed to users
func isSMILMarkup(body string) bool {
trimmed := strings.TrimSpace(body)
return strings.HasPrefix(trimmed, "<smil") || strings.HasPrefix(trimmed, "<?xml")
}
// isVCardContentType checks if a content type is vCard format
func isVCardContentType(contentType string) bool {
ct := strings.ToLower(strings.TrimSpace(contentType))
return ct == "text/vcard" || ct == "text/x-vcard" || ct == "text/directory"
}
// extractGroupNameFromTrID extracts the group conversation name from RCS proto: tr_id field
func extractGroupNameFromTrID(trID string) string {
return ""
/*
// Check if tr_id starts with "proto:"
if !strings.HasPrefix(trID, "proto:") {
return ""
}
// Remove the "proto:" prefix
protoData := strings.TrimPrefix(trID, "proto:")
// Base64 decode the remaining bytes
decoded, err := base64.StdEncoding.DecodeString(protoData)
if err != nil {
slog.Error("Failed to base64 decode tr_id", "error", err)
return ""
}
// Check if we have enough bytes (need at least 84 bytes: offset 83 + 1 for length)
if len(decoded) < 84 {
slog.Debug("Decoded tr_id too short", "bytes", len(decoded), "required", 84)
return ""
}
// Read the length byte at offset 83
nameLength := int(decoded[83])
// Check if we have enough bytes for the name
if len(decoded) < 84+nameLength {
slog.Debug("Not enough bytes for group name", "have", len(decoded), "need", 84+nameLength)
return ""
}
// Extract the group name string
groupName := string(decoded[84 : 84+nameLength])
slog.Debug("Extracted group name from tr_id", "group_name", groupName)
return groupName
*/
}
// isHEICContentType checks if a content type is HEIC/HEIF format
func isHEICContentType(contentType string) bool {
ct := strings.ToLower(strings.TrimSpace(contentType))
return strings.Contains(ct, "heic") || strings.Contains(ct, "heif")
}
// needsVideoConversion checks if a video format needs conversion for browser compatibility
func needsVideoConversion(contentType string) bool {
ct := strings.ToLower(strings.TrimSpace(contentType))
// 3GP, 3G2, and other old mobile formats that browsers don't support
unsupportedFormats := []string{
"3gpp", "3gp", "3g2", "3gpp2",
"video/3gpp", "video/3gp", "video/3gpp2", "video/3g2",
}
for _, format := range unsupportedFormats {
if strings.Contains(ct, format) {
return true
}
}
return false
}
// convertHEICtoJPEG is implemented in heic_enabled.go (with -tags heic) or heic_disabled.go (default)
// When HEIC support is enabled, it converts HEIC image data to JPEG format
// When HEIC support is disabled, it returns a placeholder image
// convertVideoToMP4 converts unsupported video formats (like 3GP) to MP4 using ffmpeg
// Returns the converted MP4 data or an error if conversion fails
func convertVideoToMP4(videoData []byte) ([]byte, error) {
// Create temporary files for input and output
tmpInputFile, err := os.CreateTemp("", "video-input-*.3gp")
if err != nil {
return nil, fmt.Errorf("failed to create temp input file: %w", err)
}
defer os.Remove(tmpInputFile.Name())
defer tmpInputFile.Close()
tmpOutputFile, err := os.CreateTemp("", "video-output-*.mp4")
if err != nil {
return nil, fmt.Errorf("failed to create temp output file: %w", err)
}
defer os.Remove(tmpOutputFile.Name())
tmpOutputFile.Close()
// Write input video data to temp file
_, err = tmpInputFile.Write(videoData)
if err != nil {
return nil, fmt.Errorf("failed to write input video: %w", err)
}
tmpInputFile.Close()
// Run ffmpeg to convert video to MP4 with H.264 codec
// -i: input file
// -c:v libx264: use H.264 video codec
// -c:a aac: use AAC audio codec
// -movflags +faststart: optimize for streaming
// -preset fast: balance between speed and quality
// -crf 23: constant rate factor (quality, lower is better, 23 is good default)
cmd := exec.Command("ffmpeg",
"-i", tmpInputFile.Name(),
"-c:v", "libx264",
"-c:a", "aac",
"-movflags", "+faststart",
"-preset", "fast",
"-crf", "23",
"-y", // overwrite output file
tmpOutputFile.Name(),
)
// Capture stderr for error messages
var stderr bytes.Buffer
cmd.Stderr = &stderr
err = cmd.Run()
if err != nil {
return nil, fmt.Errorf("ffmpeg conversion failed: %w, stderr: %s", err, stderr.String())
}
// Read converted video data
convertedData, err := os.ReadFile(tmpOutputFile.Name())
if err != nil {
return nil, fmt.Errorf("failed to read converted video: %w", err)
}
return convertedData, nil
}
func convertCallEntry(call CallEntry) (CallLog, error) {
dateMs, err := strconv.ParseInt(call.Date, 10, 64)
if err != nil {
return CallLog{}, err
}
duration, _ := strconv.Atoi(call.Duration)
callType, _ := strconv.Atoi(call.Type)
presentation, _ := strconv.Atoi(call.Presentation)
// Normalize the phone number to remove formatting differences
normalizedNumber := normalizePhoneNumber(call.Number)
return CallLog{
Number: normalizedNumber,
Duration: duration,
Date: time.Unix(dateMs/1000, 0),
Type: callType,
Presentation: presentation,
SubscriptionID: call.SubscriptionID,
ContactName: call.ContactName,
}, nil
}
// UploadProgress tracks the progress of an ongoing upload
type UploadProgress struct {
TotalMessages int `json:"total_messages"`
ProcessedMessages int `json:"processed_messages"`
TotalCalls int `json:"total_calls"`
ProcessedCalls int `json:"processed_calls"`
Status string `json:"status"` // "parsing", "importing", "completed", "error"
ErrorMessage string `json:"error_message,omitempty"`
StartTime time.Time `json:"start_time"`
mu sync.RWMutex
}
var (
uploadProgress *UploadProgress
uploadProgressLock sync.RWMutex
)
// GetUploadProgress returns the current upload progress
func GetUploadProgress() *UploadProgress {
uploadProgressLock.RLock()
defer uploadProgressLock.RUnlock()
if uploadProgress == nil {
return nil
}
uploadProgress.mu.RLock()
defer uploadProgress.mu.RUnlock()
// Return a copy to avoid race conditions
return &UploadProgress{
TotalMessages: uploadProgress.TotalMessages,
ProcessedMessages: uploadProgress.ProcessedMessages,
TotalCalls: uploadProgress.TotalCalls,
ProcessedCalls: uploadProgress.ProcessedCalls,
Status: uploadProgress.Status,
ErrorMessage: uploadProgress.ErrorMessage,
StartTime: uploadProgress.StartTime,
}
}
// SetUploadProgress initializes or updates the upload progress
func SetUploadProgress(total, processed int, status string) {
uploadProgressLock.Lock()
defer uploadProgressLock.Unlock()
if uploadProgress == nil {
uploadProgress = &UploadProgress{
StartTime: time.Now(),
}
}
uploadProgress.mu.Lock()
defer uploadProgress.mu.Unlock()
uploadProgress.TotalMessages = total
uploadProgress.ProcessedMessages = processed
uploadProgress.Status = status
}
// UpdateMessageProgress updates the progress for messages
func UpdateMessageProgress(processed int) {
uploadProgressLock.RLock()
defer uploadProgressLock.RUnlock()
if uploadProgress == nil {
return
}
uploadProgress.mu.Lock()
defer uploadProgress.mu.Unlock()
uploadProgress.ProcessedMessages = processed
}
// UpdateCallProgress updates the progress for calls
func UpdateCallProgress(processed int) {
uploadProgressLock.RLock()
defer uploadProgressLock.RUnlock()
if uploadProgress == nil {
return
}
uploadProgress.mu.Lock()
defer uploadProgress.mu.Unlock()
uploadProgress.ProcessedCalls = processed
}
// ClearUploadProgress clears the upload progress
func ClearUploadProgress() {
uploadProgressLock.Lock()
defer uploadProgressLock.Unlock()
uploadProgress = nil
}
// SaveUploadedFile saves the uploaded file to a temporary location
func SaveUploadedFile(file io.Reader, filename string) (string, error) {
// Create temp directory if it doesn't exist
tempDir := os.TempDir()
uploadDir := filepath.Join(tempDir, "sbv-uploads")
err := os.MkdirAll(uploadDir, 0755)
if err != nil {
return "", fmt.Errorf("failed to create upload directory: %v", err)
}
// Create temporary file
tempFile, err := os.CreateTemp(uploadDir, "backup-*.xml")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %v", err)
}
defer tempFile.Close()
// Copy uploaded file to temp file
_, err = io.Copy(tempFile, file)
if err != nil {
os.Remove(tempFile.Name())
return "", fmt.Errorf("failed to save file: %v", err)
}
return tempFile.Name(), nil
}
// ProcessUploadedFile processes the uploaded file in the background
func ProcessUploadedFile(userID string, username string, filePath string) {
defer func() {
// Always clean up the temp file when done
slog.Info("Removing temporary file", "path", filePath)
if err := os.Remove(filePath); err != nil {
slog.Warn("Failed to remove temp file", "path", filePath, "error", err)
}
}()
slog.Info("Starting background processing", "path", filePath, "user", username)
// Get user database
userDB, err := GetUserDB(userID, username)
if err != nil {
slog.Error("Error getting user database", "error", err)
SetUploadProgress(0, 0, "error")
uploadProgressLock.Lock()
if uploadProgress != nil {
uploadProgress.mu.Lock()
uploadProgress.ErrorMessage = fmt.Sprintf("Failed to get user database: %v", err)
uploadProgress.mu.Unlock()
}
uploadProgressLock.Unlock()
return
}
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
slog.Error("Error opening file", "error", err)
SetUploadProgress(0, 0, "error")
uploadProgressLock.Lock()
if uploadProgress != nil {
uploadProgress.mu.Lock()
uploadProgress.ErrorMessage = fmt.Sprintf("Failed to open file: %v", err)
uploadProgress.mu.Unlock()
}
uploadProgressLock.Unlock()
return
}
defer file.Close()
// Process with streaming parser (batch size 1 for minimal memory)
messageCount, callCount, err := ParseSMSBackupStreaming(userDB, file, 1) // Insert immediately, no batching
if err != nil {
slog.Error("Error processing file", "error", err)
SetUploadProgress(0, 0, "error")
uploadProgressLock.Lock()
if uploadProgress != nil {
uploadProgress.mu.Lock()
uploadProgress.ErrorMessage = fmt.Sprintf("Failed to process file: %v", err)
uploadProgress.mu.Unlock()
}
uploadProgressLock.Unlock()
return
}
slog.Info("Completed processing", "messages", messageCount, "calls", callCount)
}
// ParseSMSBackupStreaming parses SMS backup file with streaming to reduce memory usage
// Each message is inserted immediately and memory is freed aggressively
func ParseSMSBackupStreaming(userDB *sql.DB, r io.Reader, batchSize int) (int, int, error) {
// Initialize progress tracking
uploadProgressLock.Lock()
uploadProgress = &UploadProgress{
Status: "parsing",
StartTime: time.Now(),
}
uploadProgressLock.Unlock()
decoder := xml.NewDecoder(r)
var messageCount, callCount int
// Track total count from root element if available
var totalCount int
for {
token, err := decoder.Token()
if err == io.EOF {
break
}
if err != nil {
SetUploadProgress(0, 0, "error")
return messageCount, callCount, err
}
switch elem := token.(type) {
case xml.StartElement:
// Get total count from root element
if elem.Name.Local == "smses" {
for _, attr := range elem.Attr {
if attr.Name.Local == "count" {
totalCount, _ = strconv.Atoi(attr.Value)
uploadProgressLock.Lock()
uploadProgress.mu.Lock()
uploadProgress.TotalMessages = totalCount
uploadProgress.mu.Unlock()
uploadProgressLock.Unlock()
}
}
}
// Process SMS messages
if elem.Name.Local == "sms" {
var sms SMSEntry
err := decoder.DecodeElement(&sms, &elem)
if err != nil {
slog.Error("Error decoding SMS", "error", err)
continue
}
msg, err := convertSMSEntry(sms)
if err != nil {
slog.Error("Error converting SMS", "error", err)
continue
}
// Insert immediately - no batching
err = InsertMessage(userDB, &msg)
if err != nil {
slog.Error("Error inserting message", "error", err)
} else {
messageCount++
UpdateMessageProgress(messageCount)
}
// Force garbage collection every 1000 messages to keep memory low
if messageCount%1000 == 0 {
runtime.GC()
}
}
// Process MMS messages
if elem.Name.Local == "mms" {
var mms MMSEntry
err := decoder.DecodeElement(&mms, &elem)
if err != nil {
slog.Error("Error decoding MMS", "error", err)
continue
}
msg, err := convertMMSEntry(mms)
// Clear the MMS struct immediately after conversion to free base64 strings
mms.Parts = nil
mms = MMSEntry{}
if err != nil {
slog.Error("Error converting MMS", "error", err)
continue
}
// Insert immediately - no batching
err = InsertMessage(userDB, &msg)
if err != nil {
slog.Error("Error inserting message", "error", err)
} else {
messageCount++
UpdateMessageProgress(messageCount)
}
// Clear the message data immediately after insert
msg.MediaData = nil
msg = Message{}
// Force garbage collection every 100 MMS messages (they're larger)
if messageCount%100 == 0 {
runtime.GC()
}
}
// Process call logs
if elem.Name.Local == "call" {
var call CallEntry
err := decoder.DecodeElement(&call, &elem)
if err != nil {
slog.Error("Error decoding call", "error", err)
continue
}
callLog, err := convertCallEntry(call)
if err != nil {
slog.Error("Error converting call", "error", err)
continue
}
// Insert immediately - no batching
err = InsertCallLog(userDB, &callLog)
if err != nil {
slog.Error("Error inserting call log", "error", err)
} else {
callCount++
uploadProgressLock.Lock()
uploadProgress.mu.Lock()
uploadProgress.TotalCalls++
uploadProgress.ProcessedCalls = callCount
uploadProgress.mu.Unlock()
uploadProgressLock.Unlock()
}
}
}
}
// Final garbage collection
runtime.GC()
// Mark as completed
SetUploadProgress(messageCount, messageCount, "completed")
return messageCount, callCount, nil
}
+264
View File
@@ -0,0 +1,264 @@
package internal
import (
"os"
"strings"
"testing"
"time"
)
const sampleXML = `<?xml version='1.0' encoding='UTF-8' standalone='yes' ?>
<?xml-stylesheet type="text/xsl" href="sms.xsl"?>
<smses count="2">
<sms protocol="0" address="332" date="1285799668193" type="2" subject="null" body="Sample Message Sent from the phone" toa="null" sc_toa="null" service_center="null" read="1" status="-1" locked="0" readable_date="Sep 30, 2010 8:34:28 AM" contact_name="(Unknown)" />
<sms protocol="0" address="4433221123" date="1289643415810" type="1" subject="null" body="Sample Message received by the phone" toa="null" sc_toa="null" service_center="null" read="0" status="-1" locked="0" readable_date="Nov 13, 2010 9:16:55 PM" contact_name="(Unknown)" />
</smses>`
func TestSampleXMLParsing(t *testing.T) {
// Parse the XML
reader := strings.NewReader(sampleXML)
result, err := ParseSMSBackup(reader)
if err != nil {
t.Fatalf("Failed to parse XML: %v", err)
}
// Verify we got 2 messages
if len(result.Messages) != 2 {
t.Errorf("Expected 2 messages, got %d", len(result.Messages))
}
// Verify first message (sent)
msg1 := result.Messages[0]
if msg1.Address != "332" {
t.Errorf("Expected address '332', got '%s'", msg1.Address)
}
if msg1.Type != 2 {
t.Errorf("Expected type 2 (sent), got %d", msg1.Type)
}
if msg1.Body != "Sample Message Sent from the phone" {
t.Errorf("Expected body 'Sample Message Sent from the phone', got '%s'", msg1.Body)
}
if msg1.Protocol != 0 {
t.Errorf("Expected protocol 0, got %d", msg1.Protocol)
}
if !msg1.Read {
t.Errorf("Expected message to be read (read=1)")
}
if msg1.Status != -1 {
t.Errorf("Expected status -1, got %d", msg1.Status)
}
// Check date: 1285799668193 milliseconds = Sep 30, 2010 8:34:28 AM
expectedDate1 := time.Unix(1285799668, 0)
if !msg1.Date.Equal(expectedDate1) {
t.Errorf("Expected date %v, got %v", expectedDate1, msg1.Date)
}
// Verify second message (received)
msg2 := result.Messages[1]
// Phone number normalization adds +1 to 10-digit US numbers
if msg2.Address != "+14433221123" {
t.Errorf("Expected address '+14433221123', got '%s'", msg2.Address)
}
if msg2.Type != 1 {
t.Errorf("Expected type 1 (received), got %d", msg2.Type)
}
if msg2.Body != "Sample Message received by the phone" {
t.Errorf("Expected body 'Sample Message received by the phone', got '%s'", msg2.Body)
}
if msg2.Read {
t.Errorf("Expected message to be unread (read=0)")
}
// Check date: 1289643415810 milliseconds = Nov 13, 2010 9:16:55 PM
expectedDate2 := time.Unix(1289643415, 0)
if !msg2.Date.Equal(expectedDate2) {
t.Errorf("Expected date %v, got %v", expectedDate2, msg2.Date)
}
// Verify no call logs in this sample
if len(result.Calls) != 0 {
t.Errorf("Expected 0 call logs, got %d", len(result.Calls))
}
}
func TestSampleXMLDatabaseIngestion(t *testing.T) {
// Create a temporary database file
tmpDB := "test_messages.db"
defer os.Remove(tmpDB) // Clean up after test
// Initialize database
err := InitDB(tmpDB)
if err != nil {
t.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Parse the XML
reader := strings.NewReader(sampleXML)
result, err := ParseSMSBackup(reader)
if err != nil {
t.Fatalf("Failed to parse XML: %v", err)
}
// Insert messages into database
messageCount := 0
for i := range result.Messages {
err := InsertMessage(db, &result.Messages[i])
if err != nil {
t.Errorf("Failed to insert message %d: %v", i, err)
continue
}
messageCount++
// Verify the ID was set
if result.Messages[i].ID == 0 {
t.Errorf("Message %d: ID was not set after insert", i)
}
}
// Verify we inserted 2 messages
if messageCount != 2 {
t.Errorf("Expected to insert 2 messages, inserted %d", messageCount)
}
// Retrieve messages from database and verify
messages, err := GetMessages(db, "332", nil, nil)
if err != nil {
t.Fatalf("Failed to retrieve messages for address '332': %v", err)
}
if len(messages) != 1 {
t.Errorf("Expected 1 message for address '332', got %d", len(messages))
} else {
msg := messages[0]
if msg.Body != "Sample Message Sent from the phone" {
t.Errorf("Retrieved message has wrong body: '%s'", msg.Body)
}
if msg.Type != 2 {
t.Errorf("Retrieved message has wrong type: %d", msg.Type)
}
if msg.Protocol != 0 {
t.Errorf("Retrieved message has wrong protocol: %d", msg.Protocol)
}
if msg.Status != -1 {
t.Errorf("Retrieved message has wrong status: %d", msg.Status)
}
if !msg.Read {
t.Errorf("Retrieved message should be marked as read")
}
}
// Retrieve second message
messages2, err := GetMessages(db, "+14433221123", nil, nil)
if err != nil {
t.Fatalf("Failed to retrieve messages for address '+14433221123': %v", err)
}
if len(messages2) != 1 {
t.Errorf("Expected 1 message for address '+14433221123', got %d", len(messages2))
} else {
msg := messages2[0]
if msg.Body != "Sample Message received by the phone" {
t.Errorf("Retrieved message has wrong body: '%s'", msg.Body)
}
if msg.Type != 1 {
t.Errorf("Retrieved message has wrong type: %d", msg.Type)
}
if msg.Read {
t.Errorf("Retrieved message should be marked as unread")
}
}
// Test GetConversations
conversations, err := GetConversations(db, nil, nil)
if err != nil {
t.Fatalf("Failed to get conversations: %v", err)
}
if len(conversations) != 2 {
t.Errorf("Expected 2 conversations, got %d", len(conversations))
}
// Verify conversations are sorted by date (most recent first)
// Second message (1289643415) should be first as it's more recent
if len(conversations) == 2 {
if conversations[0].Address != "+14433221123" {
t.Errorf("Expected first conversation to be '+14433221123', got '%s'", conversations[0].Address)
}
if conversations[1].Address != "332" {
t.Errorf("Expected second conversation to be '332', got '%s'", conversations[1].Address)
}
if conversations[0].MessageCount != 1 {
t.Errorf("Expected first conversation to have 1 message, got %d", conversations[0].MessageCount)
}
if conversations[0].Type != "conversation" {
t.Errorf("Expected conversation type to be 'conversation', got '%s'", conversations[0].Type)
}
}
// Test date range functionality
startDate := time.Unix(1289000000, 0) // After first message, before second
messages3, err := GetMessages(db, "332", &startDate, nil)
if err != nil {
t.Fatalf("Failed to retrieve messages with date filter: %v", err)
}
if len(messages3) != 0 {
t.Errorf("Expected 0 messages after start date, got %d", len(messages3))
}
// Get date range
minDate, maxDate, err := GetDateRange(db)
if err != nil {
t.Fatalf("Failed to get date range: %v", err)
}
expectedMin := time.Unix(1285799668, 0)
expectedMax := time.Unix(1289643415, 0)
if !minDate.Equal(expectedMin) {
t.Errorf("Expected min date %v, got %v", expectedMin, minDate)
}
if !maxDate.Equal(expectedMax) {
t.Errorf("Expected max date %v, got %v", expectedMax, maxDate)
}
}
func TestEmptyXML(t *testing.T) {
emptyXML := `<?xml version='1.0' encoding='UTF-8' standalone='yes' ?>
<smses count="0">
</smses>`
reader := strings.NewReader(emptyXML)
result, err := ParseSMSBackup(reader)
if err != nil {
t.Fatalf("Failed to parse empty XML: %v", err)
}
if len(result.Messages) != 0 {
t.Errorf("Expected 0 messages, got %d", len(result.Messages))
}
if len(result.Calls) != 0 {
t.Errorf("Expected 0 calls, got %d", len(result.Calls))
}
}
func TestInvalidXML(t *testing.T) {
invalidXML := `<?xml version='1.0' encoding='UTF-8' standalone='yes' ?>
<smses count="1">
<sms protocol="invalid" address="123" date="notanumber" type="2" body="Test" />
</smses>`
reader := strings.NewReader(invalidXML)
result, err := ParseSMSBackup(reader)
// Should parse but skip invalid entries or use defaults
if err != nil {
t.Fatalf("Parser should handle invalid data gracefully: %v", err)
}
// The message might be parsed with default values for invalid fields
if len(result.Messages) > 0 {
msg := result.Messages[0]
// Protocol "invalid" should parse as 0
if msg.Protocol != 0 {
t.Logf("Invalid protocol parsed as: %d", msg.Protocol)
}
// Date "notanumber" should result in Unix epoch
t.Logf("Invalid date parsed as: %v", msg.Date)
}
}
+45
View File
@@ -0,0 +1,45 @@
package internal
import "strings"
// normalizePhoneNumber removes all non-numeric characters except leading +
// and standardizes US phone numbers to include the +1 country code
// This prevents duplicate conversations due to different phone number formatting
func normalizePhoneNumber(phoneNumber string) string {
if phoneNumber == "" {
return ""
}
// Check if it starts with +
hasPlus := strings.HasPrefix(phoneNumber, "+")
// Remove all non-numeric characters
var result strings.Builder
for _, ch := range phoneNumber {
if ch >= '0' && ch <= '9' {
result.WriteRune(ch)
}
}
normalized := result.String()
if normalized == "" {
return ""
}
// Standardize US phone numbers
if !hasPlus {
// 10 digits without country code - add +1 (US number)
if len(normalized) == 10 {
return "+1" + normalized
}
// 11 digits starting with 1 - add + (US number with 1 prefix)
if len(normalized) == 11 && normalized[0] == '1' {
return "+" + normalized
}
// Other lengths without + - keep as is (might be partial/invalid)
return normalized
}
// Already has +, keep it
return "+" + normalized
}