Golang After-Handler Middleware?
— coding & tutorials — 4 min read
In this post we're gonna look into pre-handler middleware, handler code and post-handler middleware
Use case
Usually for API servers, we want to have middlewares that run before (pre-handler) and after (post-handler) the actual API handler.
-
Some of the pre-handler middlewares are Trace (generate request ID), Auth (validate request headers), IP check (check if IP comes from a whitelisted list).
-
Some of the post-handler middlewares are Performance (how long the API took to response), Log (log what is called into database).
In golang in order to get the data follow the stream, we get and set data into context. This post will show a problem with using a post-handler middleware and how to solve it.
Simple Web Server
So we have a simple API server using gorilla/mux, save the file as main.go
package main
import ( "fmt" "net/http"
"github.com/gorilla/mux")
func main() { router := mux.NewRouter() router.HandleFunc("/hi", hiHandler) http.ListenAndServe(":8001", router)}
func hiHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("hiHandler hits") fmt.Fprintf(w, "hello")}
Running this code block by running go run main.go
, open http://localhost:8001/hi we'll see hello
on the browser screen and hiHandler hits
in the terminal console. Next, let's add some middleware next
Add some middleware
Here we add a pre-handler middleware to set user to context so that the handler can get it out from context
package main
import ( "context" "fmt" "net/http"
"github.com/gorilla/mux")
const ( contextKeyUsername = "username")
func main() { router := mux.NewRouter() router.Use(preHandlerMiddleware) router.HandleFunc("/hi", hiHandler) router.Use(postHandlerMiddleware) http.ListenAndServe(":8001", router)}
func preHandlerMiddleware(next http.Handler) http.Handler { fmt.Println("preHandlerMiddleware hits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKeyUsername, "user123") next.ServeHTTP(w, r.WithContext(ctx)) })}
func hiHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("hiHandler hits") ctx := r.Context() username := ctx.Value(contextKeyUsername) if username != nil { fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string))) } else { fmt.Fprintf(w, "hello") }}
func postHandlerMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() next.ServeHTTP(w, r.WithContext(ctx)) fmt.Println("postHandlerMiddleware hits")
username := ctx.Value(contextKeyUsername) fmt.Println("username", username) })}
The current user is set into context using context.WithValue, released in version 1.7:
ctx = context.WithValue(ctx, contextKey, "user123")
Later on in the handler, the value can be get from context using context.Value
username := ctx.Value(contextKey)
Open http://localhost:8001/hi again and we see hello user123
on the browser screen and this in ter terminal console:
preHandlerMiddleware hitshiHandler hitspostHandlerMiddleware hitsusername user123
Add a middleware as handler wrapper function
So far so good. Let's add a handler wrapper where we set another value into context
package main
import ( "context" "fmt" "net/http"
"github.com/gorilla/mux")
const ( contextKeyUsername = "username" contextKeyAction = "action")
func main() { router := mux.NewRouter() router.Use(preHandlerMiddleware) router.HandleFunc("/hi", setAction("HI", hiHandler)) router.Use(postHandlerMiddleware) http.ListenAndServe(":8001", router)}
func preHandlerMiddleware(next http.Handler) http.Handler { fmt.Println("preHandlerMiddleware hits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKeyUsername, "user123") next.ServeHTTP(w, r.WithContext(ctx)) })}
func setAction(action string, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKeyAction, action) next.ServeHTTP(w, r.WithContext(ctx)) }}
func hiHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("hiHandler hits") ctx := r.Context() username := ctx.Value(contextKeyUsername) if username != nil { fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string))) } else { fmt.Fprintf(w, "hello") }}
func postHandlerMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() next.ServeHTTP(w, r.WithContext(ctx)) fmt.Println("postHandlerMiddleware hits")
username := ctx.Value(contextKeyUsername) fmt.Println("username", username)
action := ctx.Value(contextKeyAction) fmt.Println("action", action) // We see <nil> here })}
Nothing changes on the browser screen. Here in the terminal console we expect to see action, but we see nil instead
preHandlerMiddleware hitshiHandler hitspostHandlerMiddleware hitsusername user123action <nil>
Solution
We can get rid of postHandlerMiddleware
and put in the "handler wrapper" setAction
, which is a middleware too. Interestingly, this is how middlewares are implemented with the native net/http
without any web frameworks
package main
import ( "context" "fmt" "net/http"
"github.com/gorilla/mux")
const ( contextKeyUsername = "username")
func main() { router := mux.NewRouter() router.Use(preHandlerMiddleware) router.HandleFunc("/hi", setAction("HI", hiHandler)) http.ListenAndServe(":8001", router)}
func preHandlerMiddleware(next http.Handler) http.Handler { fmt.Println("preHandlerMiddleware hits") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKeyUsername, "user123") next.ServeHTTP(w, r.WithContext(ctx)) })}
func setAction(action string, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKeyAction, action) next.ServeHTTP(w, r.WithContext(ctx))
fmt.Println("postHandlerMiddleware hits") fmt.Println("action", action) }}
func hiHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("hiHandler hits") ctx := r.Context() username := ctx.Value(contextKeyUsername) if username != nil { fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string))) } else { fmt.Fprintf(w, "hello") }}
And now we have the expected result
preHandlerMiddleware hitshiHandler hitspostHandlerMiddleware hitsusername user123action HI