Microservices in Golang, pt4 - Validation

Middleware#

When posting data it’s a good idea to validate the data and we’ll look into that a bit later, before that we need to check middleware (middleware it’s an HTTP handler that hijacks a request, does something before sending it back to another or final handler, for example a good use case is CORS or authentication). Read more about the Gorilla Mux middleware here

A very basic usage of middleware would be to log a message, you’ll notice that the order its assigned to the router mathers:

func middlewareOne(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Println("MiddlewareOne")
		next.ServeHTTP(w, r)
	})
}

func middlewareTwo(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Println("MiddlewareTwo")
		next.ServeHTTP(w, r)
	})
}

func main() {
 ...
	api.Use(middlewareOne)
	api.Use(middlewareTwo)
}

The output is:

1970/1/1 15:59:00 MiddlewareOne
1970/1/1 15:59:00 MiddlewareTwo

In our code for our PUT and POST, we can get the Song from the request and move all the DRY code to the middleware, such as the unmarshall.

package handlers

import (
	"context"
	"log"
	"net/http"
	"strconv"

	"example.com/go-intro-microservices-pt2/data"
	"github.com/gorilla/mux"
)

// KeySong is a key used for the Song object in the context
type KeySong struct{}

type Songs struct {
	l *log.Logger
}

func NewSongs(l *log.Logger) *Songs {
	return &Songs{l}
}

func (s *Songs) Get(rw http.ResponseWriter, r *http.Request) {
	s.l.Println("Handle GET Songs")

	ls := data.GetSongs()
	err := ls.ToJSON(rw)
	if err != nil {
		http.Error(rw, "Unable to marshal json of songs!", http.StatusInternalServerError)
	}
}

func (s *Songs) Post(rw http.ResponseWriter, r *http.Request) {
	s.l.Println("Handle POST Songs")

	song := r.Context().Value(KeySong{}).(data.Song)
	data.AddSong(&song)
}

func (s Songs) Put(rw http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	id, err := strconv.Atoi(vars["id"])
	if err != nil {
		http.Error(rw, "Unable to convert id", http.StatusBadRequest)
		return
	}

	s.l.Println("Handle PUT Songs, update song id", id)

	// an interface is returned but we cast to Song
	song := r.Context().Value(KeySong{}).(data.Song)

	err = data.UpdateSong(id, &song)
	if err == data.ErrSongNotFound {
		http.Error(rw, "Song not found", http.StatusNotFound)
		return
	}

	if err != nil {
		http.Error(rw, "Song not found", http.StatusInternalServerError)
		return
	}
}

func (s Songs) MiddlewareSongValidation(next http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		song := data.Song{}

		if r.Method == "POST" || r.Method == "PUT" {
			err := song.FromJSON(r.Body)
			if err != nil {
				s.l.Println("[ERROR] deserializing song", err)
				http.Error(rw, "Error reading song", http.StatusBadRequest)
				return
			}
		}

		// add the product to the context
		// the preferred approach is to use Types as keys
		ctx := context.WithValue(r.Context(), KeySong{}, song)
		r = r.WithContext(ctx)

		// Call the next handler, which can be another middleware in the chain, or the final handler.
		next.ServeHTTP(rw, r)
	})
}

Here’s the code we’ve removed, or moved to the middleware

song := &data.Song{}

err = song.FromJSON(r.Body)
if err != nil {
	http.Error(rw, "Unable to unmarshal json of song", http.StatusBadRequest)
}

JSON Validation#

We’re going to start looking at doing validation on our structs, by that I mean use the package Validator to help us achieve that goal; It’s a nice tool that allow us to check the minimal length on struct fields, check if fields are present, if a particular field type is of a certain type, etc based on tags. Documents available here and a list explaining why we should sanitize the data here .

Start by adding a validator function in our data model. To do the validation we also need to construct a validator (data object) and add the validator tag in our struct.

In our data/songs.go

type Song struct {
	ID        int     `json:"id"`
	Band      string  `json:"band" validate:"required"`
	...
}

func (s *Song) Validator() error {
	validate := validator.New()
	return validate.Struct(s)
}

To test we create a basic unit test, as such:

package data

import "testing"

func TestChecksValidation(t *testing.T) {
	s := &Song{}

	err := s.Validate()

	if err != nil {
		t.Fatal(err)
	}
}

And then execute:

go test -timeout 30s example.com/go-intro-microservices-pt2/data -run ^TestChecksValidation

That should fail with:

--- FAIL: TestChecksValidation (0.00s)
    songs_test.go:11: Key: 'Song.Band' Error:Field validation for 'Band' failed on the 'required' tag
FAIL
FAIL	example.com/go-intro-microservices-pt2/data	0.055s
FAIL

So, let’s add a new validation tag to Price and make it validate:"gt=0", but also fulfil the required fields on our test.

