package models import ( "strings" "time" "code.nonshy.com/nonshy/website/pkg/config" ) // Message table. type Message struct { ID uint64 `gorm:"primaryKey"` SourceUserID uint64 `gorm:"index"` TargetUserID uint64 `gorm:"index"` Read bool `gorm:"index"` Message string CreatedAt time.Time UpdatedAt time.Time } // GetMessage by ID. func GetMessage(id uint64) (*Message, error) { m := &Message{} result := DB.First(&m, id) return m, result.Error } // GetMessages for a user, e-mail style for the inbox or sent box view. func GetMessages(user *User, sent bool, pager *Pagination) ([]*Message, error) { var ( m = []*Message{} blockedUserIDs = BlockedUserIDs(user) where = []string{} placeholders = []interface{}{} ) if sent { where = append(where, "source_user_id = ?") placeholders = append(placeholders, user.ID) if len(blockedUserIDs) > 0 { where = append(where, "target_user_id NOT IN ?") placeholders = append(placeholders, blockedUserIDs) } } else { where = append(where, "target_user_id = ?") placeholders = append(placeholders, user.ID) if len(blockedUserIDs) > 0 { where = append(where, "source_user_id NOT IN ?") placeholders = append(placeholders, blockedUserIDs) } } // Don't show messages from banned or disabled accounts. where = append(where, ` NOT EXISTS ( SELECT 1 FROM users WHERE users.id IN (messages.target_user_id, messages.source_user_id) AND users.status <> 'active' ) `) query := DB.Where( strings.Join(where, " AND "), placeholders..., ).Order(pager.Sort) query.Model(&Message{}).Count(&pager.Total) result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&m) return m, result.Error } // GetMessageThreads for a user: combined inbox/sent view grouped by username. func GetMessageThreads(user *User, pager *Pagination) ([]*Message, error) { var ( m = []*Message{} blockedUserIDs = BlockedUserIDs(user) where = []string{} placeholders = []interface{}{} ) where = append(where, "target_user_id = ?") placeholders = append(placeholders, user.ID) if len(blockedUserIDs) > 0 { where = append(where, "source_user_id NOT IN ?") placeholders = append(placeholders, blockedUserIDs) } // Don't show messages from banned or disabled accounts. where = append(where, ` NOT EXISTS ( SELECT 1 FROM users WHERE users.id IN (messages.target_user_id, messages.source_user_id) AND users.status <> 'active' ) `) type newest struct { ID uint64 SourceUserID uint64 TargetUserID uint64 } var scan []newest // Get the newest message IDs grouped by username for everyone we are chatting with. query := DB.Model(&Message{}).Select( "max(id) AS id", "source_user_id", "target_user_id", ).Where( strings.Join(where, " AND "), placeholders..., ).Group( "source_user_id, target_user_id", ).Order("id desc").Scan(&scan) if query.Error != nil { return nil, query.Error } pager.Total = int64(len(scan)) // Get the details from these message IDs. var messageIDs = []uint64{} for _, row := range scan { messageIDs = append(messageIDs, row.ID) } query = DB.Where( "id IN ?", messageIDs, ).Order(pager.Sort) query.Model(&Message{}).Count(&pager.Total) result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&m) return m, result.Error } // GetMessageThread returns paginated message history between two people. func GetMessageThread(sourceUserID, targetUserID uint64, pager *Pagination) ([]*Message, error) { var m = []*Message{} query := DB.Where( "(source_user_id = ? AND target_user_id = ?) OR (source_user_id = ? AND target_user_id = ?)", sourceUserID, targetUserID, targetUserID, sourceUserID, ).Order(pager.Sort) query.Model(&Message{}).Count(&pager.Total) result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&m) return m, result.Error } // HasMessageThread returns if a message thread exists between two users (either direction). // Returns the ID of the thread and a boolean OK that it existed. func HasMessageThread(a, b *User) (uint64, bool) { var pager = &Pagination{ Page: 1, PerPage: 1, Sort: "updated_at desc", } messages, err := GetMessageThread(a.ID, b.ID, pager) if err == nil && len(messages) > 0 { return messages[0].ID, true } return 0, false } // HasSentAMessage tells if the source user has sent a DM to the target user. func HasSentAMessage(sourceUser, targetUser *User) bool { var count int64 DB.Model(&Message{}).Where( "source_user_id = ? AND target_user_id = ?", sourceUser.ID, targetUser.ID, ).Count(&count) return count > 0 } // DeleteMessageThread removes all message history between two people. func DeleteMessageThread(message *Message) error { return DB.Where( "(source_user_id = ? AND target_user_id = ?) OR (source_user_id = ? AND target_user_id = ?)", message.SourceUserID, message.TargetUserID, message.TargetUserID, message.SourceUserID, ).Delete(&Message{}).Error } // CountUnreadMessages gets the count of unread messages for a user. func CountUnreadMessages(user *User) (int64, error) { var ( blockedUserIDs = BlockedUserIDs(user) where = []string{ "target_user_id = ? AND read = ?", } placeholders = []interface{}{ user.ID, false, } ) // Blocking user IDs? if len(blockedUserIDs) > 0 { where = append(where, "source_user_id NOT IN ?") placeholders = append(placeholders, blockedUserIDs) } // Don't show messages from banned or disabled accounts. where = append(where, ` NOT EXISTS ( SELECT 1 FROM users WHERE users.id IN (messages.target_user_id, messages.source_user_id) AND users.status <> 'active' ) `) query := DB.Where( strings.Join(where, " AND "), placeholders..., ) var count int64 result := query.Model(&Message{}).Count(&count) return count, result.Error } // SendMessage from a source to a target user. func SendMessage(sourceUserID, targetUserID uint64, message string) (*Message, error) { m := &Message{ SourceUserID: sourceUserID, TargetUserID: targetUserID, Message: message, Read: false, } result := DB.Create(m) return m, result.Error } // IsLikelySpam checks if a DM message is likely to be spam so that the front-end can warn the recipient. // // This happens e.g. when the sender asks to switch to Telegram or WhatsApp. func (m *Message) IsLikelySpam() bool { body := strings.ToLower(m.Message) for _, re := range config.DirectMessageSpamKeywords { if idx := re.FindStringIndex(body); len(idx) > 0 { return true } } return false } // Save message. func (m *Message) Save() error { result := DB.Save(m) return result.Error } // Delete a message. func (m *Message) Delete() error { return DB.Delete(m).Error }