diff --git a/gock.go b/gock.go
index c3a8333..5f1e4c8 100644
--- a/gock.go
+++ b/gock.go
@@ -1,57 +1,43 @@
package gock
import (
- "fmt"
"net/http"
- "net/http/httputil"
- "net/url"
- "regexp"
"sync"
+
+ "github.com/h2non/gock/threadsafe"
)
+var g = threadsafe.NewGock()
+
+func init() {
+ g.DisableCallback = disable
+ g.InterceptCallback = intercept
+ g.InterceptingCallback = intercepting
+}
+
// mutex is used interally for locking thread-sensitive functions.
var mutex = &sync.Mutex{}
-// config global singleton store.
-var config = struct {
- Networking bool
- NetworkingFilters []FilterRequestFunc
- Observer ObserverFunc
-}{}
-
// ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic
-type ObserverFunc func(*http.Request, Mock)
+type ObserverFunc = threadsafe.ObserverFunc
// DumpRequest is a default implementation of ObserverFunc that dumps
// the HTTP/1.x wire representation of the http request
-var DumpRequest ObserverFunc = func(request *http.Request, mock Mock) {
- bytes, _ := httputil.DumpRequestOut(request, true)
- fmt.Println(string(bytes))
- fmt.Printf("\nMatches: %v\n---\n", mock != nil)
-}
-
-// track unmatched requests so they can be tested for
-var unmatchedRequests = []*http.Request{}
+var DumpRequest = g.DumpRequest
// New creates and registers a new HTTP mock with
// default settings and returns the Request DSL for HTTP mock
// definition and set up.
func New(uri string) *Request {
- Intercept()
-
- res := NewResponse()
- req := NewRequest()
- req.URLStruct, res.Error = url.Parse(normalizeURI(uri))
-
- // Create the new mock expectation
- exp := NewMock(req, res)
- Register(exp)
-
- return req
+ return g.New(uri)
}
// Intercepting returns true if gock is currently able to intercept.
func Intercepting() bool {
+ return g.Intercepting()
+}
+
+func intercepting() bool {
mutex.Lock()
defer mutex.Unlock()
return http.DefaultTransport == DefaultTransport
@@ -60,37 +46,33 @@ func Intercepting() bool {
// Intercept enables HTTP traffic interception via http.DefaultTransport.
// If you are using a custom HTTP transport, you have to use `gock.Transport()`
func Intercept() {
- if !Intercepting() {
- mutex.Lock()
- http.DefaultTransport = DefaultTransport
- mutex.Unlock()
- }
+ g.Intercept()
+}
+
+func intercept() {
+ mutex.Lock()
+ http.DefaultTransport = DefaultTransport
+ mutex.Unlock()
}
// InterceptClient allows the developer to intercept HTTP traffic using
// a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation.
func InterceptClient(cli *http.Client) {
- _, ok := cli.Transport.(*Transport)
- if ok {
- return // if transport already intercepted, just ignore it
- }
- trans := NewTransport()
- trans.Transport = cli.Transport
- cli.Transport = trans
+ g.InterceptClient(cli)
}
// RestoreClient allows the developer to disable and restore the
// original transport in the given http.Client.
func RestoreClient(cli *http.Client) {
- trans, ok := cli.Transport.(*Transport)
- if !ok {
- return
- }
- cli.Transport = trans.Transport
+ g.RestoreClient(cli)
}
// Disable disables HTTP traffic interception by gock.
func Disable() {
+ g.Disable()
+}
+
+func disable() {
mutex.Lock()
defer mutex.Unlock()
http.DefaultTransport = NativeTransport
@@ -99,80 +81,50 @@ func Disable() {
// Off disables the default HTTP interceptors and removes
// all the registered mocks, even if they has not been intercepted yet.
func Off() {
- Flush()
- Disable()
+ g.Off()
}
// OffAll is like `Off()`, but it also removes the unmatched requests registry.
func OffAll() {
- Flush()
- Disable()
- CleanUnmatchedRequest()
+ g.OffAll()
}
// Observe provides a hook to support inspection of the request and matched mock
func Observe(fn ObserverFunc) {
- mutex.Lock()
- defer mutex.Unlock()
- config.Observer = fn
+ g.Observe(fn)
}
// EnableNetworking enables real HTTP networking
func EnableNetworking() {
- mutex.Lock()
- defer mutex.Unlock()
- config.Networking = true
+ g.EnableNetworking()
}
// DisableNetworking disables real HTTP networking
func DisableNetworking() {
- mutex.Lock()
- defer mutex.Unlock()
- config.Networking = false
+ g.DisableNetworking()
}
// NetworkingFilter determines if an http.Request should be triggered or not.
func NetworkingFilter(fn FilterRequestFunc) {
- mutex.Lock()
- defer mutex.Unlock()
- config.NetworkingFilters = append(config.NetworkingFilters, fn)
+ g.NetworkingFilter(fn)
}
// DisableNetworkingFilters disables registered networking filters.
func DisableNetworkingFilters() {
- mutex.Lock()
- defer mutex.Unlock()
- config.NetworkingFilters = []FilterRequestFunc{}
+ g.DisableNetworkingFilters()
}
// GetUnmatchedRequests returns all requests that have been received but haven't matched any mock
func GetUnmatchedRequests() []*http.Request {
- mutex.Lock()
- defer mutex.Unlock()
- return unmatchedRequests
+ return g.GetUnmatchedRequests()
}
// HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock
func HasUnmatchedRequest() bool {
- return len(GetUnmatchedRequests()) > 0
+ return g.HasUnmatchedRequest()
}
// CleanUnmatchedRequest cleans the unmatched requests internal registry.
func CleanUnmatchedRequest() {
- mutex.Lock()
- defer mutex.Unlock()
- unmatchedRequests = []*http.Request{}
-}
-
-func trackUnmatchedRequest(req *http.Request) {
- mutex.Lock()
- defer mutex.Unlock()
- unmatchedRequests = append(unmatchedRequests, req)
-}
-
-func normalizeURI(uri string) string {
- if ok, _ := regexp.MatchString("^http[s]?", uri); !ok {
- return "http://" + uri
- }
- return uri
+ g.CleanUnmatchedRequest()
}
diff --git a/gock_test.go b/gock_test.go
index 7df68fe..8ad128b 100644
--- a/gock_test.go
+++ b/gock_test.go
@@ -304,7 +304,7 @@ func TestUnmatched(t *testing.T) {
defer after()
// clear out any unmatchedRequests from other tests
- unmatchedRequests = []*http.Request{}
+ CleanUnmatchedRequest()
Intercept()
diff --git a/matcher.go b/matcher.go
index 11a1d7e..633734f 100644
--- a/matcher.go
+++ b/matcher.go
@@ -1,137 +1,75 @@
package gock
-import "net/http"
+import (
+ "net/http"
+
+ "github.com/h2non/gock/threadsafe"
+)
// MatchersHeader exposes an slice of HTTP header specific mock matchers.
-var MatchersHeader = []MatchFunc{
- MatchMethod,
- MatchScheme,
- MatchHost,
- MatchPath,
- MatchHeaders,
- MatchQueryParams,
- MatchPathParams,
+func MatchersHeader() []MatchFunc {
+ return g.MatchersHeader
+}
+
+func SetMatchersHeader(matchers []MatchFunc) {
+ g.MatchersHeader = matchers
}
// MatchersBody exposes an slice of HTTP body specific built-in mock matchers.
-var MatchersBody = []MatchFunc{
- MatchBody,
+func MatchersBody() []MatchFunc {
+ return g.MatchersBody
+}
+
+func SetMatchersBody(matchers []MatchFunc) {
+ g.MatchersBody = matchers
}
// Matchers stores all the built-in mock matchers.
-var Matchers = append(MatchersHeader, MatchersBody...)
+func Matchers() []MatchFunc {
+ return g.Matchers
+}
+
+func SetMatchers(matchers []MatchFunc) {
+ g.Matchers = matchers
+}
// DefaultMatcher stores the default Matcher instance used to match mocks.
-var DefaultMatcher = NewMatcher()
+func DefaultMatcher() *MockMatcher {
+ return g.DefaultMatcher
+}
+
+func SetDefaultMatcher(matcher *MockMatcher) {
+ g.DefaultMatcher = matcher
+}
// MatchFunc represents the required function
// interface implemented by matchers.
-type MatchFunc func(*http.Request, *Request) (bool, error)
+type MatchFunc = threadsafe.MatchFunc
// Matcher represents the required interface implemented by mock matchers.
-type Matcher interface {
- // Get returns a slice of registered function matchers.
- Get() []MatchFunc
-
- // Add adds a new matcher function.
- Add(MatchFunc)
-
- // Set sets the matchers functions stack.
- Set([]MatchFunc)
-
- // Flush flushes the current matchers function stack.
- Flush()
-
- // Match matches the given http.Request with a mock Request.
- Match(*http.Request, *Request) (bool, error)
-}
+type Matcher = threadsafe.Matcher
// MockMatcher implements a mock matcher
-type MockMatcher struct {
- Matchers []MatchFunc
-}
+type MockMatcher = threadsafe.MockMatcher
// NewMatcher creates a new mock matcher
// using the default matcher functions.
func NewMatcher() *MockMatcher {
- m := NewEmptyMatcher()
- for _, matchFn := range Matchers {
- m.Add(matchFn)
- }
- return m
+ return g.NewMatcher()
}
// NewBasicMatcher creates a new matcher with header only mock matchers.
func NewBasicMatcher() *MockMatcher {
- m := NewEmptyMatcher()
- for _, matchFn := range MatchersHeader {
- m.Add(matchFn)
- }
- return m
+ return g.NewBasicMatcher()
}
// NewEmptyMatcher creates a new empty matcher without default matchers.
func NewEmptyMatcher() *MockMatcher {
- return &MockMatcher{Matchers: []MatchFunc{}}
-}
-
-// Get returns a slice of registered function matchers.
-func (m *MockMatcher) Get() []MatchFunc {
- mutex.Lock()
- defer mutex.Unlock()
- return m.Matchers
-}
-
-// Add adds a new function matcher.
-func (m *MockMatcher) Add(fn MatchFunc) {
- m.Matchers = append(m.Matchers, fn)
-}
-
-// Set sets a new stack of matchers functions.
-func (m *MockMatcher) Set(stack []MatchFunc) {
- m.Matchers = stack
-}
-
-// Flush flushes the current matcher
-func (m *MockMatcher) Flush() {
- m.Matchers = []MatchFunc{}
-}
-
-// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs
-func (m *MockMatcher) Clone() *MockMatcher {
- m2 := NewEmptyMatcher()
- for _, mFn := range m.Get() {
- m2.Add(mFn)
- }
- return m2
-}
-
-// Match matches the given http.Request with a mock request
-// returning true in case that the request matches, otherwise false.
-func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) {
- for _, matcher := range m.Matchers {
- matches, err := matcher(req, ereq)
- if err != nil {
- return false, err
- }
- if !matches {
- return false, nil
- }
- }
- return true, nil
+ return g.NewEmptyMatcher()
}
// MatchMock is a helper function that matches the given http.Request
// in the list of registered mocks, returning it if matches or error if it fails.
func MatchMock(req *http.Request) (Mock, error) {
- for _, mock := range GetAll() {
- matches, err := mock.Match(req)
- if err != nil {
- return nil, err
- }
- if matches {
- return mock, nil
- }
- }
- return nil, nil
+ return g.MatchMock(req)
}
diff --git a/matcher_test.go b/matcher_test.go
index d96c00c..c7e842c 100644
--- a/matcher_test.go
+++ b/matcher_test.go
@@ -9,24 +9,24 @@ import (
)
func TestRegisteredMatchers(t *testing.T) {
- st.Expect(t, len(MatchersHeader), 7)
- st.Expect(t, len(MatchersBody), 1)
+ st.Expect(t, len(MatchersHeader()), 7)
+ st.Expect(t, len(MatchersBody()), 1)
}
func TestNewMatcher(t *testing.T) {
matcher := NewMatcher()
// Funcs are not comparable, checking slice length as it's better than nothing
// See https://golang.org/pkg/reflect/#DeepEqual
- st.Expect(t, len(matcher.Matchers), len(Matchers))
- st.Expect(t, len(matcher.Get()), len(Matchers))
+ st.Expect(t, len(matcher.Matchers), len(Matchers()))
+ st.Expect(t, len(matcher.Get()), len(Matchers()))
}
func TestNewBasicMatcher(t *testing.T) {
matcher := NewBasicMatcher()
// Funcs are not comparable, checking slice length as it's better than nothing
// See https://golang.org/pkg/reflect/#DeepEqual
- st.Expect(t, len(matcher.Matchers), len(MatchersHeader))
- st.Expect(t, len(matcher.Get()), len(MatchersHeader))
+ st.Expect(t, len(matcher.Matchers), len(MatchersHeader()))
+ st.Expect(t, len(matcher.Get()), len(MatchersHeader()))
}
func TestNewEmptyMatcher(t *testing.T) {
@@ -37,17 +37,17 @@ func TestNewEmptyMatcher(t *testing.T) {
func TestMatcherAdd(t *testing.T) {
matcher := NewMatcher()
- st.Expect(t, len(matcher.Matchers), len(Matchers))
+ st.Expect(t, len(matcher.Matchers), len(Matchers()))
matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
return true, nil
})
- st.Expect(t, len(matcher.Get()), len(Matchers)+1)
+ st.Expect(t, len(matcher.Get()), len(Matchers())+1)
}
func TestMatcherSet(t *testing.T) {
matcher := NewMatcher()
matchers := []MatchFunc{}
- st.Expect(t, len(matcher.Matchers), len(Matchers))
+ st.Expect(t, len(matcher.Matchers), len(Matchers()))
matcher.Set(matchers)
st.Expect(t, matcher.Matchers, matchers)
st.Expect(t, len(matcher.Get()), 0)
@@ -62,18 +62,18 @@ func TestMatcherGet(t *testing.T) {
func TestMatcherFlush(t *testing.T) {
matcher := NewMatcher()
- st.Expect(t, len(matcher.Matchers), len(Matchers))
+ st.Expect(t, len(matcher.Matchers), len(Matchers()))
matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
return true, nil
})
- st.Expect(t, len(matcher.Get()), len(Matchers)+1)
+ st.Expect(t, len(matcher.Get()), len(Matchers())+1)
matcher.Flush()
st.Expect(t, len(matcher.Get()), 0)
}
func TestMatcherClone(t *testing.T) {
- matcher := DefaultMatcher.Clone()
- st.Expect(t, len(matcher.Get()), len(DefaultMatcher.Get()))
+ matcher := DefaultMatcher().Clone()
+ st.Expect(t, len(matcher.Get()), len(DefaultMatcher().Get()))
}
func TestMatcher(t *testing.T) {
@@ -115,19 +115,20 @@ func TestMatcher(t *testing.T) {
func TestMatchMock(t *testing.T) {
cases := []struct {
- method string
- url string
- matches bool
+ method string
+ methodFn func(r *Request, path string) *Request
+ url string
+ matches bool
}{
- {"GET", "http://foo.com/bar", true},
- {"GET", "http://foo.com/baz", true},
- {"GET", "http://foo.com/foo", false},
- {"POST", "http://foo.com/bar", false},
- {"POST", "http://bar.com/bar", false},
- {"GET", "http://foo.com", false},
+ {"GET", (*Request).Get, "http://foo.com/bar", true},
+ {"GET", (*Request).Get, "http://foo.com/baz", true},
+ {"GET", (*Request).Get, "http://foo.com/foo", false},
+ {"POST", (*Request).Post, "http://foo.com/bar", false},
+ {"POST", (*Request).Post, "http://bar.com/bar", false},
+ {"GET", (*Request).Get, "http://foo.com", false},
}
- matcher := DefaultMatcher
+ matcher := DefaultMatcher()
matcher.Flush()
st.Expect(t, len(matcher.Matchers), 0)
@@ -143,7 +144,7 @@ func TestMatchMock(t *testing.T) {
for _, test := range cases {
Flush()
- mock := New(test.url).method(test.method, "").Mock
+ mock := test.methodFn(New(test.url), "").Mock
u, _ := url.Parse(test.url)
req := &http.Request{Method: test.method, URL: u}
@@ -157,5 +158,5 @@ func TestMatchMock(t *testing.T) {
}
}
- DefaultMatcher.Matchers = Matchers
+ DefaultMatcher().Matchers = Matchers()
}
diff --git a/matchers.go b/matchers.go
index 658c9a6..8c3ac77 100644
--- a/matchers.go
+++ b/matchers.go
@@ -1,266 +1,79 @@
package gock
import (
- "compress/gzip"
- "encoding/json"
- "io"
- "io/ioutil"
"net/http"
- "reflect"
- "regexp"
- "strings"
- "github.com/h2non/parth"
+ "github.com/h2non/gock/threadsafe"
)
// EOL represents the end of line character.
-const EOL = 0xa
+const EOL = threadsafe.EOL
// BodyTypes stores the supported MIME body types for matching.
// Currently only text-based types.
-var BodyTypes = []string{
- "text/html",
- "text/plain",
- "application/json",
- "application/xml",
- "multipart/form-data",
- "application/x-www-form-urlencoded",
+func BodyTypes() []string {
+ return g.BodyTypes
+}
+
+func SetBodyTypes(types []string) {
+ g.BodyTypes = types
}
// BodyTypeAliases stores a generic MIME type by alias.
-var BodyTypeAliases = map[string]string{
- "html": "text/html",
- "text": "text/plain",
- "json": "application/json",
- "xml": "application/xml",
- "form": "multipart/form-data",
- "url": "application/x-www-form-urlencoded",
+func BodyTypeAliases() map[string]string {
+ return g.BodyTypeAliases
+}
+
+func SetBodyTypeAliases(aliases map[string]string) {
+ g.BodyTypeAliases = aliases
}
// CompressionSchemes stores the supported Content-Encoding types for decompression.
-var CompressionSchemes = []string{
- "gzip",
+func CompressionSchemes() []string {
+ return g.CompressionSchemes
+}
+
+func SetCompressionSchemes(schemes []string) {
+ g.CompressionSchemes = schemes
}
// MatchMethod matches the HTTP method of the given request.
func MatchMethod(req *http.Request, ereq *Request) (bool, error) {
- return ereq.Method == "" || req.Method == ereq.Method, nil
+ return g.MatchMethod(req, ereq)
}
// MatchScheme matches the request URL protocol scheme.
func MatchScheme(req *http.Request, ereq *Request) (bool, error) {
- return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil
+ return g.MatchScheme(req, ereq)
}
// MatchHost matches the HTTP host header field of the given request.
func MatchHost(req *http.Request, ereq *Request) (bool, error) {
- url := ereq.URLStruct
- if strings.EqualFold(url.Host, req.URL.Host) {
- return true, nil
- }
- if !ereq.Options.DisableRegexpHost {
- return regexp.MatchString(url.Host, req.URL.Host)
- }
- return false, nil
+ return g.MatchHost(req, ereq)
}
// MatchPath matches the HTTP URL path of the given request.
func MatchPath(req *http.Request, ereq *Request) (bool, error) {
- if req.URL.Path == ereq.URLStruct.Path {
- return true, nil
- }
- return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path)
+ return g.MatchPath(req, ereq)
}
// MatchHeaders matches the headers fields of the given request.
func MatchHeaders(req *http.Request, ereq *Request) (bool, error) {
- for key, value := range ereq.Header {
- var err error
- var match bool
- var matchEscaped bool
-
- for _, field := range req.Header[key] {
- match, err = regexp.MatchString(value[0], field)
- // Some values may contain reserved regex params e.g. "()", try matching with these escaped.
- matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field)
-
- if err != nil {
- return false, err
- }
- if match || matchEscaped {
- break
- }
-
- }
-
- if !match && !matchEscaped {
- return false, nil
- }
- }
- return true, nil
+ return g.MatchHeaders(req, ereq)
}
// MatchQueryParams matches the URL query params fields of the given request.
func MatchQueryParams(req *http.Request, ereq *Request) (bool, error) {
- for key, value := range ereq.URLStruct.Query() {
- var err error
- var match bool
-
- for _, field := range req.URL.Query()[key] {
- match, err = regexp.MatchString(value[0], field)
- if err != nil {
- return false, err
- }
- if match {
- break
- }
- }
-
- if !match {
- return false, nil
- }
- }
- return true, nil
+ return g.MatchQueryParams(req, ereq)
}
// MatchPathParams matches the URL path parameters of the given request.
func MatchPathParams(req *http.Request, ereq *Request) (bool, error) {
- for key, value := range ereq.PathParams {
- var s string
-
- if err := parth.Sequent(req.URL.Path, key, &s); err != nil {
- return false, nil
- }
-
- if s != value {
- return false, nil
- }
- }
- return true, nil
+ return g.MatchPathParams(req, ereq)
}
// MatchBody tries to match the request body.
// TODO: not too smart now, needs several improvements.
func MatchBody(req *http.Request, ereq *Request) (bool, error) {
- // If match body is empty, just continue
- if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 {
- return true, nil
- }
-
- // Only can match certain MIME body types
- if !supportedType(req, ereq) {
- return false, nil
- }
-
- // Can only match certain compression schemes
- if !supportedCompressionScheme(req) {
- return false, nil
- }
-
- // Create a reader for the body depending on compression type
- bodyReader := req.Body
- if ereq.CompressionScheme != "" {
- if ereq.CompressionScheme != req.Header.Get("Content-Encoding") {
- return false, nil
- }
- compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme)
- if err != nil {
- return false, err
- }
- bodyReader = compressedBodyReader
- }
-
- // Read the whole request body
- body, err := ioutil.ReadAll(bodyReader)
- if err != nil {
- return false, err
- }
-
- // Restore body reader stream
- req.Body = createReadCloser(body)
-
- // If empty, ignore the match
- if len(body) == 0 && len(ereq.BodyBuffer) != 0 {
- return false, nil
- }
-
- // Match body by atomic string comparison
- bodyStr := castToString(body)
- matchStr := castToString(ereq.BodyBuffer)
- if bodyStr == matchStr {
- return true, nil
- }
-
- // Match request body by regexp
- match, _ := regexp.MatchString(matchStr, bodyStr)
- if match == true {
- return true, nil
- }
-
- // todo - add conditional do only perform the conversion of body bytes
- // representation of JSON to a map and then compare them for equality.
-
- // Check if the key + value pairs match
- var bodyMap map[string]interface{}
- var matchMap map[string]interface{}
-
- // Ensure that both byte bodies that that should be JSON can be converted to maps.
- umErr := json.Unmarshal(body, &bodyMap)
- umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap)
- if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) {
- return true, nil
- }
-
- return false, nil
-}
-
-func supportedType(req *http.Request, ereq *Request) bool {
- mime := req.Header.Get("Content-Type")
- if mime == "" {
- return true
- }
-
- mimeToMatch := ereq.Header.Get("Content-Type")
- if mimeToMatch != "" {
- return mime == mimeToMatch
- }
-
- for _, kind := range BodyTypes {
- if match, _ := regexp.MatchString(kind, mime); match {
- return true
- }
- }
- return false
-}
-
-func supportedCompressionScheme(req *http.Request) bool {
- encoding := req.Header.Get("Content-Encoding")
- if encoding == "" {
- return true
- }
-
- for _, kind := range CompressionSchemes {
- if match, _ := regexp.MatchString(kind, encoding); match {
- return true
- }
- }
- return false
-}
-
-func castToString(buf []byte) string {
- str := string(buf)
- tail := len(str) - 1
- if str[tail] == EOL {
- str = str[:tail]
- }
- return str
-}
-
-func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) {
- switch scheme {
- case "gzip":
- return gzip.NewReader(r)
- default:
- return r, nil
- }
+ return g.MatchBody(req, ereq)
}
diff --git a/matchers_test.go b/matchers_test.go
index 56aaa01..cbe30d6 100644
--- a/matchers_test.go
+++ b/matchers_test.go
@@ -1,6 +1,9 @@
package gock
import (
+ "bytes"
+ "io"
+ "io/ioutil"
"net/http"
"net/url"
"testing"
@@ -249,3 +252,9 @@ func TestMatchBody_MatchType(t *testing.T) {
st.Expect(t, matches, test.matches)
}
}
+
+// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an
+// http response body.
+func createReadCloser(body []byte) io.ReadCloser {
+ return ioutil.NopCloser(bytes.NewReader(body))
+}
diff --git a/mock.go b/mock.go
index d28875b..aa388ff 100644
--- a/mock.go
+++ b/mock.go
@@ -1,172 +1,19 @@
package gock
import (
- "net/http"
- "sync"
+ "github.com/h2non/gock/threadsafe"
)
// Mock represents the required interface that must
// be implemented by HTTP mock instances.
-type Mock interface {
- // Disable disables the current mock manually.
- Disable()
-
- // Done returns true if the current mock is disabled.
- Done() bool
-
- // Request returns the mock Request instance.
- Request() *Request
-
- // Response returns the mock Response instance.
- Response() *Response
-
- // Match matches the given http.Request with the current mock.
- Match(*http.Request) (bool, error)
-
- // AddMatcher adds a new matcher function.
- AddMatcher(MatchFunc)
-
- // SetMatcher uses a new matcher implementation.
- SetMatcher(Matcher)
-}
+type Mock = threadsafe.Mock
// Mocker implements a Mock capable interface providing
// a default mock configuration used internally to store mocks.
-type Mocker struct {
- // disabler stores a disabler for thread safety checking current mock is disabled
- disabler *disabler
-
- // mutex stores the mock mutex for thread safety.
- mutex sync.Mutex
-
- // matcher stores a Matcher capable instance to match the given http.Request.
- matcher Matcher
-
- // request stores the mock Request to match.
- request *Request
-
- // response stores the mock Response to use in case of match.
- response *Response
-}
-
-type disabler struct {
- // disabled stores if the current mock is disabled.
- disabled bool
-
- // mutex stores the disabler mutex for thread safety.
- mutex sync.RWMutex
-}
-
-func (d *disabler) isDisabled() bool {
- d.mutex.RLock()
- defer d.mutex.RUnlock()
- return d.disabled
-}
-
-func (d *disabler) Disable() {
- d.mutex.Lock()
- defer d.mutex.Unlock()
- d.disabled = true
-}
+type Mocker = threadsafe.Mocker
// NewMock creates a new HTTP mock based on the given request and response instances.
// It's mostly used internally.
func NewMock(req *Request, res *Response) *Mocker {
- mock := &Mocker{
- disabler: new(disabler),
- request: req,
- response: res,
- matcher: DefaultMatcher.Clone(),
- }
- res.Mock = mock
- req.Mock = mock
- req.Response = res
- return mock
-}
-
-// Disable disables the current mock manually.
-func (m *Mocker) Disable() {
- m.disabler.Disable()
-}
-
-// Done returns true in case that the current mock
-// instance is disabled and therefore must be removed.
-func (m *Mocker) Done() bool {
- // prevent deadlock with m.mutex
- if m.disabler.isDisabled() {
- return true
- }
-
- m.mutex.Lock()
- defer m.mutex.Unlock()
- return !m.request.Persisted && m.request.Counter == 0
-}
-
-// Request returns the Request instance
-// configured for the current HTTP mock.
-func (m *Mocker) Request() *Request {
- return m.request
-}
-
-// Response returns the Response instance
-// configured for the current HTTP mock.
-func (m *Mocker) Response() *Response {
- return m.response
-}
-
-// Match matches the given http.Request with the current Request
-// mock expectation, returning true if matches.
-func (m *Mocker) Match(req *http.Request) (bool, error) {
- if m.disabler.isDisabled() {
- return false, nil
- }
-
- // Filter
- for _, filter := range m.request.Filters {
- if !filter(req) {
- return false, nil
- }
- }
-
- // Map
- for _, mapper := range m.request.Mappers {
- if treq := mapper(req); treq != nil {
- req = treq
- }
- }
-
- // Match
- matches, err := m.matcher.Match(req, m.request)
- if matches {
- m.decrement()
- }
-
- return matches, err
-}
-
-// SetMatcher sets a new matcher implementation
-// for the current mock expectation.
-func (m *Mocker) SetMatcher(matcher Matcher) {
- m.matcher = matcher
-}
-
-// AddMatcher adds a new matcher function
-// for the current mock expectation.
-func (m *Mocker) AddMatcher(fn MatchFunc) {
- m.matcher.Add(fn)
-}
-
-// decrement decrements the current mock Request counter.
-func (m *Mocker) decrement() {
- if m.request.Persisted {
- return
- }
-
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.request.Counter--
- if m.request.Counter == 0 {
- m.disabler.Disable()
- }
+ return g.NewMock(req, res)
}
diff --git a/mock_test.go b/mock_test.go
index 01e6fca..70b0765 100644
--- a/mock_test.go
+++ b/mock_test.go
@@ -7,63 +7,6 @@ import (
"github.com/nbio/st"
)
-func TestNewMock(t *testing.T) {
- defer after()
-
- req := NewRequest()
- res := NewResponse()
- mock := NewMock(req, res)
- st.Expect(t, mock.disabler.isDisabled(), false)
- st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get()))
-
- st.Expect(t, mock.Request(), req)
- st.Expect(t, mock.Request().Mock, mock)
- st.Expect(t, mock.Response(), res)
- st.Expect(t, mock.Response().Mock, mock)
-}
-
-func TestMockDisable(t *testing.T) {
- defer after()
-
- req := NewRequest()
- res := NewResponse()
- mock := NewMock(req, res)
-
- st.Expect(t, mock.disabler.isDisabled(), false)
- mock.Disable()
- st.Expect(t, mock.disabler.isDisabled(), true)
-
- matches, err := mock.Match(&http.Request{})
- st.Expect(t, err, nil)
- st.Expect(t, matches, false)
-}
-
-func TestMockDone(t *testing.T) {
- defer after()
-
- req := NewRequest()
- res := NewResponse()
-
- mock := NewMock(req, res)
- st.Expect(t, mock.disabler.isDisabled(), false)
- st.Expect(t, mock.Done(), false)
-
- mock = NewMock(req, res)
- st.Expect(t, mock.disabler.isDisabled(), false)
- mock.Disable()
- st.Expect(t, mock.Done(), true)
-
- mock = NewMock(req, res)
- st.Expect(t, mock.disabler.isDisabled(), false)
- mock.request.Counter = 0
- st.Expect(t, mock.Done(), true)
-
- mock = NewMock(req, res)
- st.Expect(t, mock.disabler.isDisabled(), false)
- mock.request.Persisted = true
- st.Expect(t, mock.Done(), false)
-}
-
func TestMockSetMatcher(t *testing.T) {
defer after()
@@ -71,15 +14,12 @@ func TestMockSetMatcher(t *testing.T) {
res := NewResponse()
mock := NewMock(req, res)
- st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get()))
matcher := NewMatcher()
matcher.Flush()
matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
return true, nil
})
mock.SetMatcher(matcher)
- st.Expect(t, len(mock.matcher.Get()), 1)
- st.Expect(t, mock.disabler.isDisabled(), false)
matches, err := mock.Match(&http.Request{})
st.Expect(t, err, nil)
@@ -93,15 +33,12 @@ func TestMockAddMatcher(t *testing.T) {
res := NewResponse()
mock := NewMock(req, res)
- st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get()))
matcher := NewMatcher()
matcher.Flush()
mock.SetMatcher(matcher)
mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
return true, nil
})
- st.Expect(t, mock.disabler.isDisabled(), false)
- st.Expect(t, mock.matcher, matcher)
matches, err := mock.Match(&http.Request{})
st.Expect(t, err, nil)
@@ -127,8 +64,6 @@ func TestMockMatch(t *testing.T) {
calls++
return true, nil
})
- st.Expect(t, mock.disabler.isDisabled(), false)
- st.Expect(t, mock.matcher, matcher)
matches, err := mock.Match(&http.Request{})
st.Expect(t, err, nil)
diff --git a/options.go b/options.go
index 188aa58..f754563 100644
--- a/options.go
+++ b/options.go
@@ -1,8 +1,6 @@
package gock
+import "github.com/h2non/gock/threadsafe"
+
// Options represents customized option for gock
-type Options struct {
- // DisableRegexpHost stores if the host is only a plain string rather than regular expression,
- // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string
- DisableRegexpHost bool
-}
+type Options = threadsafe.Options
diff --git a/request.go b/request.go
index 5702417..1563c37 100644
--- a/request.go
+++ b/request.go
@@ -1,325 +1,20 @@
package gock
import (
- "encoding/base64"
- "io"
- "io/ioutil"
- "net/http"
- "net/url"
- "strings"
+ "github.com/h2non/gock/threadsafe"
)
// MapRequestFunc represents the required function interface for request mappers.
-type MapRequestFunc func(*http.Request) *http.Request
+type MapRequestFunc = threadsafe.MapRequestFunc
// FilterRequestFunc represents the required function interface for request filters.
-type FilterRequestFunc func(*http.Request) bool
+type FilterRequestFunc = threadsafe.FilterRequestFunc
// Request represents the high-level HTTP request used to store
// request fields used to match intercepted requests.
-type Request struct {
- // Mock stores the parent mock reference for the current request mock used for method delegation.
- Mock Mock
-
- // Response stores the current Response instance for the current matches Request.
- Response *Response
-
- // Error stores the latest mock request configuration error.
- Error error
-
- // Counter stores the pending times that the current mock should be active.
- Counter int
-
- // Persisted stores if the current mock should be always active.
- Persisted bool
-
- // Options stores options for current Request.
- Options Options
-
- // URLStruct stores the parsed URL as *url.URL struct.
- URLStruct *url.URL
-
- // Method stores the Request HTTP method to match.
- Method string
-
- // CompressionScheme stores the Request Compression scheme to match and use for decompression.
- CompressionScheme string
-
- // Header stores the HTTP header fields to match.
- Header http.Header
-
- // Cookies stores the Request HTTP cookies values to match.
- Cookies []*http.Cookie
-
- // PathParams stores the path parameters to match.
- PathParams map[string]string
-
- // BodyBuffer stores the body data to match.
- BodyBuffer []byte
-
- // Mappers stores the request functions mappers used for matching.
- Mappers []MapRequestFunc
-
- // Filters stores the request functions filters used for matching.
- Filters []FilterRequestFunc
-}
+type Request = threadsafe.Request
// NewRequest creates a new Request instance.
func NewRequest() *Request {
- return &Request{
- Counter: 1,
- URLStruct: &url.URL{},
- Header: make(http.Header),
- PathParams: make(map[string]string),
- }
-}
-
-// URL defines the mock URL to match.
-func (r *Request) URL(uri string) *Request {
- r.URLStruct, r.Error = url.Parse(uri)
- return r
-}
-
-// SetURL defines the url.URL struct to be used for matching.
-func (r *Request) SetURL(u *url.URL) *Request {
- r.URLStruct = u
- return r
-}
-
-// Path defines the mock URL path value to match.
-func (r *Request) Path(path string) *Request {
- r.URLStruct.Path = path
- return r
-}
-
-// Get specifies the GET method and the given URL path to match.
-func (r *Request) Get(path string) *Request {
- return r.method("GET", path)
-}
-
-// Post specifies the POST method and the given URL path to match.
-func (r *Request) Post(path string) *Request {
- return r.method("POST", path)
-}
-
-// Put specifies the PUT method and the given URL path to match.
-func (r *Request) Put(path string) *Request {
- return r.method("PUT", path)
-}
-
-// Delete specifies the DELETE method and the given URL path to match.
-func (r *Request) Delete(path string) *Request {
- return r.method("DELETE", path)
-}
-
-// Patch specifies the PATCH method and the given URL path to match.
-func (r *Request) Patch(path string) *Request {
- return r.method("PATCH", path)
-}
-
-// Head specifies the HEAD method and the given URL path to match.
-func (r *Request) Head(path string) *Request {
- return r.method("HEAD", path)
-}
-
-// method is a DRY shortcut used to declare the expected HTTP method and URL path.
-func (r *Request) method(method, path string) *Request {
- if path != "/" {
- r.URLStruct.Path = path
- }
- r.Method = strings.ToUpper(method)
- return r
-}
-
-// Body defines the body data to match based on a io.Reader interface.
-func (r *Request) Body(body io.Reader) *Request {
- r.BodyBuffer, r.Error = ioutil.ReadAll(body)
- return r
-}
-
-// BodyString defines the body to match based on a given string.
-func (r *Request) BodyString(body string) *Request {
- r.BodyBuffer = []byte(body)
- return r
-}
-
-// File defines the body to match based on the given file path string.
-func (r *Request) File(path string) *Request {
- r.BodyBuffer, r.Error = ioutil.ReadFile(path)
- return r
-}
-
-// Compression defines the request compression scheme, and enables automatic body decompression.
-// Supports only the "gzip" scheme so far.
-func (r *Request) Compression(scheme string) *Request {
- r.Header.Set("Content-Encoding", scheme)
- r.CompressionScheme = scheme
- return r
-}
-
-// JSON defines the JSON body to match based on a given structure.
-func (r *Request) JSON(data interface{}) *Request {
- if r.Header.Get("Content-Type") == "" {
- r.Header.Set("Content-Type", "application/json")
- }
- r.BodyBuffer, r.Error = readAndDecode(data, "json")
- return r
-}
-
-// XML defines the XML body to match based on a given structure.
-func (r *Request) XML(data interface{}) *Request {
- if r.Header.Get("Content-Type") == "" {
- r.Header.Set("Content-Type", "application/xml")
- }
- r.BodyBuffer, r.Error = readAndDecode(data, "xml")
- return r
-}
-
-// MatchType defines the request Content-Type MIME header field.
-// Supports custom MIME types and type aliases. E.g: json, xml, form, text...
-func (r *Request) MatchType(kind string) *Request {
- mime := BodyTypeAliases[kind]
- if mime != "" {
- kind = mime
- }
- r.Header.Set("Content-Type", kind)
- return r
-}
-
-// BasicAuth defines a username and password for HTTP Basic Authentication
-func (r *Request) BasicAuth(username, password string) *Request {
- r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
- return r
-}
-
-// MatchHeader defines a new key and value header to match.
-func (r *Request) MatchHeader(key, value string) *Request {
- r.Header.Set(key, value)
- return r
-}
-
-// HeaderPresent defines that a header field must be present in the request.
-func (r *Request) HeaderPresent(key string) *Request {
- r.Header.Set(key, ".*")
- return r
-}
-
-// MatchHeaders defines a map of key-value headers to match.
-func (r *Request) MatchHeaders(headers map[string]string) *Request {
- for key, value := range headers {
- r.Header.Set(key, value)
- }
- return r
-}
-
-// MatchParam defines a new key and value URL query param to match.
-func (r *Request) MatchParam(key, value string) *Request {
- query := r.URLStruct.Query()
- query.Set(key, value)
- r.URLStruct.RawQuery = query.Encode()
- return r
-}
-
-// MatchParams defines a map of URL query param key-value to match.
-func (r *Request) MatchParams(params map[string]string) *Request {
- query := r.URLStruct.Query()
- for key, value := range params {
- query.Set(key, value)
- }
- r.URLStruct.RawQuery = query.Encode()
- return r
-}
-
-// ParamPresent matches if the given query param key is present in the URL.
-func (r *Request) ParamPresent(key string) *Request {
- r.MatchParam(key, ".*")
- return r
-}
-
-// PathParam matches if a given path parameter key is present in the URL.
-//
-// The value is representative of the restful resource the key defines, e.g.
-// // /users/123/name
-// r.PathParam("users", "123")
-// would match.
-func (r *Request) PathParam(key, val string) *Request {
- r.PathParams[key] = val
-
- return r
-}
-
-// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it.
-func (r *Request) Persist() *Request {
- r.Persisted = true
- return r
-}
-
-// WithOptions sets the options for the request.
-func (r *Request) WithOptions(options Options) *Request {
- r.Options = options
- return r
-}
-
-// Times defines the number of times that the current HTTP mock should remain active.
-func (r *Request) Times(num int) *Request {
- r.Counter = num
- return r
-}
-
-// AddMatcher adds a new matcher function to match the request.
-func (r *Request) AddMatcher(fn MatchFunc) *Request {
- r.Mock.AddMatcher(fn)
- return r
-}
-
-// SetMatcher sets a new matcher function to match the request.
-func (r *Request) SetMatcher(matcher Matcher) *Request {
- r.Mock.SetMatcher(matcher)
- return r
-}
-
-// Map adds a new request mapper function to map http.Request before the matching process.
-func (r *Request) Map(fn MapRequestFunc) *Request {
- r.Mappers = append(r.Mappers, fn)
- return r
-}
-
-// Filter filters a new request filter function to filter http.Request before the matching process.
-func (r *Request) Filter(fn FilterRequestFunc) *Request {
- r.Filters = append(r.Filters, fn)
- return r
-}
-
-// EnableNetworking enables the use real networking for the current mock.
-func (r *Request) EnableNetworking() *Request {
- if r.Response != nil {
- r.Response.UseNetwork = true
- }
- return r
-}
-
-// Reply defines the Response status code and returns the mock Response DSL.
-func (r *Request) Reply(status int) *Response {
- return r.Response.Status(status)
-}
-
-// ReplyError defines the Response simulated error.
-func (r *Request) ReplyError(err error) *Response {
- return r.Response.SetError(err)
-}
-
-// ReplyFunc allows the developer to define the mock response via a custom function.
-func (r *Request) ReplyFunc(replier func(*Response)) *Response {
- replier(r.Response)
- return r.Response
-}
-
-// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
-// "To receive authorization, the client sends the userid and password,
-// separated by a single colon (":") character, within a base64
-// encoded string in the credentials."
-// It is not meant to be urlencoded.
-func basicAuth(username, password string) string {
- auth := username + ":" + password
- return base64.StdEncoding.EncodeToString([]byte(auth))
+ return g.NewRequest()
}
diff --git a/request_test.go b/request_test.go
index 463e784..011a0f9 100644
--- a/request_test.go
+++ b/request_test.go
@@ -266,7 +266,6 @@ func TestRequestAddMatcher(t *testing.T) {
ereq := NewRequest()
mock := NewMock(ereq, &Response{})
- mock.matcher = NewMatcher()
ereq.Mock = mock
ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
diff --git a/responder.go b/responder.go
index f0f16bb..0ec2de5 100644
--- a/responder.go
+++ b/responder.go
@@ -1,111 +1,12 @@
package gock
import (
- "bytes"
- "io"
- "io/ioutil"
"net/http"
- "strconv"
- "time"
+
+ "github.com/h2non/gock/threadsafe"
)
// Responder builds a mock http.Response based on the given Response mock.
func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) {
- // If error present, reply it
- err := mock.Error
- if err != nil {
- return nil, err
- }
-
- if res == nil {
- res = createResponse(req)
- }
-
- // Apply response filter
- for _, filter := range mock.Filters {
- if !filter(res) {
- return res, nil
- }
- }
-
- // Define mock status code
- if mock.StatusCode != 0 {
- res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode)
- res.StatusCode = mock.StatusCode
- }
-
- // Define headers by merging fields
- res.Header = mergeHeaders(res, mock)
-
- // Define mock body, if present
- if len(mock.BodyBuffer) > 0 {
- res.ContentLength = int64(len(mock.BodyBuffer))
- res.Body = createReadCloser(mock.BodyBuffer)
- }
-
- // Set raw mock body, if exist
- if mock.BodyGen != nil {
- res.ContentLength = -1
- res.Body = mock.BodyGen()
- }
-
- // Apply response mappers
- for _, mapper := range mock.Mappers {
- if tres := mapper(res); tres != nil {
- res = tres
- }
- }
-
- // Sleep to simulate delay, if necessary
- if mock.ResponseDelay > 0 {
- // allow escaping from sleep due to request context expiration or cancellation
- t := time.NewTimer(mock.ResponseDelay)
- select {
- case <-t.C:
- case <-req.Context().Done():
- // cleanly stop the timer
- if !t.Stop() {
- <-t.C
- }
- }
- }
-
- // check if the request context has ended. we could put this up in the delay code above, but putting it here
- // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.)
- if err = req.Context().Err(); err != nil {
- // cleanly close the response and return the context error
- io.Copy(ioutil.Discard, res.Body)
- res.Body.Close()
- return nil, err
- }
-
- return res, err
-}
-
-// createResponse creates a new http.Response with default fields.
-func createResponse(req *http.Request) *http.Response {
- return &http.Response{
- ProtoMajor: 1,
- ProtoMinor: 1,
- Proto: "HTTP/1.1",
- Request: req,
- Header: make(http.Header),
- Body: createReadCloser([]byte{}),
- }
-}
-
-// mergeHeaders copies the mock headers.
-func mergeHeaders(res *http.Response, mres *Response) http.Header {
- for key, values := range mres.Header {
- for _, value := range values {
- res.Header.Add(key, value)
- }
- }
- return res.Header
-}
-
-// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an
-// http response body.
-func createReadCloser(body []byte) io.ReadCloser {
- return ioutil.NopCloser(bytes.NewReader(body))
+ return threadsafe.Responder(req, mock, res)
}
diff --git a/response.go b/response.go
index 3e62b9e..0eeb314 100644
--- a/response.go
+++ b/response.go
@@ -1,196 +1,20 @@
package gock
import (
- "bytes"
- "encoding/json"
- "encoding/xml"
- "io"
- "io/ioutil"
- "net/http"
- "time"
+ "github.com/h2non/gock/threadsafe"
)
// MapResponseFunc represents the required function interface impletemed by response mappers.
-type MapResponseFunc func(*http.Response) *http.Response
+type MapResponseFunc = threadsafe.MapResponseFunc
// FilterResponseFunc represents the required function interface impletemed by response filters.
-type FilterResponseFunc func(*http.Response) bool
+type FilterResponseFunc = threadsafe.FilterResponseFunc
// Response represents high-level HTTP fields to configure
// and define HTTP responses intercepted by gock.
-type Response struct {
- // Mock stores the parent mock reference for the current response mock used for method delegation.
- Mock Mock
-
- // Error stores the latest response configuration or injected error.
- Error error
-
- // UseNetwork enables the use of real network for the current mock.
- UseNetwork bool
-
- // StatusCode stores the response status code.
- StatusCode int
-
- // Headers stores the response headers.
- Header http.Header
-
- // Cookies stores the response cookie fields.
- Cookies []*http.Cookie
-
- // BodyGen stores a io.ReadCloser generator to be returned.
- BodyGen func() io.ReadCloser
-
- // BodyBuffer stores the array of bytes to use as body.
- BodyBuffer []byte
-
- // ResponseDelay stores the simulated response delay.
- ResponseDelay time.Duration
-
- // Mappers stores the request functions mappers used for matching.
- Mappers []MapResponseFunc
-
- // Filters stores the request functions filters used for matching.
- Filters []FilterResponseFunc
-}
+type Response = threadsafe.Response
// NewResponse creates a new Response.
func NewResponse() *Response {
- return &Response{Header: make(http.Header)}
-}
-
-// Status defines the desired HTTP status code to reply in the current response.
-func (r *Response) Status(code int) *Response {
- r.StatusCode = code
- return r
-}
-
-// Type defines the response Content-Type MIME header field.
-// Supports type alias. E.g: json, xml, form, text...
-func (r *Response) Type(kind string) *Response {
- mime := BodyTypeAliases[kind]
- if mime != "" {
- kind = mime
- }
- r.Header.Set("Content-Type", kind)
- return r
-}
-
-// SetHeader sets a new header field in the mock response.
-func (r *Response) SetHeader(key, value string) *Response {
- r.Header.Set(key, value)
- return r
-}
-
-// AddHeader adds a new header field in the mock response
-// with out removing an existent one.
-func (r *Response) AddHeader(key, value string) *Response {
- r.Header.Add(key, value)
- return r
-}
-
-// SetHeaders sets a map of header fields in the mock response.
-func (r *Response) SetHeaders(headers map[string]string) *Response {
- for key, value := range headers {
- r.Header.Add(key, value)
- }
- return r
-}
-
-// Body sets the HTTP response body to be used.
-func (r *Response) Body(body io.Reader) *Response {
- r.BodyBuffer, r.Error = ioutil.ReadAll(body)
- return r
-}
-
-// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser
-// for every response. This will take priority than other Body methods used.
-func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response {
- r.BodyGen = generator
- return r
-}
-
-// BodyString defines the response body as string.
-func (r *Response) BodyString(body string) *Response {
- r.BodyBuffer = []byte(body)
- return r
-}
-
-// File defines the response body reading the data
-// from disk based on the file path string.
-func (r *Response) File(path string) *Response {
- r.BodyBuffer, r.Error = ioutil.ReadFile(path)
- return r
-}
-
-// JSON defines the response body based on a JSON based input.
-func (r *Response) JSON(data interface{}) *Response {
- r.Header.Set("Content-Type", "application/json")
- r.BodyBuffer, r.Error = readAndDecode(data, "json")
- return r
-}
-
-// XML defines the response body based on a XML based input.
-func (r *Response) XML(data interface{}) *Response {
- r.Header.Set("Content-Type", "application/xml")
- r.BodyBuffer, r.Error = readAndDecode(data, "xml")
- return r
-}
-
-// SetError defines the response simulated error.
-func (r *Response) SetError(err error) *Response {
- r.Error = err
- return r
-}
-
-// Delay defines the response simulated delay.
-// This feature is still experimental and will be improved in the future.
-func (r *Response) Delay(delay time.Duration) *Response {
- r.ResponseDelay = delay
- return r
-}
-
-// Map adds a new response mapper function to map http.Response before the matching process.
-func (r *Response) Map(fn MapResponseFunc) *Response {
- r.Mappers = append(r.Mappers, fn)
- return r
-}
-
-// Filter filters a new request filter function to filter http.Request before the matching process.
-func (r *Response) Filter(fn FilterResponseFunc) *Response {
- r.Filters = append(r.Filters, fn)
- return r
-}
-
-// EnableNetworking enables the use real networking for the current mock.
-func (r *Response) EnableNetworking() *Response {
- r.UseNetwork = true
- return r
-}
-
-// Done returns true if the mock was done and disabled.
-func (r *Response) Done() bool {
- return r.Mock.Done()
-}
-
-func readAndDecode(data interface{}, kind string) ([]byte, error) {
- buf := &bytes.Buffer{}
-
- switch data.(type) {
- case string:
- buf.WriteString(data.(string))
- case []byte:
- buf.Write(data.([]byte))
- default:
- var err error
- if kind == "xml" {
- err = xml.NewEncoder(buf).Encode(data)
- } else {
- err = json.NewEncoder(buf).Encode(data)
- }
- if err != nil {
- return nil, err
- }
- }
-
- return ioutil.ReadAll(buf)
+ return g.NewResponse()
}
diff --git a/response_test.go b/response_test.go
index 412ca53..c27781a 100644
--- a/response_test.go
+++ b/response_test.go
@@ -158,7 +158,7 @@ func TestResponseEnableNetworking(t *testing.T) {
func TestResponseDone(t *testing.T) {
res := NewResponse()
- res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)}
+ res.Mock = NewMock(&Request{Counter: 1}, res)
st.Expect(t, res.Done(), false)
res.Mock.Disable()
st.Expect(t, res.Done(), true)
diff --git a/store.go b/store.go
index 7ed1316..adc5021 100644
--- a/store.go
+++ b/store.go
@@ -1,100 +1,46 @@
package gock
-import (
- "sync"
-)
-
-// storeMutex is used interally for store synchronization.
-var storeMutex = sync.RWMutex{}
-
-// mocks is internally used to store registered mocks.
-var mocks = []Mock{}
-
// Register registers a new mock in the current mocks stack.
func Register(mock Mock) {
- if Exists(mock) {
- return
- }
-
- // Make ops thread safe
- storeMutex.Lock()
- defer storeMutex.Unlock()
-
- // Expose mock in request/response for delegation
- mock.Request().Mock = mock
- mock.Response().Mock = mock
-
- // Registers the mock in the global store
- mocks = append(mocks, mock)
+ g.Register(mock)
}
// GetAll returns the current stack of registered mocks.
func GetAll() []Mock {
- storeMutex.RLock()
- defer storeMutex.RUnlock()
- return mocks
+ return g.GetAll()
}
// Exists checks if the given Mock is already registered.
func Exists(m Mock) bool {
- storeMutex.RLock()
- defer storeMutex.RUnlock()
- for _, mock := range mocks {
- if mock == m {
- return true
- }
- }
- return false
+ return g.Exists(m)
}
// Remove removes a registered mock by reference.
func Remove(m Mock) {
- for i, mock := range mocks {
- if mock == m {
- storeMutex.Lock()
- mocks = append(mocks[:i], mocks[i+1:]...)
- storeMutex.Unlock()
- }
- }
+ g.Remove(m)
}
// Flush flushes the current stack of registered mocks.
func Flush() {
- storeMutex.Lock()
- defer storeMutex.Unlock()
- mocks = []Mock{}
+ g.Flush()
}
// Pending returns an slice of pending mocks.
func Pending() []Mock {
- Clean()
- storeMutex.RLock()
- defer storeMutex.RUnlock()
- return mocks
+ return g.Pending()
}
// IsDone returns true if all the registered mocks has been triggered successfully.
func IsDone() bool {
- return !IsPending()
+ return g.IsDone()
}
// IsPending returns true if there are pending mocks.
func IsPending() bool {
- return len(Pending()) > 0
+ return g.IsPending()
}
// Clean cleans the mocks store removing disabled or obsolete mocks.
func Clean() {
- storeMutex.Lock()
- defer storeMutex.Unlock()
-
- buf := []Mock{}
- for _, mock := range mocks {
- if mock.Done() {
- continue
- }
- buf = append(buf, mock)
- }
-
- mocks = buf
+ g.Clean()
}
diff --git a/store_test.go b/store_test.go
index 4ab4c83..b40a078 100644
--- a/store_test.go
+++ b/store_test.go
@@ -8,36 +8,36 @@ import (
func TestStoreRegister(t *testing.T) {
defer after()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
mock := New("foo").Mock
Register(mock)
- st.Expect(t, len(mocks), 1)
+ st.Expect(t, len(GetAll()), 1)
st.Expect(t, mock.Request().Mock, mock)
st.Expect(t, mock.Response().Mock, mock)
}
func TestStoreGetAll(t *testing.T) {
defer after()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
mock := New("foo").Mock
store := GetAll()
- st.Expect(t, len(mocks), 1)
+ st.Expect(t, len(GetAll()), 1)
st.Expect(t, len(store), 1)
st.Expect(t, store[0], mock)
}
func TestStoreExists(t *testing.T) {
defer after()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
mock := New("foo").Mock
- st.Expect(t, len(mocks), 1)
+ st.Expect(t, len(GetAll()), 1)
st.Expect(t, Exists(mock), true)
}
func TestStorePending(t *testing.T) {
defer after()
New("foo")
- st.Expect(t, mocks, Pending())
+ st.Expect(t, GetAll(), Pending())
}
func TestStoreIsPending(t *testing.T) {
@@ -58,9 +58,9 @@ func TestStoreIsDone(t *testing.T) {
func TestStoreRemove(t *testing.T) {
defer after()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
mock := New("foo").Mock
- st.Expect(t, len(mocks), 1)
+ st.Expect(t, len(GetAll()), 1)
st.Expect(t, Exists(mock), true)
Remove(mock)
@@ -72,16 +72,16 @@ func TestStoreRemove(t *testing.T) {
func TestStoreFlush(t *testing.T) {
defer after()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
mock1 := New("foo").Mock
mock2 := New("foo").Mock
- st.Expect(t, len(mocks), 2)
+ st.Expect(t, len(GetAll()), 2)
st.Expect(t, Exists(mock1), true)
st.Expect(t, Exists(mock2), true)
Flush()
- st.Expect(t, len(mocks), 0)
+ st.Expect(t, len(GetAll()), 0)
st.Expect(t, Exists(mock1), false)
st.Expect(t, Exists(mock2), false)
}
diff --git a/threadsafe/gock.go b/threadsafe/gock.go
new file mode 100644
index 0000000..7df7b9a
--- /dev/null
+++ b/threadsafe/gock.go
@@ -0,0 +1,270 @@
+package threadsafe
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "regexp"
+ "sync"
+)
+
+type Gock struct {
+ // mutex is used internally for locking thread-sensitive functions.
+ mutex sync.Mutex
+ // config global singleton store.
+ config struct {
+ Networking bool
+ NetworkingFilters []FilterRequestFunc
+ Observer ObserverFunc
+ }
+ // DumpRequest is a default implementation of ObserverFunc that dumps
+ // the HTTP/1.x wire representation of the http request
+ DumpRequest ObserverFunc
+ // track unmatched requests so they can be tested for
+ unmatchedRequests []*http.Request
+
+ // storeMutex is used internally for store synchronization.
+ storeMutex sync.RWMutex
+
+ // mocks is internally used to store registered mocks.
+ mocks []Mock
+
+ // DefaultMatcher stores the default Matcher instance used to match mocks.
+ DefaultMatcher *MockMatcher
+
+ // MatchersHeader exposes a slice of HTTP header specific mock matchers.
+ MatchersHeader []MatchFunc
+ // MatchersBody exposes a slice of HTTP body specific built-in mock matchers.
+ MatchersBody []MatchFunc
+ // Matchers stores all the built-in mock matchers.
+ Matchers []MatchFunc
+
+ // BodyTypes stores the supported MIME body types for matching.
+ // Currently only text-based types.
+ BodyTypes []string
+
+ // BodyTypeAliases stores a generic MIME type by alias.
+ BodyTypeAliases map[string]string
+
+ // CompressionSchemes stores the supported Content-Encoding types for decompression.
+ CompressionSchemes []string
+
+ intercepting bool
+
+ DisableCallback func()
+ InterceptCallback func()
+ InterceptingCallback func() bool
+}
+
+func NewGock() *Gock {
+ g := &Gock{
+ DumpRequest: defaultDumpRequest,
+
+ BodyTypes: []string{
+ "text/html",
+ "text/plain",
+ "application/json",
+ "application/xml",
+ "multipart/form-data",
+ "application/x-www-form-urlencoded",
+ },
+
+ BodyTypeAliases: map[string]string{
+ "html": "text/html",
+ "text": "text/plain",
+ "json": "application/json",
+ "xml": "application/xml",
+ "form": "multipart/form-data",
+ "url": "application/x-www-form-urlencoded",
+ },
+
+ // CompressionSchemes stores the supported Content-Encoding types for decompression.
+ CompressionSchemes: []string{
+ "gzip",
+ },
+ }
+ g.MatchersHeader = []MatchFunc{
+ g.MatchMethod,
+ g.MatchScheme,
+ g.MatchHost,
+ g.MatchPath,
+ g.MatchHeaders,
+ g.MatchQueryParams,
+ g.MatchPathParams,
+ }
+ g.MatchersBody = []MatchFunc{
+ g.MatchBody,
+ }
+ g.Matchers = append(g.MatchersHeader, g.MatchersBody...)
+
+ // DefaultMatcher stores the default Matcher instance used to match mocks.
+ g.DefaultMatcher = g.NewMatcher()
+ return g
+}
+
+// ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic
+type ObserverFunc func(*http.Request, Mock)
+
+func defaultDumpRequest(request *http.Request, mock Mock) {
+ bytes, _ := httputil.DumpRequestOut(request, true)
+ fmt.Println(string(bytes))
+ fmt.Printf("\nMatches: %v\n---\n", mock != nil)
+}
+
+// New creates and registers a new HTTP mock with
+// default settings and returns the Request DSL for HTTP mock
+// definition and set up.
+func (g *Gock) New(uri string) *Request {
+ g.Intercept()
+
+ res := g.NewResponse()
+ req := g.NewRequest()
+ req.URLStruct, res.Error = url.Parse(normalizeURI(uri))
+
+ // Create the new mock expectation
+ exp := g.NewMock(req, res)
+ g.Register(exp)
+
+ return req
+}
+
+// Intercepting returns true if gock is currently able to intercept.
+func (g *Gock) Intercepting() bool {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+
+ callbackResponse := true
+ if g.InterceptingCallback != nil {
+ callbackResponse = g.InterceptingCallback()
+ }
+
+ return g.intercepting && callbackResponse
+}
+
+// Intercept enables HTTP traffic interception via http.DefaultTransport.
+// If you are using a custom HTTP transport, you have to use `gock.Transport()`
+func (g *Gock) Intercept() {
+ if !g.Intercepting() {
+ g.mutex.Lock()
+ g.intercepting = true
+
+ if g.InterceptCallback != nil {
+ g.InterceptCallback()
+ }
+
+ g.mutex.Unlock()
+ }
+}
+
+// InterceptClient allows the developer to intercept HTTP traffic using
+// a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation.
+func (g *Gock) InterceptClient(cli *http.Client) {
+ _, ok := cli.Transport.(*Transport)
+ if ok {
+ return // if transport already intercepted, just ignore it
+ }
+ cli.Transport = g.NewTransport(cli.Transport)
+}
+
+// RestoreClient allows the developer to disable and restore the
+// original transport in the given http.Client.
+func (g *Gock) RestoreClient(cli *http.Client) {
+ trans, ok := cli.Transport.(*Transport)
+ if !ok {
+ return
+ }
+ cli.Transport = trans.Transport
+}
+
+// Disable disables HTTP traffic interception by gock.
+func (g *Gock) Disable() {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.intercepting = false
+
+ if g.DisableCallback != nil {
+ g.DisableCallback()
+ }
+}
+
+// Off disables the default HTTP interceptors and removes
+// all the registered mocks, even if they has not been intercepted yet.
+func (g *Gock) Off() {
+ g.Flush()
+ g.Disable()
+}
+
+// OffAll is like `Off()`, but it also removes the unmatched requests registry.
+func (g *Gock) OffAll() {
+ g.Flush()
+ g.Disable()
+ g.CleanUnmatchedRequest()
+}
+
+// Observe provides a hook to support inspection of the request and matched mock
+func (g *Gock) Observe(fn ObserverFunc) {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.config.Observer = fn
+}
+
+// EnableNetworking enables real HTTP networking
+func (g *Gock) EnableNetworking() {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.config.Networking = true
+}
+
+// DisableNetworking disables real HTTP networking
+func (g *Gock) DisableNetworking() {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.config.Networking = false
+}
+
+// NetworkingFilter determines if an http.Request should be triggered or not.
+func (g *Gock) NetworkingFilter(fn FilterRequestFunc) {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.config.NetworkingFilters = append(g.config.NetworkingFilters, fn)
+}
+
+// DisableNetworkingFilters disables registered networking filters.
+func (g *Gock) DisableNetworkingFilters() {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.config.NetworkingFilters = []FilterRequestFunc{}
+}
+
+// GetUnmatchedRequests returns all requests that have been received but haven't matched any mock
+func (g *Gock) GetUnmatchedRequests() []*http.Request {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ return g.unmatchedRequests
+}
+
+// HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock
+func (g *Gock) HasUnmatchedRequest() bool {
+ return len(g.GetUnmatchedRequests()) > 0
+}
+
+// CleanUnmatchedRequest cleans the unmatched requests internal registry.
+func (g *Gock) CleanUnmatchedRequest() {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.unmatchedRequests = []*http.Request{}
+}
+
+func (g *Gock) trackUnmatchedRequest(req *http.Request) {
+ g.mutex.Lock()
+ defer g.mutex.Unlock()
+ g.unmatchedRequests = append(g.unmatchedRequests, req)
+}
+
+func normalizeURI(uri string) string {
+ if ok, _ := regexp.MatchString("^http[s]?", uri); !ok {
+ return "http://" + uri
+ }
+ return uri
+}
diff --git a/threadsafe/gock_test.go b/threadsafe/gock_test.go
new file mode 100644
index 0000000..a503a28
--- /dev/null
+++ b/threadsafe/gock_test.go
@@ -0,0 +1,517 @@
+package threadsafe
+
+import (
+ "bytes"
+ "compress/gzip"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestMockSimple(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get("http://foo.com")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+}
+
+func TestMockOff(t *testing.T) {
+ g := NewGock()
+ g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"})
+ g.Off()
+ c := &http.Client{}
+ g.InterceptClient(c)
+ _, err := c.Get("http://127.0.0.1:3123")
+ st.Reject(t, err, nil)
+}
+
+func TestMockBodyStringResponse(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Reply(200).BodyString("foo bar")
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get("http://foo.com")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo bar")
+}
+
+func TestMockBodyMatch(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").BodyString("foo bar").Reply(201).BodyString("foo foo")
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar")))
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo foo")
+}
+
+func TestMockBodyCannotMatch(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").BodyString("foo foo").Reply(201).BodyString("foo foo")
+ c := &http.Client{}
+ g.InterceptClient(c)
+ _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar")))
+ st.Reject(t, err, nil)
+}
+
+func TestMockBodyMatchCompressed(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo")
+
+ var compressed bytes.Buffer
+ w := gzip.NewWriter(&compressed)
+ w.Write([]byte("foo bar"))
+ w.Close()
+ c := &http.Client{}
+ g.InterceptClient(c)
+ req, err := http.NewRequest("POST", "http://foo.com", &compressed)
+ st.Expect(t, err, nil)
+ req.Header.Set("Content-Encoding", "gzip")
+ req.Header.Set("Content-Type", "text/plain")
+ res, err := c.Do(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo foo")
+}
+
+func TestMockBodyCannotMatchCompressed(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo")
+ c := &http.Client{}
+ g.InterceptClient(c)
+ _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar")))
+ st.Reject(t, err, nil)
+}
+
+func TestMockBodyMatchJSON(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Post("/bar").
+ JSON(map[string]string{"foo": "bar"}).
+ Reply(201).
+ JSON(map[string]string{"bar": "foo"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`)))
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"bar":"foo"}`)
+}
+
+func TestMockBodyCannotMatchJSON(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Post("/bar").
+ JSON(map[string]string{"bar": "bar"}).
+ Reply(201).
+ JSON(map[string]string{"bar": "foo"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ _, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`)))
+ st.Reject(t, err, nil)
+}
+
+func TestMockBodyMatchCompressedJSON(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Post("/bar").
+ Compression("gzip").
+ JSON(map[string]string{"foo": "bar"}).
+ Reply(201).
+ JSON(map[string]string{"bar": "foo"})
+
+ var compressed bytes.Buffer
+ w := gzip.NewWriter(&compressed)
+ w.Write([]byte(`{"foo":"bar"}`))
+ w.Close()
+ c := &http.Client{}
+ g.InterceptClient(c)
+ req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed)
+ st.Expect(t, err, nil)
+ req.Header.Set("Content-Encoding", "gzip")
+ req.Header.Set("Content-Type", "application/json")
+ res, err := c.Do(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"bar":"foo"}`)
+}
+
+func TestMockBodyCannotMatchCompressedJSON(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Post("/bar").
+ JSON(map[string]string{"bar": "bar"}).
+ Reply(201).
+ JSON(map[string]string{"bar": "foo"})
+
+ var compressed bytes.Buffer
+ w := gzip.NewWriter(&compressed)
+ w.Write([]byte(`{"foo":"bar"}`))
+ w.Close()
+ c := &http.Client{}
+ g.InterceptClient(c)
+ req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed)
+ st.Expect(t, err, nil)
+ req.Header.Set("Content-Encoding", "gzip")
+ req.Header.Set("Content-Type", "application/json")
+ _, err = c.Do(req)
+ st.Reject(t, err, nil)
+}
+
+func TestMockMatchHeaders(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ MatchHeader("Content-Type", "(.*)/plain").
+ Reply(200).
+ BodyString("foo foo")
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar")))
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo foo")
+}
+
+func TestMockMap(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ mock := g.New("http://bar.com")
+ mock.Map(func(req *http.Request) *http.Request {
+ req.URL.Host = "bar.com"
+ return req
+ })
+ mock.Reply(201).JSON(map[string]string{"foo": "bar"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get("http://foo.com")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+}
+
+func TestMockFilter(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ mock := g.New("http://foo.com")
+ mock.Filter(func(req *http.Request) bool {
+ return req.URL.Host == "foo.com"
+ })
+ mock.Reply(201).JSON(map[string]string{"foo": "bar"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get("http://foo.com")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 201)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+}
+
+func TestMockCounterDisabled(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Reply(204)
+ st.Expect(t, len(g.GetAll()), 1)
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get("http://foo.com")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 204)
+ st.Expect(t, len(g.GetAll()), 0)
+}
+
+func TestMockEnableNetwork(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, world")
+ }))
+ defer ts.Close()
+
+ g.EnableNetworking()
+ defer g.DisableNetworking()
+
+ g.New(ts.URL).Reply(204)
+ st.Expect(t, len(g.GetAll()), 1)
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get(ts.URL)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 204)
+ st.Expect(t, len(g.GetAll()), 0)
+
+ res, err = c.Get(ts.URL)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+}
+
+func TestMockEnableNetworkFilter(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, world")
+ }))
+ defer ts.Close()
+
+ g.EnableNetworking()
+ defer g.DisableNetworking()
+
+ g.NetworkingFilter(func(req *http.Request) bool {
+ return strings.Contains(req.URL.Host, "127.0.0.1")
+ })
+ defer g.DisableNetworkingFilters()
+
+ g.New(ts.URL).Reply(0).SetHeader("foo", "bar")
+ st.Expect(t, len(g.GetAll()), 1)
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Get(ts.URL)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ st.Expect(t, res.Header.Get("foo"), "bar")
+ st.Expect(t, len(g.GetAll()), 0)
+}
+
+func TestMockPersistent(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Get("/bar").
+ Persist().
+ Reply(200).
+ JSON(map[string]string{"foo": "bar"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ for i := 0; i < 5; i++ {
+ res, err := c.Get("http://foo.com/bar")
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+ }
+}
+
+func TestMockPersistTimes(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://127.0.0.1:1234").
+ Get("/bar").
+ Times(4).
+ Reply(200).
+ JSON(map[string]string{"foo": "bar"})
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ for i := 0; i < 5; i++ {
+ res, err := c.Get("http://127.0.0.1:1234/bar")
+ if i == 4 {
+ st.Reject(t, err, nil)
+ break
+ }
+
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+ }
+}
+
+func TestUnmatched(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ // clear out any unmatchedRequests from other tests
+ g.unmatchedRequests = []*http.Request{}
+
+ g.Intercept()
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ _, err := c.Get("http://server.com/unmatched")
+ st.Reject(t, err, nil)
+
+ unmatched := g.GetUnmatchedRequests()
+ st.Expect(t, len(unmatched), 1)
+ st.Expect(t, unmatched[0].URL.Host, "server.com")
+ st.Expect(t, unmatched[0].URL.Path, "/unmatched")
+ st.Expect(t, g.HasUnmatchedRequest(), true)
+}
+
+func TestMultipleMocks(t *testing.T) {
+ g := NewGock()
+ defer g.Disable()
+
+ g.New("http://server.com").
+ Get("/foo").
+ Reply(200).
+ JSON(map[string]string{"value": "foo"})
+
+ g.New("http://server.com").
+ Get("/bar").
+ Reply(200).
+ JSON(map[string]string{"value": "bar"})
+
+ g.New("http://server.com").
+ Get("/baz").
+ Reply(200).
+ JSON(map[string]string{"value": "baz"})
+
+ tests := []struct {
+ path string
+ }{
+ {"/foo"},
+ {"/bar"},
+ {"/baz"},
+ }
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ for _, test := range tests {
+ res, err := c.Get("http://server.com" + test.path)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:15], `{"value":"`+test.path[1:]+`"}`)
+ }
+
+ _, err := c.Get("http://server.com/foo")
+ st.Reject(t, err, nil)
+}
+
+func TestInterceptClient(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ g.New("http://foo.com").Reply(204)
+ st.Expect(t, len(g.GetAll()), 1)
+
+ req, err := http.NewRequest("GET", "http://foo.com", nil)
+ client := &http.Client{Transport: &http.Transport{}}
+ g.InterceptClient(client)
+
+ res, err := client.Do(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 204)
+}
+
+func TestRestoreClient(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ g.New("http://foo.com").Reply(204)
+ st.Expect(t, len(g.GetAll()), 1)
+
+ req, err := http.NewRequest("GET", "http://foo.com", nil)
+ client := &http.Client{Transport: &http.Transport{}}
+ g.InterceptClient(client)
+ trans := client.Transport
+
+ res, err := client.Do(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 204)
+
+ g.RestoreClient(client)
+ st.Reject(t, trans, client.Transport)
+}
+
+func TestMockRegExpMatching(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").
+ Post("/bar").
+ MatchHeader("Authorization", "Bearer (.*)").
+ BodyString(`{"foo":".*"}`).
+ Reply(200).
+ SetHeader("Server", "gock").
+ JSON(map[string]string{"foo": "bar"})
+
+ req, _ := http.NewRequest("POST", "http://foo.com/bar", bytes.NewBuffer([]byte(`{"foo":"baz"}`)))
+ req.Header.Set("Authorization", "Bearer s3cr3t")
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ res, err := c.Do(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 200)
+ st.Expect(t, res.Header.Get("Server"), "gock")
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body)[:13], `{"foo":"bar"}`)
+}
+
+func TestObserve(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ var observedRequest *http.Request
+ var observedMock Mock
+ g.Observe(func(request *http.Request, mock Mock) {
+ observedRequest = request
+ observedMock = mock
+ })
+ g.New("http://observe-foo.com").Reply(200)
+ req, _ := http.NewRequest("POST", "http://observe-foo.com", nil)
+
+ c := &http.Client{}
+ g.InterceptClient(c)
+ c.Do(req)
+
+ st.Expect(t, observedRequest.Host, "observe-foo.com")
+ st.Expect(t, observedMock.Request().URLStruct.Host, "observe-foo.com")
+}
+
+func TestTryCreatingRacesInNew(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ for i := 0; i < 10; i++ {
+ go func() {
+ g.New("http://example.com")
+ }()
+ }
+}
+
+func after(g *Gock) {
+ g.Flush()
+ g.Disable()
+}
diff --git a/threadsafe/matcher.go b/threadsafe/matcher.go
new file mode 100644
index 0000000..ad10a84
--- /dev/null
+++ b/threadsafe/matcher.go
@@ -0,0 +1,116 @@
+package threadsafe
+
+import "net/http"
+
+// MatchFunc represents the required function
+// interface implemented by matchers.
+type MatchFunc func(*http.Request, *Request) (bool, error)
+
+// Matcher represents the required interface implemented by mock matchers.
+type Matcher interface {
+ // Get returns a slice of registered function matchers.
+ Get() []MatchFunc
+
+ // Add adds a new matcher function.
+ Add(MatchFunc)
+
+ // Set sets the matchers functions stack.
+ Set([]MatchFunc)
+
+ // Flush flushes the current matchers function stack.
+ Flush()
+
+ // Match matches the given http.Request with a mock Request.
+ Match(*http.Request, *Request) (bool, error)
+}
+
+// MockMatcher implements a mock matcher
+type MockMatcher struct {
+ Matchers []MatchFunc
+ g *Gock
+}
+
+// NewMatcher creates a new mock matcher
+// using the default matcher functions.
+func (g *Gock) NewMatcher() *MockMatcher {
+ m := g.NewEmptyMatcher()
+ for _, matchFn := range g.Matchers {
+ m.Add(matchFn)
+ }
+ return m
+}
+
+// NewBasicMatcher creates a new matcher with header only mock matchers.
+func (g *Gock) NewBasicMatcher() *MockMatcher {
+ m := g.NewEmptyMatcher()
+ for _, matchFn := range g.MatchersHeader {
+ m.Add(matchFn)
+ }
+ return m
+}
+
+// NewEmptyMatcher creates a new empty matcher without default matchers.
+func (g *Gock) NewEmptyMatcher() *MockMatcher {
+ return &MockMatcher{g: g, Matchers: []MatchFunc{}}
+}
+
+// Get returns a slice of registered function matchers.
+func (m *MockMatcher) Get() []MatchFunc {
+ m.g.mutex.Lock()
+ defer m.g.mutex.Unlock()
+ return m.Matchers
+}
+
+// Add adds a new function matcher.
+func (m *MockMatcher) Add(fn MatchFunc) {
+ m.Matchers = append(m.Matchers, fn)
+}
+
+// Set sets a new stack of matchers functions.
+func (m *MockMatcher) Set(stack []MatchFunc) {
+ m.Matchers = stack
+}
+
+// Flush flushes the current matcher
+func (m *MockMatcher) Flush() {
+ m.Matchers = []MatchFunc{}
+}
+
+// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs
+func (m *MockMatcher) Clone() *MockMatcher {
+ m2 := m.g.NewEmptyMatcher()
+ for _, mFn := range m.Get() {
+ m2.Add(mFn)
+ }
+ return m2
+}
+
+// Match matches the given http.Request with a mock request
+// returning true in case that the request matches, otherwise false.
+func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) {
+ for _, matcher := range m.Matchers {
+ matches, err := matcher(req, ereq)
+ if err != nil {
+ return false, err
+ }
+ if !matches {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+// MatchMock is a helper function that matches the given http.Request
+// in the list of registered mocks, returning it if matches or error if it fails.
+func (g *Gock) MatchMock(req *http.Request) (Mock, error) {
+ for _, mock := range g.GetAll() {
+ matches, err := mock.Match(req)
+ if err != nil {
+ return nil, err
+ }
+ if matches {
+ return mock, nil
+ }
+ }
+ return nil, nil
+}
diff --git a/threadsafe/matcher_test.go b/threadsafe/matcher_test.go
new file mode 100644
index 0000000..c33a475
--- /dev/null
+++ b/threadsafe/matcher_test.go
@@ -0,0 +1,172 @@
+package threadsafe
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestRegisteredMatchers(t *testing.T) {
+ g := NewGock()
+ st.Expect(t, len(g.MatchersHeader), 7)
+ st.Expect(t, len(g.MatchersBody), 1)
+}
+
+func TestNewMatcher(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewMatcher()
+ // Funcs are not comparable, checking slice length as it's better than nothing
+ // See https://golang.org/pkg/reflect/#DeepEqual
+ st.Expect(t, len(matcher.Matchers), len(g.Matchers))
+ st.Expect(t, len(matcher.Get()), len(g.Matchers))
+}
+
+func TestNewBasicMatcher(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewBasicMatcher()
+ // Funcs are not comparable, checking slice length as it's better than nothing
+ // See https://golang.org/pkg/reflect/#DeepEqual
+ st.Expect(t, len(matcher.Matchers), len(g.MatchersHeader))
+ st.Expect(t, len(matcher.Get()), len(g.MatchersHeader))
+}
+
+func TestNewEmptyMatcher(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewEmptyMatcher()
+ st.Expect(t, len(matcher.Matchers), 0)
+ st.Expect(t, len(matcher.Get()), 0)
+}
+
+func TestMatcherAdd(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewMatcher()
+ st.Expect(t, len(matcher.Matchers), len(g.Matchers))
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return true, nil
+ })
+ st.Expect(t, len(matcher.Get()), len(g.Matchers)+1)
+}
+
+func TestMatcherSet(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewMatcher()
+ matchers := []MatchFunc{}
+ st.Expect(t, len(matcher.Matchers), len(g.Matchers))
+ matcher.Set(matchers)
+ st.Expect(t, matcher.Matchers, matchers)
+ st.Expect(t, len(matcher.Get()), 0)
+}
+
+func TestMatcherGet(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewMatcher()
+ matchers := []MatchFunc{}
+ matcher.Set(matchers)
+ st.Expect(t, matcher.Get(), matchers)
+}
+
+func TestMatcherFlush(t *testing.T) {
+ g := NewGock()
+ matcher := g.NewMatcher()
+ st.Expect(t, len(matcher.Matchers), len(g.Matchers))
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return true, nil
+ })
+ st.Expect(t, len(matcher.Get()), len(g.Matchers)+1)
+ matcher.Flush()
+ st.Expect(t, len(matcher.Get()), 0)
+}
+
+func TestMatcherClone(t *testing.T) {
+ g := NewGock()
+ matcher := g.DefaultMatcher.Clone()
+ st.Expect(t, len(matcher.Get()), len(g.DefaultMatcher.Get()))
+}
+
+func TestMatcher(t *testing.T) {
+ cases := []struct {
+ method string
+ url string
+ matches bool
+ }{
+ {"GET", "http://foo.com/bar", true},
+ {"GET", "http://foo.com/baz", true},
+ {"GET", "http://foo.com/foo", false},
+ {"POST", "http://foo.com/bar", false},
+ {"POST", "http://bar.com/bar", false},
+ {"GET", "http://foo.com", false},
+ }
+
+ g := NewGock()
+ matcher := g.NewMatcher()
+ matcher.Flush()
+ st.Expect(t, len(matcher.Matchers), 0)
+
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.Method == "GET", nil
+ })
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Host == "foo.com", nil
+ })
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil
+ })
+
+ for _, test := range cases {
+ u, _ := url.Parse(test.url)
+ req := &http.Request{Method: test.method, URL: u}
+ matches, err := matcher.Match(req, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchMock(t *testing.T) {
+ cases := []struct {
+ method string
+ url string
+ matches bool
+ }{
+ {"GET", "http://foo.com/bar", true},
+ {"GET", "http://foo.com/baz", true},
+ {"GET", "http://foo.com/foo", false},
+ {"POST", "http://foo.com/bar", false},
+ {"POST", "http://bar.com/bar", false},
+ {"GET", "http://foo.com", false},
+ }
+
+ g := NewGock()
+ matcher := g.DefaultMatcher
+ matcher.Flush()
+ st.Expect(t, len(matcher.Matchers), 0)
+
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.Method == "GET", nil
+ })
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Host == "foo.com", nil
+ })
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil
+ })
+
+ for _, test := range cases {
+ g.Flush()
+ mock := g.New(test.url).method(test.method, "").Mock
+
+ u, _ := url.Parse(test.url)
+ req := &http.Request{Method: test.method, URL: u}
+
+ match, err := g.MatchMock(req)
+ st.Expect(t, err, nil)
+ if test.matches {
+ st.Expect(t, match, mock)
+ } else {
+ st.Expect(t, match, nil)
+ }
+ }
+
+ g.DefaultMatcher.Matchers = g.Matchers
+}
diff --git a/threadsafe/matchers.go b/threadsafe/matchers.go
new file mode 100644
index 0000000..9b1c0b3
--- /dev/null
+++ b/threadsafe/matchers.go
@@ -0,0 +1,240 @@
+package threadsafe
+
+import (
+ "compress/gzip"
+ "encoding/json"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "reflect"
+ "regexp"
+ "strings"
+
+ "github.com/h2non/parth"
+)
+
+// EOL represents the end of line character.
+const EOL = 0xa
+
+// MatchMethod matches the HTTP method of the given request.
+func (g *Gock) MatchMethod(req *http.Request, ereq *Request) (bool, error) {
+ return ereq.Method == "" || req.Method == ereq.Method, nil
+}
+
+// MatchScheme matches the request URL protocol scheme.
+func (g *Gock) MatchScheme(req *http.Request, ereq *Request) (bool, error) {
+ return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil
+}
+
+// MatchHost matches the HTTP host header field of the given request.
+func (g *Gock) MatchHost(req *http.Request, ereq *Request) (bool, error) {
+ url := ereq.URLStruct
+ if strings.EqualFold(url.Host, req.URL.Host) {
+ return true, nil
+ }
+ if !ereq.Options.DisableRegexpHost {
+ return regexp.MatchString(url.Host, req.URL.Host)
+ }
+ return false, nil
+}
+
+// MatchPath matches the HTTP URL path of the given request.
+func (g *Gock) MatchPath(req *http.Request, ereq *Request) (bool, error) {
+ if req.URL.Path == ereq.URLStruct.Path {
+ return true, nil
+ }
+ return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path)
+}
+
+// MatchHeaders matches the headers fields of the given request.
+func (g *Gock) MatchHeaders(req *http.Request, ereq *Request) (bool, error) {
+ for key, value := range ereq.Header {
+ var err error
+ var match bool
+ var matchEscaped bool
+
+ for _, field := range req.Header[key] {
+ match, err = regexp.MatchString(value[0], field)
+ // Some values may contain reserved regex params e.g. "()", try matching with these escaped.
+ matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field)
+
+ if err != nil {
+ return false, err
+ }
+ if match || matchEscaped {
+ break
+ }
+
+ }
+
+ if !match && !matchEscaped {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+// MatchQueryParams matches the URL query params fields of the given request.
+func (g *Gock) MatchQueryParams(req *http.Request, ereq *Request) (bool, error) {
+ for key, value := range ereq.URLStruct.Query() {
+ var err error
+ var match bool
+
+ for _, field := range req.URL.Query()[key] {
+ match, err = regexp.MatchString(value[0], field)
+ if err != nil {
+ return false, err
+ }
+ if match {
+ break
+ }
+ }
+
+ if !match {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+// MatchPathParams matches the URL path parameters of the given request.
+func (g *Gock) MatchPathParams(req *http.Request, ereq *Request) (bool, error) {
+ for key, value := range ereq.PathParams {
+ var s string
+
+ if err := parth.Sequent(req.URL.Path, key, &s); err != nil {
+ return false, nil
+ }
+
+ if s != value {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+// MatchBody tries to match the request body.
+// TODO: not too smart now, needs several improvements.
+func (g *Gock) MatchBody(req *http.Request, ereq *Request) (bool, error) {
+ // If match body is empty, just continue
+ if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 {
+ return true, nil
+ }
+
+ // Only can match certain MIME body types
+ if !g.supportedType(req, ereq) {
+ return false, nil
+ }
+
+ // Can only match certain compression schemes
+ if !g.supportedCompressionScheme(req) {
+ return false, nil
+ }
+
+ // Create a reader for the body depending on compression type
+ bodyReader := req.Body
+ if ereq.CompressionScheme != "" {
+ if ereq.CompressionScheme != req.Header.Get("Content-Encoding") {
+ return false, nil
+ }
+ compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme)
+ if err != nil {
+ return false, err
+ }
+ bodyReader = compressedBodyReader
+ }
+
+ // Read the whole request body
+ body, err := ioutil.ReadAll(bodyReader)
+ if err != nil {
+ return false, err
+ }
+
+ // Restore body reader stream
+ req.Body = createReadCloser(body)
+
+ // If empty, ignore the match
+ if len(body) == 0 && len(ereq.BodyBuffer) != 0 {
+ return false, nil
+ }
+
+ // Match body by atomic string comparison
+ bodyStr := castToString(body)
+ matchStr := castToString(ereq.BodyBuffer)
+ if bodyStr == matchStr {
+ return true, nil
+ }
+
+ // Match request body by regexp
+ match, _ := regexp.MatchString(matchStr, bodyStr)
+ if match == true {
+ return true, nil
+ }
+
+ // todo - add conditional do only perform the conversion of body bytes
+ // representation of JSON to a map and then compare them for equality.
+
+ // Check if the key + value pairs match
+ var bodyMap map[string]interface{}
+ var matchMap map[string]interface{}
+
+ // Ensure that both byte bodies that that should be JSON can be converted to maps.
+ umErr := json.Unmarshal(body, &bodyMap)
+ umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap)
+ if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) {
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (g *Gock) supportedType(req *http.Request, ereq *Request) bool {
+ mime := req.Header.Get("Content-Type")
+ if mime == "" {
+ return true
+ }
+
+ mimeToMatch := ereq.Header.Get("Content-Type")
+ if mimeToMatch != "" {
+ return mime == mimeToMatch
+ }
+
+ for _, kind := range g.BodyTypes {
+ if match, _ := regexp.MatchString(kind, mime); match {
+ return true
+ }
+ }
+ return false
+}
+
+func (g *Gock) supportedCompressionScheme(req *http.Request) bool {
+ encoding := req.Header.Get("Content-Encoding")
+ if encoding == "" {
+ return true
+ }
+
+ for _, kind := range g.CompressionSchemes {
+ if match, _ := regexp.MatchString(kind, encoding); match {
+ return true
+ }
+ }
+ return false
+}
+
+func castToString(buf []byte) string {
+ str := string(buf)
+ tail := len(str) - 1
+ if str[tail] == EOL {
+ str = str[:tail]
+ }
+ return str
+}
+
+func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) {
+ switch scheme {
+ case "gzip":
+ return gzip.NewReader(r)
+ default:
+ return r, nil
+ }
+}
diff --git a/threadsafe/matchers_test.go b/threadsafe/matchers_test.go
new file mode 100644
index 0000000..6db6c08
--- /dev/null
+++ b/threadsafe/matchers_test.go
@@ -0,0 +1,253 @@
+package threadsafe
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestMatchMethod(t *testing.T) {
+ cases := []struct {
+ value string
+ method string
+ matches bool
+ }{
+ {"GET", "GET", true},
+ {"POST", "POST", true},
+ {"", "POST", true},
+ {"POST", "GET", false},
+ {"PUT", "GET", false},
+ }
+
+ for _, test := range cases {
+ req := &http.Request{Method: test.method}
+ ereq := &Request{Method: test.value}
+ matches, err := NewGock().MatchMethod(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchScheme(t *testing.T) {
+ cases := []struct {
+ value string
+ scheme string
+ matches bool
+ }{
+ {"http", "http", true},
+ {"https", "https", true},
+ {"http", "https", false},
+ {"", "https", true},
+ {"https", "", true},
+ }
+
+ for _, test := range cases {
+ req := &http.Request{URL: &url.URL{Scheme: test.scheme}}
+ ereq := &Request{URLStruct: &url.URL{Scheme: test.value}}
+ matches, err := NewGock().MatchScheme(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchHost(t *testing.T) {
+ cases := []struct {
+ value string
+ url string
+ matches bool
+ matchesNonRegexp bool
+ }{
+ {"foo.com", "foo.com", true, true},
+ {"FOO.com", "foo.com", true, true},
+ {"foo.net", "foo.com", false, false},
+ {"foo.bar.net", "foo-bar.net", true, false},
+ {"foo", "foo.com", true, false},
+ {"(.*).com", "foo.com", true, false},
+ {"127.0.0.1", "127.0.0.1", true, true},
+ {"127.0.0.2", "127.0.0.1", false, false},
+ {"127.0.0.*", "127.0.0.1", true, false},
+ {"127.0.0.[0-9]", "127.0.0.7", true, false},
+ }
+
+ for _, test := range cases {
+ req := &http.Request{URL: &url.URL{Host: test.url}}
+ ereq := &Request{URLStruct: &url.URL{Host: test.value}}
+ matches, err := NewGock().MatchHost(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ ereq.WithOptions(Options{DisableRegexpHost: true})
+ matches, err = NewGock().MatchHost(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matchesNonRegexp)
+ }
+}
+
+func TestMatchPath(t *testing.T) {
+ cases := []struct {
+ value string
+ path string
+ matches bool
+ }{
+ {"/foo", "/foo", true},
+ {"/foo", "/foo/bar", true},
+ {"bar", "/foo/bar", true},
+ {"foo", "/foo/bar", true},
+ {"bar$", "/foo/bar", true},
+ {"/foo/*", "/foo/bar", true},
+ {"/foo/[a-z]+", "/foo/bar", true},
+ {"/foo/baz", "/foo/bar", false},
+ {"/foo/baz", "/foo/bar", false},
+ {"/foo/bar%3F+%C3%A9", "/foo/bar%3F+%C3%A9", true},
+ }
+
+ for _, test := range cases {
+ u, _ := url.Parse("http://foo.com" + test.path)
+ mu, _ := url.Parse("http://foo.com" + test.value)
+ req := &http.Request{URL: u}
+ ereq := &Request{URLStruct: mu}
+ matches, err := NewGock().MatchPath(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchHeaders(t *testing.T) {
+ cases := []struct {
+ values http.Header
+ headers http.Header
+ matches bool
+ }{
+ {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, true},
+ {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"barbar"}}, true},
+ {http.Header{"bar": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false},
+ {http.Header{"foofoo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false},
+ {http.Header{"foo": []string{"bar(.*)"}}, http.Header{"foo": []string{"barbar"}}, true},
+ {http.Header{"foo": []string{"b(.*)"}}, http.Header{"foo": []string{"barbar"}}, true},
+ {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"bar"}}, true},
+ {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"barbar"}}, false},
+ {http.Header{"UPPERCASE": []string{"bar"}}, http.Header{"UPPERCASE": []string{"bar"}}, true},
+ {http.Header{"Mixed-CASE": []string{"bar"}}, http.Header{"Mixed-CASE": []string{"bar"}}, true},
+ {http.Header{"User-Agent": []string{"Agent (version1.0)"}}, http.Header{"User-Agent": []string{"Agent (version1.0)"}}, true},
+ {http.Header{"Content-Type": []string{"(.*)/plain"}}, http.Header{"Content-Type": []string{"text/plain"}}, true},
+ }
+
+ for _, test := range cases {
+ req := &http.Request{Header: test.headers}
+ ereq := &Request{Header: test.values}
+ matches, err := NewGock().MatchHeaders(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchQueryParams(t *testing.T) {
+ cases := []struct {
+ value string
+ path string
+ matches bool
+ }{
+ {"foo=bar", "foo=bar", true},
+ {"foo=bar", "foo=foo&foo=bar", true},
+ {"foo=b*", "foo=bar", true},
+ {"foo=.*", "foo=bar", true},
+ {"foo=f[o]{2}", "foo=foo", true},
+ {"foo=bar&bar=foo", "foo=bar&foo=foo&bar=foo", true},
+ {"foo=", "foo=bar", true},
+ {"foo=foo", "foo=bar", false},
+ {"bar=bar", "foo=bar bar", false},
+ }
+
+ for _, test := range cases {
+ u, _ := url.Parse("http://foo.com/?" + test.path)
+ mu, _ := url.Parse("http://foo.com/?" + test.value)
+ req := &http.Request{URL: u}
+ ereq := &Request{URLStruct: mu}
+ matches, err := NewGock().MatchQueryParams(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchPathParams(t *testing.T) {
+ cases := []struct {
+ key string
+ value string
+ path string
+ matches bool
+ }{
+ {"foo", "bar", "/foo/bar", true},
+ {"foo", "bar", "/foo/test/bar", false},
+ {"foo", "bar", "/test/foo/bar/ack", true},
+ {"foo", "bar", "/foo", false},
+ }
+
+ for i, test := range cases {
+ u, _ := url.Parse("http://foo.com" + test.path)
+ mu, _ := url.Parse("http://foo.com" + test.path)
+ req := &http.Request{URL: u}
+ ereq := &Request{
+ URLStruct: mu,
+ PathParams: map[string]string{test.key: test.value},
+ }
+ matches, err := NewGock().MatchPathParams(req, ereq)
+ st.Expect(t, err, nil, i)
+ st.Expect(t, matches, test.matches, i)
+ }
+}
+
+func TestMatchBody(t *testing.T) {
+ cases := []struct {
+ value string
+ body string
+ matches bool
+ }{
+ {"foo bar", "foo bar\n", true},
+ {"foo", "foo bar\n", true},
+ {"f[o]+", "foo\n", true},
+ {`"foo"`, `{"foo":"bar"}\n`, true},
+ {`{"foo":"bar"}`, `{"foo":"bar"}\n`, true},
+ {`{"foo":"foo"}`, `{"foo":"bar"}\n`, false},
+
+ {`{"foo":"bar","bar":"foo"}`, `{"bar":"foo","foo":"bar"}`, true},
+ {`{"bar":"foo","foo":{"two":"three","three":"two"}}`, `{"foo":{"three":"two","two":"three"},"bar":"foo"}`, true},
+ }
+
+ g := NewGock()
+ for _, test := range cases {
+ req := &http.Request{Body: createReadCloser([]byte(test.body))}
+ ereq := &Request{BodyBuffer: []byte(test.value)}
+ matches, err := g.MatchBody(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
+
+func TestMatchBody_MatchType(t *testing.T) {
+ body := `{"foo":"bar"}`
+ cases := []struct {
+ body string
+ requestContentType string
+ customBodyType string
+ matches bool
+ }{
+ {body, "application/vnd.apiname.v1+json", "foobar", false},
+ {body, "application/vnd.apiname.v1+json", "application/vnd.apiname.v1+json", true},
+ {body, "application/json", "foobar", false},
+ {body, "application/json", "", true},
+ {"", "", "", true},
+ }
+
+ g := NewGock()
+ for _, test := range cases {
+ req := &http.Request{
+ Header: http.Header{"Content-Type": []string{test.requestContentType}},
+ Body: createReadCloser([]byte(test.body)),
+ }
+ ereq := g.NewRequest().BodyString(test.body).MatchType(test.customBodyType)
+ matches, err := g.MatchBody(req, ereq)
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, test.matches)
+ }
+}
diff --git a/threadsafe/mock.go b/threadsafe/mock.go
new file mode 100644
index 0000000..004263a
--- /dev/null
+++ b/threadsafe/mock.go
@@ -0,0 +1,172 @@
+package threadsafe
+
+import (
+ "net/http"
+ "sync"
+)
+
+// Mock represents the required interface that must
+// be implemented by HTTP mock instances.
+type Mock interface {
+ // Disable disables the current mock manually.
+ Disable()
+
+ // Done returns true if the current mock is disabled.
+ Done() bool
+
+ // Request returns the mock Request instance.
+ Request() *Request
+
+ // Response returns the mock Response instance.
+ Response() *Response
+
+ // Match matches the given http.Request with the current mock.
+ Match(*http.Request) (bool, error)
+
+ // AddMatcher adds a new matcher function.
+ AddMatcher(MatchFunc)
+
+ // SetMatcher uses a new matcher implementation.
+ SetMatcher(Matcher)
+}
+
+// Mocker implements a Mock capable interface providing
+// a default mock configuration used internally to store mocks.
+type Mocker struct {
+ // disabler stores a disabler for thread safety checking current mock is disabled
+ disabler *disabler
+
+ // mutex stores the mock mutex for thread safety.
+ mutex sync.Mutex
+
+ // matcher stores a Matcher capable instance to match the given http.Request.
+ matcher Matcher
+
+ // request stores the mock Request to match.
+ request *Request
+
+ // response stores the mock Response to use in case of match.
+ response *Response
+}
+
+type disabler struct {
+ // disabled stores if the current mock is disabled.
+ disabled bool
+
+ // mutex stores the disabler mutex for thread safety.
+ mutex sync.RWMutex
+}
+
+func (d *disabler) isDisabled() bool {
+ d.mutex.RLock()
+ defer d.mutex.RUnlock()
+ return d.disabled
+}
+
+func (d *disabler) Disable() {
+ d.mutex.Lock()
+ defer d.mutex.Unlock()
+ d.disabled = true
+}
+
+// NewMock creates a new HTTP mock based on the given request and response instances.
+// It's mostly used internally.
+func (g *Gock) NewMock(req *Request, res *Response) *Mocker {
+ mock := &Mocker{
+ disabler: new(disabler),
+ request: req,
+ response: res,
+ matcher: g.DefaultMatcher.Clone(),
+ }
+ res.Mock = mock
+ req.Mock = mock
+ req.Response = res
+ return mock
+}
+
+// Disable disables the current mock manually.
+func (m *Mocker) Disable() {
+ m.disabler.Disable()
+}
+
+// Done returns true in case that the current mock
+// instance is disabled and therefore must be removed.
+func (m *Mocker) Done() bool {
+ // prevent deadlock with m.mutex
+ if m.disabler.isDisabled() {
+ return true
+ }
+
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ return !m.request.Persisted && m.request.Counter == 0
+}
+
+// Request returns the Request instance
+// configured for the current HTTP mock.
+func (m *Mocker) Request() *Request {
+ return m.request
+}
+
+// Response returns the Response instance
+// configured for the current HTTP mock.
+func (m *Mocker) Response() *Response {
+ return m.response
+}
+
+// Match matches the given http.Request with the current Request
+// mock expectation, returning true if matches.
+func (m *Mocker) Match(req *http.Request) (bool, error) {
+ if m.disabler.isDisabled() {
+ return false, nil
+ }
+
+ // Filter
+ for _, filter := range m.request.Filters {
+ if !filter(req) {
+ return false, nil
+ }
+ }
+
+ // Map
+ for _, mapper := range m.request.Mappers {
+ if treq := mapper(req); treq != nil {
+ req = treq
+ }
+ }
+
+ // Match
+ matches, err := m.matcher.Match(req, m.request)
+ if matches {
+ m.decrement()
+ }
+
+ return matches, err
+}
+
+// SetMatcher sets a new matcher implementation
+// for the current mock expectation.
+func (m *Mocker) SetMatcher(matcher Matcher) {
+ m.matcher = matcher
+}
+
+// AddMatcher adds a new matcher function
+// for the current mock expectation.
+func (m *Mocker) AddMatcher(fn MatchFunc) {
+ m.matcher.Add(fn)
+}
+
+// decrement decrements the current mock Request counter.
+func (m *Mocker) decrement() {
+ if m.request.Persisted {
+ return
+ }
+
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ m.request.Counter--
+ if m.request.Counter == 0 {
+ m.disabler.Disable()
+ }
+}
diff --git a/threadsafe/mock_test.go b/threadsafe/mock_test.go
new file mode 100644
index 0000000..f277842
--- /dev/null
+++ b/threadsafe/mock_test.go
@@ -0,0 +1,143 @@
+package threadsafe
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestNewMock(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+ mock := g.NewMock(req, res)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get()))
+
+ st.Expect(t, mock.Request(), req)
+ st.Expect(t, mock.Request().Mock, mock)
+ st.Expect(t, mock.Response(), res)
+ st.Expect(t, mock.Response().Mock, mock)
+}
+
+func TestMockDisable(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+ mock := g.NewMock(req, res)
+
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ mock.Disable()
+ st.Expect(t, mock.disabler.isDisabled(), true)
+
+ matches, err := mock.Match(&http.Request{})
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, false)
+}
+
+func TestMockDone(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+
+ mock := g.NewMock(req, res)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ st.Expect(t, mock.Done(), false)
+
+ mock = g.NewMock(req, res)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ mock.Disable()
+ st.Expect(t, mock.Done(), true)
+
+ mock = g.NewMock(req, res)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ mock.request.Counter = 0
+ st.Expect(t, mock.Done(), true)
+
+ mock = g.NewMock(req, res)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ mock.request.Persisted = true
+ st.Expect(t, mock.Done(), false)
+}
+
+func TestMockSetMatcher(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+ mock := g.NewMock(req, res)
+
+ st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get()))
+ matcher := g.NewMatcher()
+ matcher.Flush()
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return true, nil
+ })
+ mock.SetMatcher(matcher)
+ st.Expect(t, len(mock.matcher.Get()), 1)
+ st.Expect(t, mock.disabler.isDisabled(), false)
+
+ matches, err := mock.Match(&http.Request{})
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, true)
+}
+
+func TestMockAddMatcher(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+ mock := g.NewMock(req, res)
+
+ st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get()))
+ matcher := g.NewMatcher()
+ matcher.Flush()
+ mock.SetMatcher(matcher)
+ mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
+ return true, nil
+ })
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ st.Expect(t, mock.matcher, matcher)
+
+ matches, err := mock.Match(&http.Request{})
+ st.Expect(t, err, nil)
+ st.Expect(t, matches, true)
+}
+
+func TestMockMatch(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ req := g.NewRequest()
+ res := g.NewResponse()
+ mock := g.NewMock(req, res)
+
+ matcher := g.NewMatcher()
+ matcher.Flush()
+ mock.SetMatcher(matcher)
+ calls := 0
+ mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
+ calls++
+ return true, nil
+ })
+ mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
+ calls++
+ return true, nil
+ })
+ st.Expect(t, mock.disabler.isDisabled(), false)
+ st.Expect(t, mock.matcher, matcher)
+
+ matches, err := mock.Match(&http.Request{})
+ st.Expect(t, err, nil)
+ st.Expect(t, calls, 2)
+ st.Expect(t, matches, true)
+}
diff --git a/threadsafe/options.go b/threadsafe/options.go
new file mode 100644
index 0000000..98497f9
--- /dev/null
+++ b/threadsafe/options.go
@@ -0,0 +1,8 @@
+package threadsafe
+
+// Options represents customized option for gock
+type Options struct {
+ // DisableRegexpHost stores if the host is only a plain string rather than regular expression,
+ // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string
+ DisableRegexpHost bool
+}
diff --git a/threadsafe/request.go b/threadsafe/request.go
new file mode 100644
index 0000000..3508bbb
--- /dev/null
+++ b/threadsafe/request.go
@@ -0,0 +1,330 @@
+package threadsafe
+
+import (
+ "encoding/base64"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+// MapRequestFunc represents the required function interface for request mappers.
+type MapRequestFunc func(*http.Request) *http.Request
+
+// FilterRequestFunc represents the required function interface for request filters.
+type FilterRequestFunc func(*http.Request) bool
+
+// Request represents the high-level HTTP request used to store
+// request fields used to match intercepted requests.
+type Request struct {
+ g *Gock
+
+ // Mock stores the parent mock reference for the current request mock used for method delegation.
+ Mock Mock
+
+ // Response stores the current Response instance for the current matches Request.
+ Response *Response
+
+ // Error stores the latest mock request configuration error.
+ Error error
+
+ // Counter stores the pending times that the current mock should be active.
+ Counter int
+
+ // Persisted stores if the current mock should be always active.
+ Persisted bool
+
+ // Options stores options for current Request.
+ Options Options
+
+ // URLStruct stores the parsed URL as *url.URL struct.
+ URLStruct *url.URL
+
+ // Method stores the Request HTTP method to match.
+ Method string
+
+ // CompressionScheme stores the Request Compression scheme to match and use for decompression.
+ CompressionScheme string
+
+ // Header stores the HTTP header fields to match.
+ Header http.Header
+
+ // Cookies stores the Request HTTP cookies values to match.
+ Cookies []*http.Cookie
+
+ // PathParams stores the path parameters to match.
+ PathParams map[string]string
+
+ // BodyBuffer stores the body data to match.
+ BodyBuffer []byte
+
+ // Mappers stores the request functions mappers used for matching.
+ Mappers []MapRequestFunc
+
+ // Filters stores the request functions filters used for matching.
+ Filters []FilterRequestFunc
+}
+
+// NewRequest creates a new Request instance.
+func (g *Gock) NewRequest() *Request {
+ return &Request{
+ g: g,
+ Counter: 1,
+ URLStruct: &url.URL{},
+ Header: make(http.Header),
+ PathParams: make(map[string]string),
+ }
+}
+
+// URL defines the mock URL to match.
+func (r *Request) URL(uri string) *Request {
+ r.URLStruct, r.Error = url.Parse(uri)
+ return r
+}
+
+// SetURL defines the url.URL struct to be used for matching.
+func (r *Request) SetURL(u *url.URL) *Request {
+ r.URLStruct = u
+ return r
+}
+
+// Path defines the mock URL path value to match.
+func (r *Request) Path(path string) *Request {
+ r.URLStruct.Path = path
+ return r
+}
+
+// Get specifies the GET method and the given URL path to match.
+func (r *Request) Get(path string) *Request {
+ return r.method("GET", path)
+}
+
+// Post specifies the POST method and the given URL path to match.
+func (r *Request) Post(path string) *Request {
+ return r.method("POST", path)
+}
+
+// Put specifies the PUT method and the given URL path to match.
+func (r *Request) Put(path string) *Request {
+ return r.method("PUT", path)
+}
+
+// Delete specifies the DELETE method and the given URL path to match.
+func (r *Request) Delete(path string) *Request {
+ return r.method("DELETE", path)
+}
+
+// Patch specifies the PATCH method and the given URL path to match.
+func (r *Request) Patch(path string) *Request {
+ return r.method("PATCH", path)
+}
+
+// Head specifies the HEAD method and the given URL path to match.
+func (r *Request) Head(path string) *Request {
+ return r.method("HEAD", path)
+}
+
+// method is a DRY shortcut used to declare the expected HTTP method and URL path.
+func (r *Request) method(method, path string) *Request {
+ if path != "/" {
+ r.URLStruct.Path = path
+ }
+ r.Method = strings.ToUpper(method)
+ return r
+}
+
+// Body defines the body data to match based on a io.Reader interface.
+func (r *Request) Body(body io.Reader) *Request {
+ r.BodyBuffer, r.Error = ioutil.ReadAll(body)
+ return r
+}
+
+// BodyString defines the body to match based on a given string.
+func (r *Request) BodyString(body string) *Request {
+ r.BodyBuffer = []byte(body)
+ return r
+}
+
+// File defines the body to match based on the given file path string.
+func (r *Request) File(path string) *Request {
+ r.BodyBuffer, r.Error = ioutil.ReadFile(path)
+ return r
+}
+
+// Compression defines the request compression scheme, and enables automatic body decompression.
+// Supports only the "gzip" scheme so far.
+func (r *Request) Compression(scheme string) *Request {
+ r.Header.Set("Content-Encoding", scheme)
+ r.CompressionScheme = scheme
+ return r
+}
+
+// JSON defines the JSON body to match based on a given structure.
+func (r *Request) JSON(data interface{}) *Request {
+ if r.Header.Get("Content-Type") == "" {
+ r.Header.Set("Content-Type", "application/json")
+ }
+ r.BodyBuffer, r.Error = readAndDecode(data, "json")
+ return r
+}
+
+// XML defines the XML body to match based on a given structure.
+func (r *Request) XML(data interface{}) *Request {
+ if r.Header.Get("Content-Type") == "" {
+ r.Header.Set("Content-Type", "application/xml")
+ }
+ r.BodyBuffer, r.Error = readAndDecode(data, "xml")
+ return r
+}
+
+// MatchType defines the request Content-Type MIME header field.
+// Supports custom MIME types and type aliases. E.g: json, xml, form, text...
+func (r *Request) MatchType(kind string) *Request {
+ mime := r.g.BodyTypeAliases[kind]
+ if mime != "" {
+ kind = mime
+ }
+ r.Header.Set("Content-Type", kind)
+ return r
+}
+
+// BasicAuth defines a username and password for HTTP Basic Authentication
+func (r *Request) BasicAuth(username, password string) *Request {
+ r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
+ return r
+}
+
+// MatchHeader defines a new key and value header to match.
+func (r *Request) MatchHeader(key, value string) *Request {
+ r.Header.Set(key, value)
+ return r
+}
+
+// HeaderPresent defines that a header field must be present in the request.
+func (r *Request) HeaderPresent(key string) *Request {
+ r.Header.Set(key, ".*")
+ return r
+}
+
+// MatchHeaders defines a map of key-value headers to match.
+func (r *Request) MatchHeaders(headers map[string]string) *Request {
+ for key, value := range headers {
+ r.Header.Set(key, value)
+ }
+ return r
+}
+
+// MatchParam defines a new key and value URL query param to match.
+func (r *Request) MatchParam(key, value string) *Request {
+ query := r.URLStruct.Query()
+ query.Set(key, value)
+ r.URLStruct.RawQuery = query.Encode()
+ return r
+}
+
+// MatchParams defines a map of URL query param key-value to match.
+func (r *Request) MatchParams(params map[string]string) *Request {
+ query := r.URLStruct.Query()
+ for key, value := range params {
+ query.Set(key, value)
+ }
+ r.URLStruct.RawQuery = query.Encode()
+ return r
+}
+
+// ParamPresent matches if the given query param key is present in the URL.
+func (r *Request) ParamPresent(key string) *Request {
+ r.MatchParam(key, ".*")
+ return r
+}
+
+// PathParam matches if a given path parameter key is present in the URL.
+//
+// The value is representative of the restful resource the key defines, e.g.
+//
+// // /users/123/name
+// r.PathParam("users", "123")
+//
+// would match.
+func (r *Request) PathParam(key, val string) *Request {
+ r.PathParams[key] = val
+
+ return r
+}
+
+// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it.
+func (r *Request) Persist() *Request {
+ r.Persisted = true
+ return r
+}
+
+// WithOptions sets the options for the request.
+func (r *Request) WithOptions(options Options) *Request {
+ r.Options = options
+ return r
+}
+
+// Times defines the number of times that the current HTTP mock should remain active.
+func (r *Request) Times(num int) *Request {
+ r.Counter = num
+ return r
+}
+
+// AddMatcher adds a new matcher function to match the request.
+func (r *Request) AddMatcher(fn MatchFunc) *Request {
+ r.Mock.AddMatcher(fn)
+ return r
+}
+
+// SetMatcher sets a new matcher function to match the request.
+func (r *Request) SetMatcher(matcher Matcher) *Request {
+ r.Mock.SetMatcher(matcher)
+ return r
+}
+
+// Map adds a new request mapper function to map http.Request before the matching process.
+func (r *Request) Map(fn MapRequestFunc) *Request {
+ r.Mappers = append(r.Mappers, fn)
+ return r
+}
+
+// Filter filters a new request filter function to filter http.Request before the matching process.
+func (r *Request) Filter(fn FilterRequestFunc) *Request {
+ r.Filters = append(r.Filters, fn)
+ return r
+}
+
+// EnableNetworking enables the use real networking for the current mock.
+func (r *Request) EnableNetworking() *Request {
+ if r.Response != nil {
+ r.Response.UseNetwork = true
+ }
+ return r
+}
+
+// Reply defines the Response status code and returns the mock Response DSL.
+func (r *Request) Reply(status int) *Response {
+ return r.Response.Status(status)
+}
+
+// ReplyError defines the Response simulated error.
+func (r *Request) ReplyError(err error) *Response {
+ return r.Response.SetError(err)
+}
+
+// ReplyFunc allows the developer to define the mock response via a custom function.
+func (r *Request) ReplyFunc(replier func(*Response)) *Response {
+ replier(r.Response)
+ return r.Response
+}
+
+// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
+// "To receive authorization, the client sends the userid and password,
+// separated by a single colon (":") character, within a base64
+// encoded string in the credentials."
+// It is not meant to be urlencoded.
+func basicAuth(username, password string) string {
+ auth := username + ":" + password
+ return base64.StdEncoding.EncodeToString([]byte(auth))
+}
diff --git a/threadsafe/request_test.go b/threadsafe/request_test.go
new file mode 100644
index 0000000..e67614f
--- /dev/null
+++ b/threadsafe/request_test.go
@@ -0,0 +1,318 @@
+package threadsafe
+
+import (
+ "bytes"
+ "net/http"
+ "net/url"
+ "path/filepath"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestNewRequest(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.URL("http://foo.com")
+ st.Expect(t, req.URLStruct.Host, "foo.com")
+ st.Expect(t, req.URLStruct.Scheme, "http")
+ req.MatchHeader("foo", "bar")
+ st.Expect(t, req.Header.Get("foo"), "bar")
+}
+
+func TestRequestSetURL(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.URL("http://foo.com")
+ req.SetURL(&url.URL{Host: "bar.com", Path: "/foo"})
+ st.Expect(t, req.URLStruct.Host, "bar.com")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+}
+
+func TestRequestPath(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.URL("http://foo.com")
+ req.Path("/foo")
+ st.Expect(t, req.URLStruct.Scheme, "http")
+ st.Expect(t, req.URLStruct.Host, "foo.com")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+}
+
+func TestRequestBody(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.Body(bytes.NewBuffer([]byte("foo bar")))
+ st.Expect(t, string(req.BodyBuffer), "foo bar")
+}
+
+func TestRequestBodyString(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.BodyString("foo bar")
+ st.Expect(t, string(req.BodyBuffer), "foo bar")
+}
+
+func TestRequestFile(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ absPath, err := filepath.Abs("../version.go")
+ st.Expect(t, err, nil)
+ req.File(absPath)
+ st.Expect(t, string(req.BodyBuffer)[:12], "package gock")
+}
+
+func TestRequestJSON(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.JSON(map[string]string{"foo": "bar"})
+ st.Expect(t, string(req.BodyBuffer)[:13], `{"foo":"bar"}`)
+ st.Expect(t, req.Header.Get("Content-Type"), "application/json")
+}
+
+func TestRequestXML(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ type xml struct {
+ Data string `xml:"data"`
+ }
+ req.XML(xml{Data: "foo"})
+ st.Expect(t, string(req.BodyBuffer), `foo`)
+ st.Expect(t, req.Header.Get("Content-Type"), "application/xml")
+}
+
+func TestRequestMatchType(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.MatchType("json")
+ st.Expect(t, req.Header.Get("Content-Type"), "application/json")
+
+ req = g.NewRequest()
+ req.MatchType("html")
+ st.Expect(t, req.Header.Get("Content-Type"), "text/html")
+
+ req = g.NewRequest()
+ req.MatchType("foo/bar")
+ st.Expect(t, req.Header.Get("Content-Type"), "foo/bar")
+}
+
+func TestRequestBasicAuth(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.BasicAuth("bob", "qwerty")
+ st.Expect(t, req.Header.Get("Authorization"), "Basic Ym9iOnF3ZXJ0eQ==")
+}
+
+func TestRequestMatchHeader(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.MatchHeader("foo", "bar")
+ req.MatchHeader("bar", "baz")
+ req.MatchHeader("UPPERCASE", "bat")
+ req.MatchHeader("Mixed-CASE", "foo")
+
+ st.Expect(t, req.Header.Get("foo"), "bar")
+ st.Expect(t, req.Header.Get("bar"), "baz")
+ st.Expect(t, req.Header.Get("UPPERCASE"), "bat")
+ st.Expect(t, req.Header.Get("Mixed-CASE"), "foo")
+}
+
+func TestRequestHeaderPresent(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.HeaderPresent("foo")
+ req.HeaderPresent("bar")
+ req.HeaderPresent("UPPERCASE")
+ req.HeaderPresent("Mixed-CASE")
+ st.Expect(t, req.Header.Get("foo"), ".*")
+ st.Expect(t, req.Header.Get("bar"), ".*")
+ st.Expect(t, req.Header.Get("UPPERCASE"), ".*")
+ st.Expect(t, req.Header.Get("Mixed-CASE"), ".*")
+}
+
+func TestRequestMatchParam(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.MatchParam("foo", "bar")
+ req.MatchParam("bar", "baz")
+ st.Expect(t, req.URLStruct.Query().Get("foo"), "bar")
+ st.Expect(t, req.URLStruct.Query().Get("bar"), "baz")
+}
+
+func TestRequestMatchParams(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.MatchParams(map[string]string{"foo": "bar", "bar": "baz"})
+ st.Expect(t, req.URLStruct.Query().Get("foo"), "bar")
+ st.Expect(t, req.URLStruct.Query().Get("bar"), "baz")
+}
+
+func TestRequestPresentParam(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.ParamPresent("key")
+ st.Expect(t, req.URLStruct.Query().Get("key"), ".*")
+}
+
+func TestRequestPathParam(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.PathParam("key", "value")
+ st.Expect(t, req.PathParams["key"], "value")
+}
+
+func TestRequestPersist(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ st.Expect(t, req.Persisted, false)
+ req.Persist()
+ st.Expect(t, req.Persisted, true)
+}
+
+func TestRequestTimes(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ st.Expect(t, req.Counter, 1)
+ req.Times(3)
+ st.Expect(t, req.Counter, 3)
+}
+
+func TestRequestMap(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ st.Expect(t, len(req.Mappers), 0)
+ req.Map(func(req *http.Request) *http.Request {
+ return req
+ })
+ st.Expect(t, len(req.Mappers), 1)
+}
+
+func TestRequestFilter(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ st.Expect(t, len(req.Filters), 0)
+ req.Filter(func(req *http.Request) bool {
+ return true
+ })
+ st.Expect(t, len(req.Filters), 1)
+}
+
+func TestRequestEnableNetworking(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.Response = &Response{}
+ st.Expect(t, req.Response.UseNetwork, false)
+ req.EnableNetworking()
+ st.Expect(t, req.Response.UseNetwork, true)
+}
+
+func TestRequestResponse(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ res := g.NewResponse()
+ req.Response = res
+ chain := req.Reply(200)
+ st.Expect(t, chain, res)
+ st.Expect(t, chain.StatusCode, 200)
+}
+
+func TestRequestReplyFunc(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ res := g.NewResponse()
+ req.Response = res
+ chain := req.ReplyFunc(func(r *Response) {
+ r.Status(204)
+ })
+ st.Expect(t, chain, res)
+ st.Expect(t, chain.StatusCode, 204)
+}
+
+func TestRequestMethods(t *testing.T) {
+ g := NewGock()
+ req := g.NewRequest()
+ req.Get("/foo")
+ st.Expect(t, req.Method, "GET")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+
+ req = g.NewRequest()
+ req.Post("/foo")
+ st.Expect(t, req.Method, "POST")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+
+ req = g.NewRequest()
+ req.Put("/foo")
+ st.Expect(t, req.Method, "PUT")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+
+ req = g.NewRequest()
+ req.Delete("/foo")
+ st.Expect(t, req.Method, "DELETE")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+
+ req = g.NewRequest()
+ req.Patch("/foo")
+ st.Expect(t, req.Method, "PATCH")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+
+ req = g.NewRequest()
+ req.Head("/foo")
+ st.Expect(t, req.Method, "HEAD")
+ st.Expect(t, req.URLStruct.Path, "/foo")
+}
+
+func TestRequestSetMatcher(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ matcher := g.NewEmptyMatcher()
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Host == "foo.com", nil
+ })
+ matcher.Add(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.Header.Get("foo") == "bar", nil
+ })
+ ereq := g.NewRequest()
+ mock := g.NewMock(ereq, &Response{})
+ mock.SetMatcher(matcher)
+ ereq.Mock = mock
+
+ headers := make(http.Header)
+ headers.Set("foo", "bar")
+ req := &http.Request{
+ URL: &url.URL{Host: "foo.com", Path: "/bar"},
+ Header: headers,
+ }
+
+ match, err := ereq.Mock.Match(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, match, true)
+}
+
+func TestRequestAddMatcher(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ ereq := g.NewRequest()
+ mock := g.NewMock(ereq, &Response{})
+ mock.matcher = g.NewMatcher()
+ ereq.Mock = mock
+
+ ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.URL.Host == "foo.com", nil
+ })
+ ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) {
+ return req.Header.Get("foo") == "bar", nil
+ })
+
+ headers := make(http.Header)
+ headers.Set("foo", "bar")
+ req := &http.Request{
+ URL: &url.URL{Host: "foo.com", Path: "/bar"},
+ Header: headers,
+ }
+
+ match, err := ereq.Mock.Match(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, match, true)
+}
diff --git a/threadsafe/responder.go b/threadsafe/responder.go
new file mode 100644
index 0000000..5dc4a7d
--- /dev/null
+++ b/threadsafe/responder.go
@@ -0,0 +1,111 @@
+package threadsafe
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strconv"
+ "time"
+)
+
+// Responder builds a mock http.Response based on the given Response mock.
+func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) {
+ // If error present, reply it
+ err := mock.Error
+ if err != nil {
+ return nil, err
+ }
+
+ if res == nil {
+ res = createResponse(req)
+ }
+
+ // Apply response filter
+ for _, filter := range mock.Filters {
+ if !filter(res) {
+ return res, nil
+ }
+ }
+
+ // Define mock status code
+ if mock.StatusCode != 0 {
+ res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode)
+ res.StatusCode = mock.StatusCode
+ }
+
+ // Define headers by merging fields
+ res.Header = mergeHeaders(res, mock)
+
+ // Define mock body, if present
+ if len(mock.BodyBuffer) > 0 {
+ res.ContentLength = int64(len(mock.BodyBuffer))
+ res.Body = createReadCloser(mock.BodyBuffer)
+ }
+
+ // Set raw mock body, if exist
+ if mock.BodyGen != nil {
+ res.ContentLength = -1
+ res.Body = mock.BodyGen()
+ }
+
+ // Apply response mappers
+ for _, mapper := range mock.Mappers {
+ if tres := mapper(res); tres != nil {
+ res = tres
+ }
+ }
+
+ // Sleep to simulate delay, if necessary
+ if mock.ResponseDelay > 0 {
+ // allow escaping from sleep due to request context expiration or cancellation
+ t := time.NewTimer(mock.ResponseDelay)
+ select {
+ case <-t.C:
+ case <-req.Context().Done():
+ // cleanly stop the timer
+ if !t.Stop() {
+ <-t.C
+ }
+ }
+ }
+
+ // check if the request context has ended. we could put this up in the delay code above, but putting it here
+ // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.)
+ if err = req.Context().Err(); err != nil {
+ // cleanly close the response and return the context error
+ io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ return nil, err
+ }
+
+ return res, err
+}
+
+// createResponse creates a new http.Response with default fields.
+func createResponse(req *http.Request) *http.Response {
+ return &http.Response{
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Proto: "HTTP/1.1",
+ Request: req,
+ Header: make(http.Header),
+ Body: createReadCloser([]byte{}),
+ }
+}
+
+// mergeHeaders copies the mock headers.
+func mergeHeaders(res *http.Response, mres *Response) http.Header {
+ for key, values := range mres.Header {
+ for _, value := range values {
+ res.Header.Add(key, value)
+ }
+ }
+ return res.Header
+}
+
+// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an
+// http response body.
+func createReadCloser(body []byte) io.ReadCloser {
+ return ioutil.NopCloser(bytes.NewReader(body))
+}
diff --git a/threadsafe/responder_test.go b/threadsafe/responder_test.go
new file mode 100644
index 0000000..7d18820
--- /dev/null
+++ b/threadsafe/responder_test.go
@@ -0,0 +1,191 @@
+package threadsafe
+
+import (
+ "context"
+ "errors"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/nbio/st"
+)
+
+func TestResponder(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").Reply(200).BodyString("foo")
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Status, "200 OK")
+ st.Expect(t, res.StatusCode, 200)
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo")
+}
+
+func TestResponder_ReadTwice(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").Reply(200).BodyString("foo")
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Status, "200 OK")
+ st.Expect(t, res.StatusCode, 200)
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo")
+
+ body, err = ioutil.ReadAll(res.Body)
+ st.Expect(t, err, nil)
+ st.Expect(t, body, []byte{})
+}
+
+func TestResponderBodyGenerator(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ generator := func() io.ReadCloser {
+ return io.NopCloser(strings.NewReader("foo"))
+ }
+ mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator)
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Status, "200 OK")
+ st.Expect(t, res.StatusCode, 200)
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo")
+}
+
+func TestResponderBodyGenerator_ReadTwice(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ generator := func() io.ReadCloser {
+ return io.NopCloser(strings.NewReader("foo"))
+ }
+ mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator)
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Status, "200 OK")
+ st.Expect(t, res.StatusCode, 200)
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo")
+
+ body, err = ioutil.ReadAll(res.Body)
+ st.Expect(t, err, nil)
+ st.Expect(t, body, []byte{})
+}
+
+func TestResponderBodyGenerator_Override(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ generator := func() io.ReadCloser {
+ return io.NopCloser(strings.NewReader("foo"))
+ }
+ mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator).BodyString("bar")
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Status, "200 OK")
+ st.Expect(t, res.StatusCode, 200)
+
+ body, _ := ioutil.ReadAll(res.Body)
+ st.Expect(t, string(body), "foo")
+
+ body, err = ioutil.ReadAll(res.Body)
+ st.Expect(t, err, nil)
+ st.Expect(t, body, []byte{})
+}
+
+func TestResponderSupportsMultipleHeadersWithSameKey(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo").
+ Reply(200).
+ AddHeader("Set-Cookie", "a=1").
+ AddHeader("Set-Cookie", "b=2")
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.Header, http.Header{"Set-Cookie": []string{"a=1", "b=2"}})
+}
+
+func TestResponderError(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").ReplyError(errors.New("error"))
+ req := &http.Request{}
+
+ res, err := Responder(req, mres, nil)
+ st.Expect(t, err.Error(), "error")
+ st.Expect(t, res == nil, true)
+}
+
+func TestResponderCancelledContext(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo")
+
+ // create a context and schedule a call to cancel in 10ms
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil)
+
+ res, err := Responder(req, mres, nil)
+
+ // verify that we got a context cancellation error and nil response
+ st.Expect(t, err, context.Canceled)
+ st.Expect(t, res == nil, true)
+}
+
+func TestResponderExpiredContext(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo")
+
+ // create a context that is set to expire in 10ms
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ defer cancel()
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil)
+
+ res, err := Responder(req, mres, nil)
+
+ // verify that we got a context cancellation error and nil response
+ st.Expect(t, err, context.DeadlineExceeded)
+ st.Expect(t, res == nil, true)
+}
+
+func TestResponderPreExpiredContext(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ mres := g.New("http://foo.com").Get("").Reply(200).BodyString("foo")
+
+ // create a context and wait to ensure it is expired
+ ctx, cancel := context.WithTimeout(context.Background(), 500*time.Microsecond)
+ defer cancel()
+ time.Sleep(1 * time.Millisecond)
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil)
+
+ res, err := Responder(req, mres, nil)
+
+ // verify that we got a context cancellation error and nil response
+ st.Expect(t, err, context.DeadlineExceeded)
+ st.Expect(t, res == nil, true)
+}
diff --git a/threadsafe/response.go b/threadsafe/response.go
new file mode 100644
index 0000000..04de096
--- /dev/null
+++ b/threadsafe/response.go
@@ -0,0 +1,198 @@
+package threadsafe
+
+import (
+ "bytes"
+ "encoding/json"
+ "encoding/xml"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "time"
+)
+
+// MapResponseFunc represents the required function interface impletemed by response mappers.
+type MapResponseFunc func(*http.Response) *http.Response
+
+// FilterResponseFunc represents the required function interface impletemed by response filters.
+type FilterResponseFunc func(*http.Response) bool
+
+// Response represents high-level HTTP fields to configure
+// and define HTTP responses intercepted by gock.
+type Response struct {
+ g *Gock
+
+ // Mock stores the parent mock reference for the current response mock used for method delegation.
+ Mock Mock
+
+ // Error stores the latest response configuration or injected error.
+ Error error
+
+ // UseNetwork enables the use of real network for the current mock.
+ UseNetwork bool
+
+ // StatusCode stores the response status code.
+ StatusCode int
+
+ // Headers stores the response headers.
+ Header http.Header
+
+ // Cookies stores the response cookie fields.
+ Cookies []*http.Cookie
+
+ // BodyGen stores a io.ReadCloser generator to be returned.
+ BodyGen func() io.ReadCloser
+
+ // BodyBuffer stores the array of bytes to use as body.
+ BodyBuffer []byte
+
+ // ResponseDelay stores the simulated response delay.
+ ResponseDelay time.Duration
+
+ // Mappers stores the request functions mappers used for matching.
+ Mappers []MapResponseFunc
+
+ // Filters stores the request functions filters used for matching.
+ Filters []FilterResponseFunc
+}
+
+// NewResponse creates a new Response.
+func (g *Gock) NewResponse() *Response {
+ return &Response{g: g, Header: make(http.Header)}
+}
+
+// Status defines the desired HTTP status code to reply in the current response.
+func (r *Response) Status(code int) *Response {
+ r.StatusCode = code
+ return r
+}
+
+// Type defines the response Content-Type MIME header field.
+// Supports type alias. E.g: json, xml, form, text...
+func (r *Response) Type(kind string) *Response {
+ mime := r.g.BodyTypeAliases[kind]
+ if mime != "" {
+ kind = mime
+ }
+ r.Header.Set("Content-Type", kind)
+ return r
+}
+
+// SetHeader sets a new header field in the mock response.
+func (r *Response) SetHeader(key, value string) *Response {
+ r.Header.Set(key, value)
+ return r
+}
+
+// AddHeader adds a new header field in the mock response
+// with out removing an existent one.
+func (r *Response) AddHeader(key, value string) *Response {
+ r.Header.Add(key, value)
+ return r
+}
+
+// SetHeaders sets a map of header fields in the mock response.
+func (r *Response) SetHeaders(headers map[string]string) *Response {
+ for key, value := range headers {
+ r.Header.Add(key, value)
+ }
+ return r
+}
+
+// Body sets the HTTP response body to be used.
+func (r *Response) Body(body io.Reader) *Response {
+ r.BodyBuffer, r.Error = ioutil.ReadAll(body)
+ return r
+}
+
+// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser
+// for every response. This will take priority than other Body methods used.
+func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response {
+ r.BodyGen = generator
+ return r
+}
+
+// BodyString defines the response body as string.
+func (r *Response) BodyString(body string) *Response {
+ r.BodyBuffer = []byte(body)
+ return r
+}
+
+// File defines the response body reading the data
+// from disk based on the file path string.
+func (r *Response) File(path string) *Response {
+ r.BodyBuffer, r.Error = ioutil.ReadFile(path)
+ return r
+}
+
+// JSON defines the response body based on a JSON based input.
+func (r *Response) JSON(data interface{}) *Response {
+ r.Header.Set("Content-Type", "application/json")
+ r.BodyBuffer, r.Error = readAndDecode(data, "json")
+ return r
+}
+
+// XML defines the response body based on a XML based input.
+func (r *Response) XML(data interface{}) *Response {
+ r.Header.Set("Content-Type", "application/xml")
+ r.BodyBuffer, r.Error = readAndDecode(data, "xml")
+ return r
+}
+
+// SetError defines the response simulated error.
+func (r *Response) SetError(err error) *Response {
+ r.Error = err
+ return r
+}
+
+// Delay defines the response simulated delay.
+// This feature is still experimental and will be improved in the future.
+func (r *Response) Delay(delay time.Duration) *Response {
+ r.ResponseDelay = delay
+ return r
+}
+
+// Map adds a new response mapper function to map http.Response before the matching process.
+func (r *Response) Map(fn MapResponseFunc) *Response {
+ r.Mappers = append(r.Mappers, fn)
+ return r
+}
+
+// Filter filters a new request filter function to filter http.Request before the matching process.
+func (r *Response) Filter(fn FilterResponseFunc) *Response {
+ r.Filters = append(r.Filters, fn)
+ return r
+}
+
+// EnableNetworking enables the use real networking for the current mock.
+func (r *Response) EnableNetworking() *Response {
+ r.UseNetwork = true
+ return r
+}
+
+// Done returns true if the mock was done and disabled.
+func (r *Response) Done() bool {
+ return r.Mock.Done()
+}
+
+func readAndDecode(data interface{}, kind string) ([]byte, error) {
+ buf := &bytes.Buffer{}
+
+ switch data.(type) {
+ case string:
+ buf.WriteString(data.(string))
+ case []byte:
+ buf.Write(data.([]byte))
+ default:
+ var err error
+ if kind == "xml" {
+ err = xml.NewEncoder(buf).Encode(data)
+ } else {
+ err = json.NewEncoder(buf).Encode(data)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return ioutil.ReadAll(buf)
+}
diff --git a/threadsafe/response_test.go b/threadsafe/response_test.go
new file mode 100644
index 0000000..0a44c63
--- /dev/null
+++ b/threadsafe/response_test.go
@@ -0,0 +1,186 @@
+package threadsafe
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/nbio/st"
+)
+
+func TestNewResponse(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+
+ res.Status(200)
+ st.Expect(t, res.StatusCode, 200)
+
+ res.SetHeader("foo", "bar")
+ st.Expect(t, res.Header.Get("foo"), "bar")
+
+ res.Delay(1000 * time.Millisecond)
+ st.Expect(t, res.ResponseDelay, 1000*time.Millisecond)
+
+ res.EnableNetworking()
+ st.Expect(t, res.UseNetwork, true)
+}
+
+func TestResponseStatus(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, res.StatusCode, 0)
+ res.Status(200)
+ st.Expect(t, res.StatusCode, 200)
+}
+
+func TestResponseType(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.Type("json")
+ st.Expect(t, res.Header.Get("Content-Type"), "application/json")
+
+ res = g.NewResponse()
+ res.Type("xml")
+ st.Expect(t, res.Header.Get("Content-Type"), "application/xml")
+
+ res = g.NewResponse()
+ res.Type("foo/bar")
+ st.Expect(t, res.Header.Get("Content-Type"), "foo/bar")
+}
+
+func TestResponseSetHeader(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.SetHeader("foo", "bar")
+ res.SetHeader("bar", "baz")
+ st.Expect(t, res.Header.Get("foo"), "bar")
+ st.Expect(t, res.Header.Get("bar"), "baz")
+}
+
+func TestResponseAddHeader(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.AddHeader("foo", "bar")
+ res.AddHeader("foo", "baz")
+ st.Expect(t, res.Header.Get("foo"), "bar")
+ st.Expect(t, res.Header["Foo"][1], "baz")
+}
+
+func TestResponseSetHeaders(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.SetHeaders(map[string]string{"foo": "bar", "bar": "baz"})
+ st.Expect(t, res.Header.Get("foo"), "bar")
+ st.Expect(t, res.Header.Get("bar"), "baz")
+}
+
+func TestResponseBody(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.Body(bytes.NewBuffer([]byte("foo bar")))
+ st.Expect(t, string(res.BodyBuffer), "foo bar")
+}
+
+func TestResponseBodyGenerator(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ generator := func() io.ReadCloser {
+ return io.NopCloser(bytes.NewBuffer([]byte("foo bar")))
+ }
+ res.BodyGenerator(generator)
+ bytes, err := io.ReadAll(res.BodyGen())
+ st.Expect(t, err, nil)
+ st.Expect(t, string(bytes), "foo bar")
+}
+
+func TestResponseBodyString(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.BodyString("foo bar")
+ st.Expect(t, string(res.BodyBuffer), "foo bar")
+}
+
+func TestResponseFile(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ absPath, err := filepath.Abs("../version.go")
+ st.Expect(t, err, nil)
+ res.File(absPath)
+ st.Expect(t, string(res.BodyBuffer)[:12], "package gock")
+}
+
+func TestResponseJSON(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.JSON(map[string]string{"foo": "bar"})
+ st.Expect(t, string(res.BodyBuffer)[:13], `{"foo":"bar"}`)
+ st.Expect(t, res.Header.Get("Content-Type"), "application/json")
+}
+
+func TestResponseXML(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ type xml struct {
+ Data string `xml:"data"`
+ }
+ res.XML(xml{Data: "foo"})
+ st.Expect(t, string(res.BodyBuffer), `foo`)
+ st.Expect(t, res.Header.Get("Content-Type"), "application/xml")
+}
+
+func TestResponseMap(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, len(res.Mappers), 0)
+ res.Map(func(res *http.Response) *http.Response {
+ return res
+ })
+ st.Expect(t, len(res.Mappers), 1)
+}
+
+func TestResponseFilter(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, len(res.Filters), 0)
+ res.Filter(func(res *http.Response) bool {
+ return true
+ })
+ st.Expect(t, len(res.Filters), 1)
+}
+
+func TestResponseSetError(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, res.Error, nil)
+ res.SetError(errors.New("foo error"))
+ st.Expect(t, res.Error.Error(), "foo error")
+}
+
+func TestResponseDelay(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, res.ResponseDelay, 0*time.Microsecond)
+ res.Delay(100 * time.Millisecond)
+ st.Expect(t, res.ResponseDelay, 100*time.Millisecond)
+}
+
+func TestResponseEnableNetworking(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ st.Expect(t, res.UseNetwork, false)
+ res.EnableNetworking()
+ st.Expect(t, res.UseNetwork, true)
+}
+
+func TestResponseDone(t *testing.T) {
+ g := NewGock()
+ res := g.NewResponse()
+ res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)}
+ st.Expect(t, res.Done(), false)
+ res.Mock.Disable()
+ st.Expect(t, res.Done(), true)
+}
diff --git a/threadsafe/store.go b/threadsafe/store.go
new file mode 100644
index 0000000..d22a02e
--- /dev/null
+++ b/threadsafe/store.go
@@ -0,0 +1,90 @@
+package threadsafe
+
+// Register registers a new mock in the current mocks stack.
+func (g *Gock) Register(mock Mock) {
+ if g.Exists(mock) {
+ return
+ }
+
+ // Make ops thread safe
+ g.storeMutex.Lock()
+ defer g.storeMutex.Unlock()
+
+ // Expose mock in request/response for delegation
+ mock.Request().Mock = mock
+ mock.Response().Mock = mock
+
+ // Registers the mock in the global store
+ g.mocks = append(g.mocks, mock)
+}
+
+// GetAll returns the current stack of registered mocks.
+func (g *Gock) GetAll() []Mock {
+ g.storeMutex.RLock()
+ defer g.storeMutex.RUnlock()
+ return g.mocks
+}
+
+// Exists checks if the given Mock is already registered.
+func (g *Gock) Exists(m Mock) bool {
+ g.storeMutex.RLock()
+ defer g.storeMutex.RUnlock()
+ for _, mock := range g.mocks {
+ if mock == m {
+ return true
+ }
+ }
+ return false
+}
+
+// Remove removes a registered mock by reference.
+func (g *Gock) Remove(m Mock) {
+ for i, mock := range g.mocks {
+ if mock == m {
+ g.storeMutex.Lock()
+ g.mocks = append(g.mocks[:i], g.mocks[i+1:]...)
+ g.storeMutex.Unlock()
+ }
+ }
+}
+
+// Flush flushes the current stack of registered mocks.
+func (g *Gock) Flush() {
+ g.storeMutex.Lock()
+ defer g.storeMutex.Unlock()
+ g.mocks = []Mock{}
+}
+
+// Pending returns an slice of pending mocks.
+func (g *Gock) Pending() []Mock {
+ g.Clean()
+ g.storeMutex.RLock()
+ defer g.storeMutex.RUnlock()
+ return g.mocks
+}
+
+// IsDone returns true if all the registered mocks has been triggered successfully.
+func (g *Gock) IsDone() bool {
+ return !g.IsPending()
+}
+
+// IsPending returns true if there are pending mocks.
+func (g *Gock) IsPending() bool {
+ return len(g.Pending()) > 0
+}
+
+// Clean cleans the mocks store removing disabled or obsolete mocks.
+func (g *Gock) Clean() {
+ g.storeMutex.Lock()
+ defer g.storeMutex.Unlock()
+
+ buf := []Mock{}
+ for _, mock := range g.mocks {
+ if mock.Done() {
+ continue
+ }
+ buf = append(buf, mock)
+ }
+
+ g.mocks = buf
+}
diff --git a/threadsafe/store_test.go b/threadsafe/store_test.go
new file mode 100644
index 0000000..e4081bb
--- /dev/null
+++ b/threadsafe/store_test.go
@@ -0,0 +1,95 @@
+package threadsafe
+
+import (
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestStoreRegister(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ st.Expect(t, len(g.mocks), 0)
+ mock := g.New("foo").Mock
+ g.Register(mock)
+ st.Expect(t, len(g.mocks), 1)
+ st.Expect(t, mock.Request().Mock, mock)
+ st.Expect(t, mock.Response().Mock, mock)
+}
+
+func TestStoreGetAll(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ st.Expect(t, len(g.mocks), 0)
+ mock := g.New("foo").Mock
+ store := g.GetAll()
+ st.Expect(t, len(g.mocks), 1)
+ st.Expect(t, len(store), 1)
+ st.Expect(t, store[0], mock)
+}
+
+func TestStoreExists(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ st.Expect(t, len(g.mocks), 0)
+ mock := g.New("foo").Mock
+ st.Expect(t, len(g.mocks), 1)
+ st.Expect(t, g.Exists(mock), true)
+}
+
+func TestStorePending(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("foo")
+ st.Expect(t, g.mocks, g.Pending())
+}
+
+func TestStoreIsPending(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("foo")
+ st.Expect(t, g.IsPending(), true)
+ g.Flush()
+ st.Expect(t, g.IsPending(), false)
+}
+
+func TestStoreIsDone(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("foo")
+ st.Expect(t, g.IsDone(), false)
+ g.Flush()
+ st.Expect(t, g.IsDone(), true)
+}
+
+func TestStoreRemove(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ st.Expect(t, len(g.mocks), 0)
+ mock := g.New("foo").Mock
+ st.Expect(t, len(g.mocks), 1)
+ st.Expect(t, g.Exists(mock), true)
+
+ g.Remove(mock)
+ st.Expect(t, g.Exists(mock), false)
+
+ g.Remove(mock)
+ st.Expect(t, g.Exists(mock), false)
+}
+
+func TestStoreFlush(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ st.Expect(t, len(g.mocks), 0)
+
+ mock1 := g.New("foo").Mock
+ mock2 := g.New("foo").Mock
+ st.Expect(t, len(g.mocks), 2)
+ st.Expect(t, g.Exists(mock1), true)
+ st.Expect(t, g.Exists(mock2), true)
+
+ g.Flush()
+ st.Expect(t, len(g.mocks), 0)
+ st.Expect(t, g.Exists(mock1), false)
+ st.Expect(t, g.Exists(mock2), false)
+}
diff --git a/threadsafe/transport.go b/threadsafe/transport.go
new file mode 100644
index 0000000..ee1af5e
--- /dev/null
+++ b/threadsafe/transport.go
@@ -0,0 +1,112 @@
+package threadsafe
+
+import (
+ "errors"
+ "net/http"
+ "sync"
+)
+
+var (
+ // ErrCannotMatch store the error returned in case of no matches.
+ ErrCannotMatch = errors.New("gock: cannot match any request")
+)
+
+// Transport implements http.RoundTripper, which fulfills single http requests issued by
+// an http.Client.
+//
+// gock's Transport encapsulates a given or default http.Transport for further
+// delegation, if needed.
+type Transport struct {
+ g *Gock
+
+ // mutex is used to make transport thread-safe of concurrent uses across goroutines.
+ mutex sync.Mutex
+
+ // Transport encapsulates the original http.RoundTripper transport interface for delegation.
+ Transport http.RoundTripper
+}
+
+// NewTransport creates a new *Transport with no responders.
+func (g *Gock) NewTransport(transport http.RoundTripper) *Transport {
+ return &Transport{g: g, Transport: transport}
+}
+
+// transport is used to always return a non-nil transport. This is the same as `(http.Client).transport`, and is what
+// would be invoked if gock's transport were not present.
+func (m *Transport) transport() http.RoundTripper {
+ if m.Transport != nil {
+ return m.Transport
+ }
+ return http.DefaultTransport
+}
+
+// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to
+// implement the http.RoundTripper interface. You will not interact with this directly, instead
+// the *http.Client you are using will call it for you.
+func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Just act as a proxy if not intercepting
+ if !m.g.Intercepting() {
+ return m.transport().RoundTrip(req)
+ }
+
+ m.mutex.Lock()
+ defer m.g.Clean()
+
+ var err error
+ var res *http.Response
+
+ // Match mock for the incoming http.Request
+ mock, err := m.g.MatchMock(req)
+ if err != nil {
+ m.mutex.Unlock()
+ return nil, err
+ }
+
+ // Invoke the observer with the intercepted http.Request and matched mock
+ if m.g.config.Observer != nil {
+ m.g.config.Observer(req, mock)
+ }
+
+ // Verify if should use real networking
+ networking := shouldUseNetwork(m.g, req, mock)
+ if !networking && mock == nil {
+ m.mutex.Unlock()
+ m.g.trackUnmatchedRequest(req)
+ return nil, ErrCannotMatch
+ }
+
+ // Ensure me unlock the mutex before building the response
+ m.mutex.Unlock()
+
+ // Perform real networking via original transport
+ if networking {
+ res, err = m.transport().RoundTrip(req)
+ // In no mock matched, continue with the response
+ if err != nil || mock == nil {
+ return res, err
+ }
+ }
+
+ return Responder(req, mock.Response(), res)
+}
+
+// CancelRequest is a no-op function.
+func (m *Transport) CancelRequest(req *http.Request) {}
+
+func shouldUseNetwork(g *Gock, req *http.Request, mock Mock) bool {
+ if mock != nil && mock.Response().UseNetwork {
+ return true
+ }
+ if !g.config.Networking {
+ return false
+ }
+ if len(g.config.NetworkingFilters) == 0 {
+ return true
+ }
+ for _, filter := range g.config.NetworkingFilters {
+ if !filter(req) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/threadsafe/transport_test.go b/threadsafe/transport_test.go
new file mode 100644
index 0000000..5215da6
--- /dev/null
+++ b/threadsafe/transport_test.go
@@ -0,0 +1,55 @@
+package threadsafe
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/nbio/st"
+)
+
+func TestTransportMatch(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ const uri = "http://foo.com"
+ g.New(uri).Reply(204)
+ u, _ := url.Parse(uri)
+ req := &http.Request{URL: u}
+ res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, res.StatusCode, 204)
+ st.Expect(t, res.Request, req)
+}
+
+func TestTransportCannotMatch(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+ g.New("http://foo.com").Reply(204)
+ u, _ := url.Parse("http://127.0.0.1:1234")
+ req := &http.Request{URL: u}
+ _, err := g.NewTransport(http.DefaultTransport).RoundTrip(req)
+ st.Expect(t, err, ErrCannotMatch)
+}
+
+func TestTransportNotIntercepting(t *testing.T) {
+ g := NewGock()
+ defer after(g)
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, world")
+ }))
+ defer ts.Close()
+
+ g.New(ts.URL).Reply(200)
+ g.Disable()
+
+ u, _ := url.Parse(ts.URL)
+ req := &http.Request{URL: u, Header: make(http.Header)}
+
+ res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req)
+ st.Expect(t, err, nil)
+ st.Expect(t, g.Intercepting(), false)
+ st.Expect(t, res.StatusCode, 200)
+}
diff --git a/transport.go b/transport.go
index 5b2bba2..985e3ad 100644
--- a/transport.go
+++ b/transport.go
@@ -1,9 +1,9 @@
package gock
import (
- "errors"
"net/http"
- "sync"
+
+ "github.com/h2non/gock/threadsafe"
)
// var mutex *sync.Mutex = &sync.Mutex{}
@@ -19,7 +19,7 @@ var (
var (
// ErrCannotMatch store the error returned in case of no matches.
- ErrCannotMatch = errors.New("gock: cannot match any request")
+ ErrCannotMatch = threadsafe.ErrCannotMatch
)
// Transport implements http.RoundTripper, which fulfills single http requests issued by
@@ -27,86 +27,9 @@ var (
//
// gock's Transport encapsulates a given or default http.Transport for further
// delegation, if needed.
-type Transport struct {
- // mutex is used to make transport thread-safe of concurrent uses across goroutines.
- mutex sync.Mutex
-
- // Transport encapsulates the original http.RoundTripper transport interface for delegation.
- Transport http.RoundTripper
-}
+type Transport = threadsafe.Transport
// NewTransport creates a new *Transport with no responders.
func NewTransport() *Transport {
- return &Transport{Transport: NativeTransport}
-}
-
-// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to
-// implement the http.RoundTripper interface. You will not interact with this directly, instead
-// the *http.Client you are using will call it for you.
-func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
- // Just act as a proxy if not intercepting
- if !Intercepting() {
- return m.Transport.RoundTrip(req)
- }
-
- m.mutex.Lock()
- defer Clean()
-
- var err error
- var res *http.Response
-
- // Match mock for the incoming http.Request
- mock, err := MatchMock(req)
- if err != nil {
- m.mutex.Unlock()
- return nil, err
- }
-
- // Invoke the observer with the intercepted http.Request and matched mock
- if config.Observer != nil {
- config.Observer(req, mock)
- }
-
- // Verify if should use real networking
- networking := shouldUseNetwork(req, mock)
- if !networking && mock == nil {
- m.mutex.Unlock()
- trackUnmatchedRequest(req)
- return nil, ErrCannotMatch
- }
-
- // Ensure me unlock the mutex before building the response
- m.mutex.Unlock()
-
- // Perform real networking via original transport
- if networking {
- res, err = m.Transport.RoundTrip(req)
- // In no mock matched, continue with the response
- if err != nil || mock == nil {
- return res, err
- }
- }
-
- return Responder(req, mock.Response(), res)
-}
-
-// CancelRequest is a no-op function.
-func (m *Transport) CancelRequest(req *http.Request) {}
-
-func shouldUseNetwork(req *http.Request, mock Mock) bool {
- if mock != nil && mock.Response().UseNetwork {
- return true
- }
- if !config.Networking {
- return false
- }
- if len(config.NetworkingFilters) == 0 {
- return true
- }
- for _, filter := range config.NetworkingFilters {
- if !filter(req) {
- return false
- }
- }
- return true
+ return g.NewTransport(NativeTransport)
}