package models import ( "errors" "fmt" "strings" "time" "code.nonshy.com/nonshy/website/pkg/log" "gorm.io/gorm" ) // Thread table - a post within a Forum. type Thread struct { ID uint64 `gorm:"primaryKey"` ForumID uint64 `gorm:"index"` Forum Forum Pinned bool `gorm:"index"` Explicit bool `gorm:"index"` NoReply bool Title string CommentID uint64 `gorm:"index"` Comment Comment // first comment of the thread Views uint64 CreatedAt time.Time UpdatedAt time.Time } // Preload related tables for the forum (classmethod). func (f *Thread) Preload() *gorm.DB { return DB.Preload("Forum").Preload("Comment.User.ProfilePhoto") } // GetThread by ID. func GetThread(id uint64) (*Thread, error) { t := &Thread{} result := t.Preload().First(&t, id) return t, result.Error } // GetThreads queries a set of thread IDs and returns them mapped. func GetThreads(IDs []uint64) (map[uint64]*Thread, error) { var ( mt = map[uint64]*Thread{} ts = []*Thread{} ) result := (&Thread{}).Preload().Where("id IN ?", IDs).Find(&ts) for _, row := range ts { mt[row.ID] = row } return mt, result.Error } // CreateThread creates a new thread with proper Comment structure. func CreateThread(user *User, forumID uint64, title, message string, pinned, explicit, noReply bool) (*Thread, error) { thread := &Thread{ ForumID: forumID, Title: title, Pinned: pinned, Explicit: explicit, NoReply: noReply && user.IsAdmin, Comment: Comment{ User: *user, Message: message, }, } log.Error("CreateThread: Going to post %+v", thread) // Create the thread & comment first... result := DB.Create(thread) if result.Error != nil { return nil, result.Error } // Fill out the Comment with proper reverse foreign keys. thread.Comment.TableName = "threads" thread.Comment.TableID = thread.ID log.Error("Saving updated comment: %+v", thread) result = DB.Save(&thread.Comment) return thread, result.Error } // Reply to a thread, adding an additional comment. func (t *Thread) Reply(user *User, message string) (*Comment, error) { // Save the thread on reply, updating its timestamp. if err := t.Save(); err != nil { log.Error("Thread.Reply: couldn't ping UpdatedAt on thread: %s", err) } return AddComment(user, "threads", t.ID, message) } // DeleteReply removes a comment from a thread. If it is the primary comment, deletes the whole thread. func (t *Thread) DeleteReply(comment *Comment) error { // Sanity check that this reply is one of ours. if !(comment.TableName == "threads" && comment.TableID == t.ID) { return errors.New("that comment doesn't belong to this thread") } // Is this the primary comment that started the thread? If so, delete the whole thread. if comment.ID == t.CommentID { log.Error("DeleteReply(%d): this is the parent comment of a thread (%d '%s'), remove the whole thread", comment.ID, t.ID, t.Title) return t.Delete() } // Remove just this comment. return comment.Delete() } // PinnedThreads returns all pinned threads in a forum (there should generally be few of these). func PinnedThreads(forum *Forum) ([]*Thread, error) { var ( ts = []*Thread{} query = (&Thread{}).Preload().Where( "forum_id = ? AND pinned IS TRUE", forum.ID, ).Order("updated_at desc") ) result := query.Find(&ts) return ts, result.Error } // PaginateThreads provides a forum index view of posts, minus pinned posts. func PaginateThreads(user *User, forum *Forum, pager *Pagination) ([]*Thread, error) { var ( ts = []*Thread{} query = (&Thread{}).Preload() wheres = []string{} placeholders = []interface{}{} ) // Always filters. wheres = append(wheres, "forum_id = ? AND pinned IS NOT TRUE") placeholders = append(placeholders, forum.ID) // If the user hasn't opted in for Explicit, hide NSFW threads. if !user.Explicit && !user.IsAdmin { wheres = append(wheres, "explicit IS NOT TRUE") } query = query.Where( strings.Join(wheres, " AND "), placeholders..., ).Order(pager.Sort) query.Model(&Thread{}).Count(&pager.Total) result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&ts) return ts, result.Error } // View a thread, incrementing its View count but not its UpdatedAt. func (t *Thread) View() error { return DB.Model(&Thread{}).Where( "id = ?", t.ID, ).Updates(map[string]interface{}{ "views": t.Views + 1, "updated_at": t.UpdatedAt, }).Error } // Save a thread, updating its timestamp. func (t *Thread) Save() error { return DB.Save(t).Error } // Delete a thread and all of its comments. func (t *Thread) Delete() error { // Remove all comments. result := DB.Where( "table_name = ? AND table_id = ?", "threads", t.ID, ).Delete(&Comment{}) if result.Error != nil { return fmt.Errorf("deleting comments for thread: %s", result.Error) } // Remove the thread itself. return DB.Delete(t).Error } // ThreadStatistics queries for reply/view count for threads. type ThreadStatistics struct { Replies uint64 Views uint64 } type ThreadStatsMap map[uint64]*ThreadStatistics // MapThreadStatistics looks up statistics for a set of threads. func MapThreadStatistics(threads []*Thread) ThreadStatsMap { var ( result = ThreadStatsMap{} IDs = []uint64{} ) // Collect thread IDs and initialize the map. for _, thread := range threads { IDs = append(IDs, thread.ID) result[thread.ID] = &ThreadStatistics{ Views: thread.Views, } } // Hold the result of the count/group by query. type group struct { ID uint64 Replies uint64 } var groups = []group{} // Count comments grouped by thread IDs. err := DB.Table( "comments", ).Select( "table_id AS id, count(id) AS replies", ).Where( "table_name = ? AND table_id IN ?", "threads", IDs, ).Group("table_id").Scan(&groups) if err != nil { log.Error("MapThreadStatistics: SQL error: %s") } // Map the results in. for _, row := range groups { log.Error("Got row: %+v", row) if stats, ok := result[row.ID]; ok { stats.Replies = row.Replies // Remove the OG comment from the count. if stats.Replies > 0 { stats.Replies-- } } } return result } // Has stats for this thread? (we should..) func (ts ThreadStatsMap) Has(threadID uint64) bool { _, ok := ts[threadID] return ok } // Get thread stats. func (ts ThreadStatsMap) Get(threadID uint64) *ThreadStatistics { if stats, ok := ts[threadID]; ok { return stats } return nil }