Skip to content

Commit

Permalink
Add files
Browse files Browse the repository at this point in the history
  • Loading branch information
gomezjdaniel committed Jun 22, 2023
1 parent 185a831 commit 3e124c7
Show file tree
Hide file tree
Showing 6 changed files with 659 additions and 0 deletions.
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

FILES = $(shell find . -type f -name '*.go' -not -path './vendor/*')

gofmt:
@gofmt -s -w $(FILES)
@gofmt -r '&α{} -> new(α)' -w $(FILES)

test:
go test -v ./...
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

# idempotency

> **IMPORTANT**: This package is still under development process. Basic test cases are covered but API is not stable yet.
Middleware to make idempotent HTTP handlers.
11 changes: 11 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/gomezjdaniel/idempotency

go 1.20

require github.com/stretchr/testify v1.8.4

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
230 changes: 230 additions & 0 deletions idempotency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package idempotency

import (
"bytes"
"context"
"crypto/md5"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"time"
)

const (
RecoveryPointStart = "start"
)

const (
DefaultIdempotencyKeyHeader = "Idempotency-Key"
DefaultLockDuration = 10 * time.Second
)

type Locker interface {
Lock(key string) bool
Unlock(key string)
}

type IdempotencyKey struct {
Key string
RecoveryPoint string
RequestMethod string
RequestURLPath string
RequestURLRawQuery string
RequestHeaders http.Header
RequestBodyHash string
ResponseStatusCode int
ResponseHeaders http.Header
ResponseBody string
}

type Repository interface {
// GetOrInsertKey returns the key if it exists, or inserts it if it doesn't.
// If the key is inserted, the returned bool is true.
GetOrInsert(*IdempotencyKey) (*IdempotencyKey, bool, error)
// SetRecoveryPoint sets the recovery point for the key.
SetRecoveryPoint(key string, recoveryPoint string) error
// SetResponse sets the response fields for the key.
SetResponse(key string, statusCode int, headers http.Header, body string) error
}

type Config struct {
Locker Locker
Repository Repository
IdempotencyKeyHeader string
}

func New(config Config) func(http.Handler) http.Handler {
if config.Locker == nil {
panic("idempotency: Locker is required")
}
if config.Repository == nil {
panic("idempotency: Repository is required")
}
if config.IdempotencyKeyHeader == "" {
config.IdempotencyKeyHeader = DefaultIdempotencyKeyHeader
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get a hash of the request body to either compare with a stored request
// or store it for future requests.
hash, err := hashBody(r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

key := &IdempotencyKey{
Key: r.Header.Get(config.IdempotencyKeyHeader),
RecoveryPoint: RecoveryPointStart,
RequestMethod: r.Method,
RequestURLPath: r.URL.Path,
RequestURLRawQuery: r.URL.RawQuery,
RequestHeaders: r.Header,
RequestBodyHash: hash,
}
key, inserted, err := config.Repository.GetOrInsert(key)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

// Compare the incoming request with the stored request.
if !inserted && !compareRequests(r, hash, key) {
w.WriteHeader(http.StatusBadRequest)
return
}

// If the request has already been processed, return the stored response.
if key.ResponseStatusCode != 0 {
for k, v := range key.ResponseHeaders {
w.Header().Add(k, strings.Join(v, ","))
}
w.WriteHeader(key.ResponseStatusCode)
w.Write([]byte(key.ResponseBody))
return
}

// Lock the key to prevent other requests from processing it.
// TODO: lock timeout.
acquired := config.Locker.Lock(key.Key)
// If the lock was not acquired, return a 409.
if !acquired {
w.WriteHeader(http.StatusConflict)
return
}
defer config.Locker.Unlock(key.Key)

// Process the request and store the response.
rec := httptest.NewRecorder()
next.ServeHTTP(rec, r)

result := rec.Result()
resbody := rec.Body.Bytes()

for k, v := range result.Header {
w.Header().Add(k, strings.Join(v, ","))
}

w.WriteHeader(result.StatusCode)
w.Write(resbody)

err = config.Repository.SetResponse(key.Key, result.StatusCode,
result.Header, string(resbody))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
})
}
}

// Returns a hash of the request body.
func hashBody(r *http.Request) (string, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
// Restore the body so it can be read again.
r.Body = io.NopCloser(bytes.NewReader(body))
r.ContentLength = int64(len(body))
r.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(body)), nil
}
h := md5.New()
if _, err := h.Write(body); err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil)), nil
}

// Compares the incoming request with the stored request. Returns true if they
// match, false otherwise.
func compareRequests(r *http.Request,
hashedBody string, stored *IdempotencyKey) bool {
if r.Method != stored.RequestMethod {
return false
}

if r.URL.Path != stored.RequestURLPath {
return false
}

if r.URL.RawQuery != stored.RequestURLRawQuery {
return false
}

if hashedBody != stored.RequestBodyHash {
return false
}

if len(r.Header) != len(stored.RequestHeaders) {
return false
}

requestHeaders := r.Header.Clone()

for header, v := range stored.RequestHeaders {
if requestHeaders.Get(header) != v[0] {
return false
}
requestHeaders.Del(header)
}

return len(requestHeaders) == 0
}

type ctxKey struct{}

func withContext(r *http.Request, key *IdempotencyKey) *http.Request {
return r.WithContext(context.WithValue(r.Context(), ctxKey{}, key))
}

func fromContext(r *http.Request) *IdempotencyKey {
return r.Context().Value(ctxKey{}).(*IdempotencyKey)
}

var rep Repository

func RecoveryPoint(name string, r *http.Request, fn func() string) {
key := fromContext(r)
if key == nil {
panic("idempotency: RecoveryPoint must be called after idempotency middleware")
}

if key.RecoveryPoint != name {
return
}

// TODO: prevent cyclic recovery points.

next := fn()
if err := rep.SetRecoveryPoint(key.Key, next); err != nil {
panic(err)
}

key.RecoveryPoint = next
r = withContext(r, key)
}
Loading

0 comments on commit 3e124c7

Please sign in to comment.