package data

import "testing"

func TestChecksValidation(t *testing.T) {
	s := &Song{
		Band:  "Mad Funk",
		Price: 9.90,
	}

	err := s.Validate()

	if err != nil {
		t.Fatal(err)
	}
}

Now that we understand how to validate, we should wire it to our API, through the middleware we’ve just created.

func (s Songs) MiddlewareSongValidation(next http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		song := data.Song{}

		if r.Method == "POST" || r.Method == "PUT" {
			err := song.FromJSON(r.Body)
			if err != nil {
				s.l.Println("[ERROR] deserializing song", err)
				http.Error(rw, "Error reading song", http.StatusBadRequest)
				return
			}

			// validate the song
			err = song.Validate()
			if err != nil {
				s.l.Println("[ERROR] validating song", err)
				http.Error(
					rw,
					fmt.Sprintf("Error validating song: %s", err),
					http.StatusBadRequest,
				)
				return
			}
		}

		// add the product to the context
		// the preferred approach is to use Types as keys
		ctx := context.WithValue(r.Context(), KeySong{}, song)
		r = r.WithContext(ctx)

		// Call the next handler, which can be another middleware in the chain, or the final handler.
		next.ServeHTTP(rw, r)
	})
}

So, if we exec:

curl -v -X PUT -d '{"band": "Zipzags", "title": "Songalicious", "price": 0, "sku": "abz1"}'  localhost:9000/api/v1/songs/3

We’d get the output:

*   Trying 127.0.0.1...
* TCP_NODELAY set
* Connected to localhost (127.0.0.1) port 9000 (#0)
> PUT /api/v1/songs/3 HTTP/1.1
> Host: localhost:9000
> User-Agent: curl/7.64.1
> Accept: */*
> Content-Length: 71
> Content-Type: application/x-www-form-urlencoded
>
* upload completely sent off: 71 out of 71 bytes
< HTTP/1.1 400 Bad Request
< Content-Type: text/plain; charset=utf-8
< X-Content-Type-Options: nosniff
< Date: Thu, 29 Oct 2020 18:21:30 GMT
< Content-Length: 99
<
Error validating song: Key: 'Song.Price' Error:Field validation for 'Price' failed on the 'gt' tag
* Connection #0 to host localhost left intact
* Closing connection 0

As you see, we present a useful error message and also improve the security of our application.

CORS#

To enable CORS see the Package CORS

As an example, for our use case we’ll only allow our localhost:9000; Notice that we’ve created an alias gorHandlers for the Gorilla Handlers, since we already have Handlers used.

package main

import (
	"context"
	"log"
	"net/http"
	"os"
	"os/signal"
	"time"

	"example.com/go-intro-microservices-pt2/handlers"
	gorHandlers "github.com/gorilla/handlers"
	"github.com/gorilla/mux"
)

var bindAddress = env.String("BIND_ADDRESS", false, ":9000", "Bind address for the server")

func main() {
	l := log.New(os.Stdout, "rest-api", log.LstdFlags)

	hh := handlers.NewHello(l)
	sh := handlers.NewSongs(l)

	sm := mux.NewRouter()

	api := sm.PathPrefix("/api/v1").Subrouter()
	api.HandleFunc("", hh.Get).Methods(http.MethodGet)
	api.HandleFunc("", hh.Post).Methods(http.MethodPost)
	api.HandleFunc("", hh.Put).Methods(http.MethodPut)
	api.HandleFunc("", hh.Delete).Methods(http.MethodDelete)
	api.HandleFunc("/songs", sh.Get).Methods(http.MethodGet)
	api.HandleFunc("/songs", sh.Post).Methods(http.MethodPost)
	api.HandleFunc("/songs/{id:[0-9]+}", sh.Put).Methods(http.MethodPut)
	api.HandleFunc("", hh.NotFound)
	api.Use(sh.MiddlewareSongValidation)

	// CORS
	ch := gorHandlers.CORS(
		gorHandlers.AllowedOrigins(
			[]string{"http://localhost:9000"},
		),
	)

	s := &http.Server{
		Addr:         *bindAddress,      // configure the bind address
		Handler:      ch(sm),            // the default handlers
		IdleTimeout:  120 * time.Second, // max time for connections using TCP keep-alive
		ReadTimeout:  20 * time.Second,  // max time to read request from client
		WriteTimeout: 30 * time.Second,  // max time to write response to the client
	}

	go func() {
		err := s.ListenAndServe()
		if err != nil {
			l.Fatal(err)
		}
	}()

	sigChan := make(chan os.Signal)
	signal.Notify(sigChan, os.Interrupt)
	signal.Notify(sigChan, os.Kill)

	sig := <-sigChan
	l.Println("Terminate received, gracefully shuttingdown...", sig)

	tc, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	s.Shutdown(tc)
}
comments powered by Disqus