Skip to content
phattv.dev
EmailLinkedIn

Golang After-Handler Middleware?

coding & tutorials4 min read

In this post we're gonna look into pre-handler middleware, handler code and post-handler middleware

Use case

Usually for API servers, we want to have middlewares that run before (pre-handler) and after (post-handler) the actual API handler.

  • Some of the pre-handler middlewares are Trace (generate request ID), Auth (validate request headers), IP check (check if IP comes from a whitelisted list).

  • Some of the post-handler middlewares are Performance (how long the API took to response), Log (log what is called into database).

In golang in order to get the data follow the stream, we get and set data into context. This post will show a problem with using a post-handler middleware and how to solve it.

Simple Web Server

So we have a simple API server using gorilla/mux, save the file as main.go

package main
import (
"fmt"
"net/http"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
router.HandleFunc("/hi", hiHandler)
http.ListenAndServe(":8001", router)
}
func hiHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("hiHandler hits")
fmt.Fprintf(w, "hello")
}

Running this code block by running go run main.go, open http://localhost:8001/hi we'll see hello on the browser screen and hiHandler hits in the terminal console. Next, let's add some middleware next

Add some middleware

Here we add a pre-handler middleware to set user to context so that the handler can get it out from context

package main
import (
"context"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
const (
contextKeyUsername = "username"
)
func main() {
router := mux.NewRouter()
router.Use(preHandlerMiddleware)
router.HandleFunc("/hi", hiHandler)
router.Use(postHandlerMiddleware)
http.ListenAndServe(":8001", router)
}
func preHandlerMiddleware(next http.Handler) http.Handler {
fmt.Println("preHandlerMiddleware hits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKeyUsername, "user123")
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func hiHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("hiHandler hits")
ctx := r.Context()
username := ctx.Value(contextKeyUsername)
if username != nil {
fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string)))
} else {
fmt.Fprintf(w, "hello")
}
}
func postHandlerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
next.ServeHTTP(w, r.WithContext(ctx))
fmt.Println("postHandlerMiddleware hits")
username := ctx.Value(contextKeyUsername)
fmt.Println("username", username)
})
}

The current user is set into context using context.WithValue, released in version 1.7:

ctx = context.WithValue(ctx, contextKey, "user123")

Later on in the handler, the value can be get from context using context.Value

username := ctx.Value(contextKey)

Open http://localhost:8001/hi again and we see hello user123 on the browser screen and this in ter terminal console:

preHandlerMiddleware hits
hiHandler hits
postHandlerMiddleware hits
username user123

Add a middleware as handler wrapper function

So far so good. Let's add a handler wrapper where we set another value into context

package main
import (
"context"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
const (
contextKeyUsername = "username"
contextKeyAction = "action"
)
func main() {
router := mux.NewRouter()
router.Use(preHandlerMiddleware)
router.HandleFunc("/hi", setAction("HI", hiHandler))
router.Use(postHandlerMiddleware)
http.ListenAndServe(":8001", router)
}
func preHandlerMiddleware(next http.Handler) http.Handler {
fmt.Println("preHandlerMiddleware hits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKeyUsername, "user123")
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func setAction(action string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKeyAction, action)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func hiHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("hiHandler hits")
ctx := r.Context()
username := ctx.Value(contextKeyUsername)
if username != nil {
fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string)))
} else {
fmt.Fprintf(w, "hello")
}
}
func postHandlerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
next.ServeHTTP(w, r.WithContext(ctx))
fmt.Println("postHandlerMiddleware hits")
username := ctx.Value(contextKeyUsername)
fmt.Println("username", username)
action := ctx.Value(contextKeyAction)
fmt.Println("action", action) // We see <nil> here
})
}

Nothing changes on the browser screen. Here in the terminal console we expect to see action, but we see nil instead

preHandlerMiddleware hits
hiHandler hits
postHandlerMiddleware hits
username user123
action <nil>

Solution

We can get rid of postHandlerMiddleware and put in the "handler wrapper" setAction, which is a middleware too. Interestingly, this is how middlewares are implemented with the native net/http without any web frameworks

package main
import (
"context"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
const (
contextKeyUsername = "username"
)
func main() {
router := mux.NewRouter()
router.Use(preHandlerMiddleware)
router.HandleFunc("/hi", setAction("HI", hiHandler))
http.ListenAndServe(":8001", router)
}
func preHandlerMiddleware(next http.Handler) http.Handler {
fmt.Println("preHandlerMiddleware hits")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKeyUsername, "user123")
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func setAction(action string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKeyAction, action)
next.ServeHTTP(w, r.WithContext(ctx))
fmt.Println("postHandlerMiddleware hits")
fmt.Println("action", action)
}
}
func hiHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("hiHandler hits")
ctx := r.Context()
username := ctx.Value(contextKeyUsername)
if username != nil {
fmt.Fprintf(w, fmt.Sprintf("hello %s", username.(string)))
} else {
fmt.Fprintf(w, "hello")
}
}

And now we have the expected result

preHandlerMiddleware hits
hiHandler hits
postHandlerMiddleware hits
username user123
action HI