Optimize CurrentUser to read from DB only once per request

This commit is contained in:
Noah 2022-08-21 14:17:52 -07:00
parent e42cebe4b8
commit 96b33a920f
3 changed files with 24 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"time" "time"
@ -46,7 +47,9 @@ func LoginRequired(handler http.Handler) http.Handler {
} }
} }
handler.ServeHTTP(w, r) // Stick the CurrentUser in the request context so future calls to session.CurrentUser can read it.
ctx := context.WithValue(r.Context(), session.CurrentUserKey, user)
handler.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
@ -55,19 +58,26 @@ func AdminRequired(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// User must be logged in. // User must be logged in.
if currentUser, err := session.CurrentUser(r); err != nil { currentUser, err := session.CurrentUser(r)
if err != nil {
log.Error("AdminRequired: %s", err) log.Error("AdminRequired: %s", err)
errhandler := templates.MakeErrorPage("Login Required", "You must be signed in to view this page.", http.StatusForbidden) errhandler := templates.MakeErrorPage("Login Required", "You must be signed in to view this page.", http.StatusForbidden)
errhandler.ServeHTTP(w, r) errhandler.ServeHTTP(w, r)
return return
} else if !currentUser.IsAdmin { }
// Stick the CurrentUser in the request context so future calls to session.CurrentUser can read it.
ctx := context.WithValue(r.Context(), session.CurrentUserKey, currentUser)
// Admin required.
if !currentUser.IsAdmin {
log.Error("AdminRequired: %s", err) log.Error("AdminRequired: %s", err)
errhandler := templates.MakeErrorPage("Admin Required", "You do not have permission for this page.", http.StatusForbidden) errhandler := templates.MakeErrorPage("Admin Required", "You do not have permission for this page.", http.StatusForbidden)
errhandler.ServeHTTP(w, r) errhandler.ServeHTTP(w, r.WithContext(ctx))
return return
} }
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r.WithContext(ctx))
}) })
} }

View File

@ -11,6 +11,12 @@ import (
func CurrentUser(r *http.Request) (*models.User, error) { func CurrentUser(r *http.Request) (*models.User, error) {
sess := Get(r) sess := Get(r)
if sess.LoggedIn { if sess.LoggedIn {
// Did we already get the CurrentUser once before?
ctx := r.Context()
if user, ok := ctx.Value(CurrentUserKey).(*models.User); ok {
return user, nil
}
// Load the associated user ID. // Load the associated user ID.
return models.GetUser(sess.UserID) return models.GetUser(sess.UserID)
} }

View File

@ -27,6 +27,7 @@ type Session struct {
const ( const (
ContextKey = "session" ContextKey = "session"
CurrentUserKey = "current_user"
CSRFKey = "csrf" CSRFKey = "csrf"
) )