diff --git a/gock.go b/gock.go index c3a8333..5f1e4c8 100644 --- a/gock.go +++ b/gock.go @@ -1,57 +1,43 @@ package gock import ( - "fmt" "net/http" - "net/http/httputil" - "net/url" - "regexp" "sync" + + "github.com/h2non/gock/threadsafe" ) +var g = threadsafe.NewGock() + +func init() { + g.DisableCallback = disable + g.InterceptCallback = intercept + g.InterceptingCallback = intercepting +} + // mutex is used interally for locking thread-sensitive functions. var mutex = &sync.Mutex{} -// config global singleton store. -var config = struct { - Networking bool - NetworkingFilters []FilterRequestFunc - Observer ObserverFunc -}{} - // ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic -type ObserverFunc func(*http.Request, Mock) +type ObserverFunc = threadsafe.ObserverFunc // DumpRequest is a default implementation of ObserverFunc that dumps // the HTTP/1.x wire representation of the http request -var DumpRequest ObserverFunc = func(request *http.Request, mock Mock) { - bytes, _ := httputil.DumpRequestOut(request, true) - fmt.Println(string(bytes)) - fmt.Printf("\nMatches: %v\n---\n", mock != nil) -} - -// track unmatched requests so they can be tested for -var unmatchedRequests = []*http.Request{} +var DumpRequest = g.DumpRequest // New creates and registers a new HTTP mock with // default settings and returns the Request DSL for HTTP mock // definition and set up. func New(uri string) *Request { - Intercept() - - res := NewResponse() - req := NewRequest() - req.URLStruct, res.Error = url.Parse(normalizeURI(uri)) - - // Create the new mock expectation - exp := NewMock(req, res) - Register(exp) - - return req + return g.New(uri) } // Intercepting returns true if gock is currently able to intercept. func Intercepting() bool { + return g.Intercepting() +} + +func intercepting() bool { mutex.Lock() defer mutex.Unlock() return http.DefaultTransport == DefaultTransport @@ -60,37 +46,33 @@ func Intercepting() bool { // Intercept enables HTTP traffic interception via http.DefaultTransport. // If you are using a custom HTTP transport, you have to use `gock.Transport()` func Intercept() { - if !Intercepting() { - mutex.Lock() - http.DefaultTransport = DefaultTransport - mutex.Unlock() - } + g.Intercept() +} + +func intercept() { + mutex.Lock() + http.DefaultTransport = DefaultTransport + mutex.Unlock() } // InterceptClient allows the developer to intercept HTTP traffic using // a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation. func InterceptClient(cli *http.Client) { - _, ok := cli.Transport.(*Transport) - if ok { - return // if transport already intercepted, just ignore it - } - trans := NewTransport() - trans.Transport = cli.Transport - cli.Transport = trans + g.InterceptClient(cli) } // RestoreClient allows the developer to disable and restore the // original transport in the given http.Client. func RestoreClient(cli *http.Client) { - trans, ok := cli.Transport.(*Transport) - if !ok { - return - } - cli.Transport = trans.Transport + g.RestoreClient(cli) } // Disable disables HTTP traffic interception by gock. func Disable() { + g.Disable() +} + +func disable() { mutex.Lock() defer mutex.Unlock() http.DefaultTransport = NativeTransport @@ -99,80 +81,50 @@ func Disable() { // Off disables the default HTTP interceptors and removes // all the registered mocks, even if they has not been intercepted yet. func Off() { - Flush() - Disable() + g.Off() } // OffAll is like `Off()`, but it also removes the unmatched requests registry. func OffAll() { - Flush() - Disable() - CleanUnmatchedRequest() + g.OffAll() } // Observe provides a hook to support inspection of the request and matched mock func Observe(fn ObserverFunc) { - mutex.Lock() - defer mutex.Unlock() - config.Observer = fn + g.Observe(fn) } // EnableNetworking enables real HTTP networking func EnableNetworking() { - mutex.Lock() - defer mutex.Unlock() - config.Networking = true + g.EnableNetworking() } // DisableNetworking disables real HTTP networking func DisableNetworking() { - mutex.Lock() - defer mutex.Unlock() - config.Networking = false + g.DisableNetworking() } // NetworkingFilter determines if an http.Request should be triggered or not. func NetworkingFilter(fn FilterRequestFunc) { - mutex.Lock() - defer mutex.Unlock() - config.NetworkingFilters = append(config.NetworkingFilters, fn) + g.NetworkingFilter(fn) } // DisableNetworkingFilters disables registered networking filters. func DisableNetworkingFilters() { - mutex.Lock() - defer mutex.Unlock() - config.NetworkingFilters = []FilterRequestFunc{} + g.DisableNetworkingFilters() } // GetUnmatchedRequests returns all requests that have been received but haven't matched any mock func GetUnmatchedRequests() []*http.Request { - mutex.Lock() - defer mutex.Unlock() - return unmatchedRequests + return g.GetUnmatchedRequests() } // HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock func HasUnmatchedRequest() bool { - return len(GetUnmatchedRequests()) > 0 + return g.HasUnmatchedRequest() } // CleanUnmatchedRequest cleans the unmatched requests internal registry. func CleanUnmatchedRequest() { - mutex.Lock() - defer mutex.Unlock() - unmatchedRequests = []*http.Request{} -} - -func trackUnmatchedRequest(req *http.Request) { - mutex.Lock() - defer mutex.Unlock() - unmatchedRequests = append(unmatchedRequests, req) -} - -func normalizeURI(uri string) string { - if ok, _ := regexp.MatchString("^http[s]?", uri); !ok { - return "http://" + uri - } - return uri + g.CleanUnmatchedRequest() } diff --git a/gock_test.go b/gock_test.go index 7df68fe..8ad128b 100644 --- a/gock_test.go +++ b/gock_test.go @@ -304,7 +304,7 @@ func TestUnmatched(t *testing.T) { defer after() // clear out any unmatchedRequests from other tests - unmatchedRequests = []*http.Request{} + CleanUnmatchedRequest() Intercept() diff --git a/matcher.go b/matcher.go index 11a1d7e..633734f 100644 --- a/matcher.go +++ b/matcher.go @@ -1,137 +1,75 @@ package gock -import "net/http" +import ( + "net/http" + + "github.com/h2non/gock/threadsafe" +) // MatchersHeader exposes an slice of HTTP header specific mock matchers. -var MatchersHeader = []MatchFunc{ - MatchMethod, - MatchScheme, - MatchHost, - MatchPath, - MatchHeaders, - MatchQueryParams, - MatchPathParams, +func MatchersHeader() []MatchFunc { + return g.MatchersHeader +} + +func SetMatchersHeader(matchers []MatchFunc) { + g.MatchersHeader = matchers } // MatchersBody exposes an slice of HTTP body specific built-in mock matchers. -var MatchersBody = []MatchFunc{ - MatchBody, +func MatchersBody() []MatchFunc { + return g.MatchersBody +} + +func SetMatchersBody(matchers []MatchFunc) { + g.MatchersBody = matchers } // Matchers stores all the built-in mock matchers. -var Matchers = append(MatchersHeader, MatchersBody...) +func Matchers() []MatchFunc { + return g.Matchers +} + +func SetMatchers(matchers []MatchFunc) { + g.Matchers = matchers +} // DefaultMatcher stores the default Matcher instance used to match mocks. -var DefaultMatcher = NewMatcher() +func DefaultMatcher() *MockMatcher { + return g.DefaultMatcher +} + +func SetDefaultMatcher(matcher *MockMatcher) { + g.DefaultMatcher = matcher +} // MatchFunc represents the required function // interface implemented by matchers. -type MatchFunc func(*http.Request, *Request) (bool, error) +type MatchFunc = threadsafe.MatchFunc // Matcher represents the required interface implemented by mock matchers. -type Matcher interface { - // Get returns a slice of registered function matchers. - Get() []MatchFunc - - // Add adds a new matcher function. - Add(MatchFunc) - - // Set sets the matchers functions stack. - Set([]MatchFunc) - - // Flush flushes the current matchers function stack. - Flush() - - // Match matches the given http.Request with a mock Request. - Match(*http.Request, *Request) (bool, error) -} +type Matcher = threadsafe.Matcher // MockMatcher implements a mock matcher -type MockMatcher struct { - Matchers []MatchFunc -} +type MockMatcher = threadsafe.MockMatcher // NewMatcher creates a new mock matcher // using the default matcher functions. func NewMatcher() *MockMatcher { - m := NewEmptyMatcher() - for _, matchFn := range Matchers { - m.Add(matchFn) - } - return m + return g.NewMatcher() } // NewBasicMatcher creates a new matcher with header only mock matchers. func NewBasicMatcher() *MockMatcher { - m := NewEmptyMatcher() - for _, matchFn := range MatchersHeader { - m.Add(matchFn) - } - return m + return g.NewBasicMatcher() } // NewEmptyMatcher creates a new empty matcher without default matchers. func NewEmptyMatcher() *MockMatcher { - return &MockMatcher{Matchers: []MatchFunc{}} -} - -// Get returns a slice of registered function matchers. -func (m *MockMatcher) Get() []MatchFunc { - mutex.Lock() - defer mutex.Unlock() - return m.Matchers -} - -// Add adds a new function matcher. -func (m *MockMatcher) Add(fn MatchFunc) { - m.Matchers = append(m.Matchers, fn) -} - -// Set sets a new stack of matchers functions. -func (m *MockMatcher) Set(stack []MatchFunc) { - m.Matchers = stack -} - -// Flush flushes the current matcher -func (m *MockMatcher) Flush() { - m.Matchers = []MatchFunc{} -} - -// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs -func (m *MockMatcher) Clone() *MockMatcher { - m2 := NewEmptyMatcher() - for _, mFn := range m.Get() { - m2.Add(mFn) - } - return m2 -} - -// Match matches the given http.Request with a mock request -// returning true in case that the request matches, otherwise false. -func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) { - for _, matcher := range m.Matchers { - matches, err := matcher(req, ereq) - if err != nil { - return false, err - } - if !matches { - return false, nil - } - } - return true, nil + return g.NewEmptyMatcher() } // MatchMock is a helper function that matches the given http.Request // in the list of registered mocks, returning it if matches or error if it fails. func MatchMock(req *http.Request) (Mock, error) { - for _, mock := range GetAll() { - matches, err := mock.Match(req) - if err != nil { - return nil, err - } - if matches { - return mock, nil - } - } - return nil, nil + return g.MatchMock(req) } diff --git a/matcher_test.go b/matcher_test.go index d96c00c..c7e842c 100644 --- a/matcher_test.go +++ b/matcher_test.go @@ -9,24 +9,24 @@ import ( ) func TestRegisteredMatchers(t *testing.T) { - st.Expect(t, len(MatchersHeader), 7) - st.Expect(t, len(MatchersBody), 1) + st.Expect(t, len(MatchersHeader()), 7) + st.Expect(t, len(MatchersBody()), 1) } func TestNewMatcher(t *testing.T) { matcher := NewMatcher() // Funcs are not comparable, checking slice length as it's better than nothing // See https://golang.org/pkg/reflect/#DeepEqual - st.Expect(t, len(matcher.Matchers), len(Matchers)) - st.Expect(t, len(matcher.Get()), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) + st.Expect(t, len(matcher.Get()), len(Matchers())) } func TestNewBasicMatcher(t *testing.T) { matcher := NewBasicMatcher() // Funcs are not comparable, checking slice length as it's better than nothing // See https://golang.org/pkg/reflect/#DeepEqual - st.Expect(t, len(matcher.Matchers), len(MatchersHeader)) - st.Expect(t, len(matcher.Get()), len(MatchersHeader)) + st.Expect(t, len(matcher.Matchers), len(MatchersHeader())) + st.Expect(t, len(matcher.Get()), len(MatchersHeader())) } func TestNewEmptyMatcher(t *testing.T) { @@ -37,17 +37,17 @@ func TestNewEmptyMatcher(t *testing.T) { func TestMatcherAdd(t *testing.T) { matcher := NewMatcher() - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, len(matcher.Get()), len(Matchers)+1) + st.Expect(t, len(matcher.Get()), len(Matchers())+1) } func TestMatcherSet(t *testing.T) { matcher := NewMatcher() matchers := []MatchFunc{} - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Set(matchers) st.Expect(t, matcher.Matchers, matchers) st.Expect(t, len(matcher.Get()), 0) @@ -62,18 +62,18 @@ func TestMatcherGet(t *testing.T) { func TestMatcherFlush(t *testing.T) { matcher := NewMatcher() - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, len(matcher.Get()), len(Matchers)+1) + st.Expect(t, len(matcher.Get()), len(Matchers())+1) matcher.Flush() st.Expect(t, len(matcher.Get()), 0) } func TestMatcherClone(t *testing.T) { - matcher := DefaultMatcher.Clone() - st.Expect(t, len(matcher.Get()), len(DefaultMatcher.Get())) + matcher := DefaultMatcher().Clone() + st.Expect(t, len(matcher.Get()), len(DefaultMatcher().Get())) } func TestMatcher(t *testing.T) { @@ -115,19 +115,20 @@ func TestMatcher(t *testing.T) { func TestMatchMock(t *testing.T) { cases := []struct { - method string - url string - matches bool + method string + methodFn func(r *Request, path string) *Request + url string + matches bool }{ - {"GET", "http://foo.com/bar", true}, - {"GET", "http://foo.com/baz", true}, - {"GET", "http://foo.com/foo", false}, - {"POST", "http://foo.com/bar", false}, - {"POST", "http://bar.com/bar", false}, - {"GET", "http://foo.com", false}, + {"GET", (*Request).Get, "http://foo.com/bar", true}, + {"GET", (*Request).Get, "http://foo.com/baz", true}, + {"GET", (*Request).Get, "http://foo.com/foo", false}, + {"POST", (*Request).Post, "http://foo.com/bar", false}, + {"POST", (*Request).Post, "http://bar.com/bar", false}, + {"GET", (*Request).Get, "http://foo.com", false}, } - matcher := DefaultMatcher + matcher := DefaultMatcher() matcher.Flush() st.Expect(t, len(matcher.Matchers), 0) @@ -143,7 +144,7 @@ func TestMatchMock(t *testing.T) { for _, test := range cases { Flush() - mock := New(test.url).method(test.method, "").Mock + mock := test.methodFn(New(test.url), "").Mock u, _ := url.Parse(test.url) req := &http.Request{Method: test.method, URL: u} @@ -157,5 +158,5 @@ func TestMatchMock(t *testing.T) { } } - DefaultMatcher.Matchers = Matchers + DefaultMatcher().Matchers = Matchers() } diff --git a/matchers.go b/matchers.go index 658c9a6..8c3ac77 100644 --- a/matchers.go +++ b/matchers.go @@ -1,266 +1,79 @@ package gock import ( - "compress/gzip" - "encoding/json" - "io" - "io/ioutil" "net/http" - "reflect" - "regexp" - "strings" - "github.com/h2non/parth" + "github.com/h2non/gock/threadsafe" ) // EOL represents the end of line character. -const EOL = 0xa +const EOL = threadsafe.EOL // BodyTypes stores the supported MIME body types for matching. // Currently only text-based types. -var BodyTypes = []string{ - "text/html", - "text/plain", - "application/json", - "application/xml", - "multipart/form-data", - "application/x-www-form-urlencoded", +func BodyTypes() []string { + return g.BodyTypes +} + +func SetBodyTypes(types []string) { + g.BodyTypes = types } // BodyTypeAliases stores a generic MIME type by alias. -var BodyTypeAliases = map[string]string{ - "html": "text/html", - "text": "text/plain", - "json": "application/json", - "xml": "application/xml", - "form": "multipart/form-data", - "url": "application/x-www-form-urlencoded", +func BodyTypeAliases() map[string]string { + return g.BodyTypeAliases +} + +func SetBodyTypeAliases(aliases map[string]string) { + g.BodyTypeAliases = aliases } // CompressionSchemes stores the supported Content-Encoding types for decompression. -var CompressionSchemes = []string{ - "gzip", +func CompressionSchemes() []string { + return g.CompressionSchemes +} + +func SetCompressionSchemes(schemes []string) { + g.CompressionSchemes = schemes } // MatchMethod matches the HTTP method of the given request. func MatchMethod(req *http.Request, ereq *Request) (bool, error) { - return ereq.Method == "" || req.Method == ereq.Method, nil + return g.MatchMethod(req, ereq) } // MatchScheme matches the request URL protocol scheme. func MatchScheme(req *http.Request, ereq *Request) (bool, error) { - return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil + return g.MatchScheme(req, ereq) } // MatchHost matches the HTTP host header field of the given request. func MatchHost(req *http.Request, ereq *Request) (bool, error) { - url := ereq.URLStruct - if strings.EqualFold(url.Host, req.URL.Host) { - return true, nil - } - if !ereq.Options.DisableRegexpHost { - return regexp.MatchString(url.Host, req.URL.Host) - } - return false, nil + return g.MatchHost(req, ereq) } // MatchPath matches the HTTP URL path of the given request. func MatchPath(req *http.Request, ereq *Request) (bool, error) { - if req.URL.Path == ereq.URLStruct.Path { - return true, nil - } - return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path) + return g.MatchPath(req, ereq) } // MatchHeaders matches the headers fields of the given request. func MatchHeaders(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.Header { - var err error - var match bool - var matchEscaped bool - - for _, field := range req.Header[key] { - match, err = regexp.MatchString(value[0], field) - // Some values may contain reserved regex params e.g. "()", try matching with these escaped. - matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field) - - if err != nil { - return false, err - } - if match || matchEscaped { - break - } - - } - - if !match && !matchEscaped { - return false, nil - } - } - return true, nil + return g.MatchHeaders(req, ereq) } // MatchQueryParams matches the URL query params fields of the given request. func MatchQueryParams(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.URLStruct.Query() { - var err error - var match bool - - for _, field := range req.URL.Query()[key] { - match, err = regexp.MatchString(value[0], field) - if err != nil { - return false, err - } - if match { - break - } - } - - if !match { - return false, nil - } - } - return true, nil + return g.MatchQueryParams(req, ereq) } // MatchPathParams matches the URL path parameters of the given request. func MatchPathParams(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.PathParams { - var s string - - if err := parth.Sequent(req.URL.Path, key, &s); err != nil { - return false, nil - } - - if s != value { - return false, nil - } - } - return true, nil + return g.MatchPathParams(req, ereq) } // MatchBody tries to match the request body. // TODO: not too smart now, needs several improvements. func MatchBody(req *http.Request, ereq *Request) (bool, error) { - // If match body is empty, just continue - if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 { - return true, nil - } - - // Only can match certain MIME body types - if !supportedType(req, ereq) { - return false, nil - } - - // Can only match certain compression schemes - if !supportedCompressionScheme(req) { - return false, nil - } - - // Create a reader for the body depending on compression type - bodyReader := req.Body - if ereq.CompressionScheme != "" { - if ereq.CompressionScheme != req.Header.Get("Content-Encoding") { - return false, nil - } - compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme) - if err != nil { - return false, err - } - bodyReader = compressedBodyReader - } - - // Read the whole request body - body, err := ioutil.ReadAll(bodyReader) - if err != nil { - return false, err - } - - // Restore body reader stream - req.Body = createReadCloser(body) - - // If empty, ignore the match - if len(body) == 0 && len(ereq.BodyBuffer) != 0 { - return false, nil - } - - // Match body by atomic string comparison - bodyStr := castToString(body) - matchStr := castToString(ereq.BodyBuffer) - if bodyStr == matchStr { - return true, nil - } - - // Match request body by regexp - match, _ := regexp.MatchString(matchStr, bodyStr) - if match == true { - return true, nil - } - - // todo - add conditional do only perform the conversion of body bytes - // representation of JSON to a map and then compare them for equality. - - // Check if the key + value pairs match - var bodyMap map[string]interface{} - var matchMap map[string]interface{} - - // Ensure that both byte bodies that that should be JSON can be converted to maps. - umErr := json.Unmarshal(body, &bodyMap) - umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap) - if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) { - return true, nil - } - - return false, nil -} - -func supportedType(req *http.Request, ereq *Request) bool { - mime := req.Header.Get("Content-Type") - if mime == "" { - return true - } - - mimeToMatch := ereq.Header.Get("Content-Type") - if mimeToMatch != "" { - return mime == mimeToMatch - } - - for _, kind := range BodyTypes { - if match, _ := regexp.MatchString(kind, mime); match { - return true - } - } - return false -} - -func supportedCompressionScheme(req *http.Request) bool { - encoding := req.Header.Get("Content-Encoding") - if encoding == "" { - return true - } - - for _, kind := range CompressionSchemes { - if match, _ := regexp.MatchString(kind, encoding); match { - return true - } - } - return false -} - -func castToString(buf []byte) string { - str := string(buf) - tail := len(str) - 1 - if str[tail] == EOL { - str = str[:tail] - } - return str -} - -func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) { - switch scheme { - case "gzip": - return gzip.NewReader(r) - default: - return r, nil - } + return g.MatchBody(req, ereq) } diff --git a/matchers_test.go b/matchers_test.go index 56aaa01..cbe30d6 100644 --- a/matchers_test.go +++ b/matchers_test.go @@ -1,6 +1,9 @@ package gock import ( + "bytes" + "io" + "io/ioutil" "net/http" "net/url" "testing" @@ -249,3 +252,9 @@ func TestMatchBody_MatchType(t *testing.T) { st.Expect(t, matches, test.matches) } } + +// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an +// http response body. +func createReadCloser(body []byte) io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader(body)) +} diff --git a/mock.go b/mock.go index d28875b..aa388ff 100644 --- a/mock.go +++ b/mock.go @@ -1,172 +1,19 @@ package gock import ( - "net/http" - "sync" + "github.com/h2non/gock/threadsafe" ) // Mock represents the required interface that must // be implemented by HTTP mock instances. -type Mock interface { - // Disable disables the current mock manually. - Disable() - - // Done returns true if the current mock is disabled. - Done() bool - - // Request returns the mock Request instance. - Request() *Request - - // Response returns the mock Response instance. - Response() *Response - - // Match matches the given http.Request with the current mock. - Match(*http.Request) (bool, error) - - // AddMatcher adds a new matcher function. - AddMatcher(MatchFunc) - - // SetMatcher uses a new matcher implementation. - SetMatcher(Matcher) -} +type Mock = threadsafe.Mock // Mocker implements a Mock capable interface providing // a default mock configuration used internally to store mocks. -type Mocker struct { - // disabler stores a disabler for thread safety checking current mock is disabled - disabler *disabler - - // mutex stores the mock mutex for thread safety. - mutex sync.Mutex - - // matcher stores a Matcher capable instance to match the given http.Request. - matcher Matcher - - // request stores the mock Request to match. - request *Request - - // response stores the mock Response to use in case of match. - response *Response -} - -type disabler struct { - // disabled stores if the current mock is disabled. - disabled bool - - // mutex stores the disabler mutex for thread safety. - mutex sync.RWMutex -} - -func (d *disabler) isDisabled() bool { - d.mutex.RLock() - defer d.mutex.RUnlock() - return d.disabled -} - -func (d *disabler) Disable() { - d.mutex.Lock() - defer d.mutex.Unlock() - d.disabled = true -} +type Mocker = threadsafe.Mocker // NewMock creates a new HTTP mock based on the given request and response instances. // It's mostly used internally. func NewMock(req *Request, res *Response) *Mocker { - mock := &Mocker{ - disabler: new(disabler), - request: req, - response: res, - matcher: DefaultMatcher.Clone(), - } - res.Mock = mock - req.Mock = mock - req.Response = res - return mock -} - -// Disable disables the current mock manually. -func (m *Mocker) Disable() { - m.disabler.Disable() -} - -// Done returns true in case that the current mock -// instance is disabled and therefore must be removed. -func (m *Mocker) Done() bool { - // prevent deadlock with m.mutex - if m.disabler.isDisabled() { - return true - } - - m.mutex.Lock() - defer m.mutex.Unlock() - return !m.request.Persisted && m.request.Counter == 0 -} - -// Request returns the Request instance -// configured for the current HTTP mock. -func (m *Mocker) Request() *Request { - return m.request -} - -// Response returns the Response instance -// configured for the current HTTP mock. -func (m *Mocker) Response() *Response { - return m.response -} - -// Match matches the given http.Request with the current Request -// mock expectation, returning true if matches. -func (m *Mocker) Match(req *http.Request) (bool, error) { - if m.disabler.isDisabled() { - return false, nil - } - - // Filter - for _, filter := range m.request.Filters { - if !filter(req) { - return false, nil - } - } - - // Map - for _, mapper := range m.request.Mappers { - if treq := mapper(req); treq != nil { - req = treq - } - } - - // Match - matches, err := m.matcher.Match(req, m.request) - if matches { - m.decrement() - } - - return matches, err -} - -// SetMatcher sets a new matcher implementation -// for the current mock expectation. -func (m *Mocker) SetMatcher(matcher Matcher) { - m.matcher = matcher -} - -// AddMatcher adds a new matcher function -// for the current mock expectation. -func (m *Mocker) AddMatcher(fn MatchFunc) { - m.matcher.Add(fn) -} - -// decrement decrements the current mock Request counter. -func (m *Mocker) decrement() { - if m.request.Persisted { - return - } - - m.mutex.Lock() - defer m.mutex.Unlock() - - m.request.Counter-- - if m.request.Counter == 0 { - m.disabler.Disable() - } + return g.NewMock(req, res) } diff --git a/mock_test.go b/mock_test.go index 01e6fca..70b0765 100644 --- a/mock_test.go +++ b/mock_test.go @@ -7,63 +7,6 @@ import ( "github.com/nbio/st" ) -func TestNewMock(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - mock := NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) - - st.Expect(t, mock.Request(), req) - st.Expect(t, mock.Request().Mock, mock) - st.Expect(t, mock.Response(), res) - st.Expect(t, mock.Response().Mock, mock) -} - -func TestMockDisable(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - mock := NewMock(req, res) - - st.Expect(t, mock.disabler.isDisabled(), false) - mock.Disable() - st.Expect(t, mock.disabler.isDisabled(), true) - - matches, err := mock.Match(&http.Request{}) - st.Expect(t, err, nil) - st.Expect(t, matches, false) -} - -func TestMockDone(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - - mock := NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.Done(), false) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.Disable() - st.Expect(t, mock.Done(), true) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.request.Counter = 0 - st.Expect(t, mock.Done(), true) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.request.Persisted = true - st.Expect(t, mock.Done(), false) -} - func TestMockSetMatcher(t *testing.T) { defer after() @@ -71,15 +14,12 @@ func TestMockSetMatcher(t *testing.T) { res := NewResponse() mock := NewMock(req, res) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) matcher := NewMatcher() matcher.Flush() matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) mock.SetMatcher(matcher) - st.Expect(t, len(mock.matcher.Get()), 1) - st.Expect(t, mock.disabler.isDisabled(), false) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) @@ -93,15 +33,12 @@ func TestMockAddMatcher(t *testing.T) { res := NewResponse() mock := NewMock(req, res) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) matcher := NewMatcher() matcher.Flush() mock.SetMatcher(matcher) mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.matcher, matcher) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) @@ -127,8 +64,6 @@ func TestMockMatch(t *testing.T) { calls++ return true, nil }) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.matcher, matcher) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) diff --git a/options.go b/options.go index 188aa58..f754563 100644 --- a/options.go +++ b/options.go @@ -1,8 +1,6 @@ package gock +import "github.com/h2non/gock/threadsafe" + // Options represents customized option for gock -type Options struct { - // DisableRegexpHost stores if the host is only a plain string rather than regular expression, - // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string - DisableRegexpHost bool -} +type Options = threadsafe.Options diff --git a/request.go b/request.go index 5702417..1563c37 100644 --- a/request.go +++ b/request.go @@ -1,325 +1,20 @@ package gock import ( - "encoding/base64" - "io" - "io/ioutil" - "net/http" - "net/url" - "strings" + "github.com/h2non/gock/threadsafe" ) // MapRequestFunc represents the required function interface for request mappers. -type MapRequestFunc func(*http.Request) *http.Request +type MapRequestFunc = threadsafe.MapRequestFunc // FilterRequestFunc represents the required function interface for request filters. -type FilterRequestFunc func(*http.Request) bool +type FilterRequestFunc = threadsafe.FilterRequestFunc // Request represents the high-level HTTP request used to store // request fields used to match intercepted requests. -type Request struct { - // Mock stores the parent mock reference for the current request mock used for method delegation. - Mock Mock - - // Response stores the current Response instance for the current matches Request. - Response *Response - - // Error stores the latest mock request configuration error. - Error error - - // Counter stores the pending times that the current mock should be active. - Counter int - - // Persisted stores if the current mock should be always active. - Persisted bool - - // Options stores options for current Request. - Options Options - - // URLStruct stores the parsed URL as *url.URL struct. - URLStruct *url.URL - - // Method stores the Request HTTP method to match. - Method string - - // CompressionScheme stores the Request Compression scheme to match and use for decompression. - CompressionScheme string - - // Header stores the HTTP header fields to match. - Header http.Header - - // Cookies stores the Request HTTP cookies values to match. - Cookies []*http.Cookie - - // PathParams stores the path parameters to match. - PathParams map[string]string - - // BodyBuffer stores the body data to match. - BodyBuffer []byte - - // Mappers stores the request functions mappers used for matching. - Mappers []MapRequestFunc - - // Filters stores the request functions filters used for matching. - Filters []FilterRequestFunc -} +type Request = threadsafe.Request // NewRequest creates a new Request instance. func NewRequest() *Request { - return &Request{ - Counter: 1, - URLStruct: &url.URL{}, - Header: make(http.Header), - PathParams: make(map[string]string), - } -} - -// URL defines the mock URL to match. -func (r *Request) URL(uri string) *Request { - r.URLStruct, r.Error = url.Parse(uri) - return r -} - -// SetURL defines the url.URL struct to be used for matching. -func (r *Request) SetURL(u *url.URL) *Request { - r.URLStruct = u - return r -} - -// Path defines the mock URL path value to match. -func (r *Request) Path(path string) *Request { - r.URLStruct.Path = path - return r -} - -// Get specifies the GET method and the given URL path to match. -func (r *Request) Get(path string) *Request { - return r.method("GET", path) -} - -// Post specifies the POST method and the given URL path to match. -func (r *Request) Post(path string) *Request { - return r.method("POST", path) -} - -// Put specifies the PUT method and the given URL path to match. -func (r *Request) Put(path string) *Request { - return r.method("PUT", path) -} - -// Delete specifies the DELETE method and the given URL path to match. -func (r *Request) Delete(path string) *Request { - return r.method("DELETE", path) -} - -// Patch specifies the PATCH method and the given URL path to match. -func (r *Request) Patch(path string) *Request { - return r.method("PATCH", path) -} - -// Head specifies the HEAD method and the given URL path to match. -func (r *Request) Head(path string) *Request { - return r.method("HEAD", path) -} - -// method is a DRY shortcut used to declare the expected HTTP method and URL path. -func (r *Request) method(method, path string) *Request { - if path != "/" { - r.URLStruct.Path = path - } - r.Method = strings.ToUpper(method) - return r -} - -// Body defines the body data to match based on a io.Reader interface. -func (r *Request) Body(body io.Reader) *Request { - r.BodyBuffer, r.Error = ioutil.ReadAll(body) - return r -} - -// BodyString defines the body to match based on a given string. -func (r *Request) BodyString(body string) *Request { - r.BodyBuffer = []byte(body) - return r -} - -// File defines the body to match based on the given file path string. -func (r *Request) File(path string) *Request { - r.BodyBuffer, r.Error = ioutil.ReadFile(path) - return r -} - -// Compression defines the request compression scheme, and enables automatic body decompression. -// Supports only the "gzip" scheme so far. -func (r *Request) Compression(scheme string) *Request { - r.Header.Set("Content-Encoding", scheme) - r.CompressionScheme = scheme - return r -} - -// JSON defines the JSON body to match based on a given structure. -func (r *Request) JSON(data interface{}) *Request { - if r.Header.Get("Content-Type") == "" { - r.Header.Set("Content-Type", "application/json") - } - r.BodyBuffer, r.Error = readAndDecode(data, "json") - return r -} - -// XML defines the XML body to match based on a given structure. -func (r *Request) XML(data interface{}) *Request { - if r.Header.Get("Content-Type") == "" { - r.Header.Set("Content-Type", "application/xml") - } - r.BodyBuffer, r.Error = readAndDecode(data, "xml") - return r -} - -// MatchType defines the request Content-Type MIME header field. -// Supports custom MIME types and type aliases. E.g: json, xml, form, text... -func (r *Request) MatchType(kind string) *Request { - mime := BodyTypeAliases[kind] - if mime != "" { - kind = mime - } - r.Header.Set("Content-Type", kind) - return r -} - -// BasicAuth defines a username and password for HTTP Basic Authentication -func (r *Request) BasicAuth(username, password string) *Request { - r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) - return r -} - -// MatchHeader defines a new key and value header to match. -func (r *Request) MatchHeader(key, value string) *Request { - r.Header.Set(key, value) - return r -} - -// HeaderPresent defines that a header field must be present in the request. -func (r *Request) HeaderPresent(key string) *Request { - r.Header.Set(key, ".*") - return r -} - -// MatchHeaders defines a map of key-value headers to match. -func (r *Request) MatchHeaders(headers map[string]string) *Request { - for key, value := range headers { - r.Header.Set(key, value) - } - return r -} - -// MatchParam defines a new key and value URL query param to match. -func (r *Request) MatchParam(key, value string) *Request { - query := r.URLStruct.Query() - query.Set(key, value) - r.URLStruct.RawQuery = query.Encode() - return r -} - -// MatchParams defines a map of URL query param key-value to match. -func (r *Request) MatchParams(params map[string]string) *Request { - query := r.URLStruct.Query() - for key, value := range params { - query.Set(key, value) - } - r.URLStruct.RawQuery = query.Encode() - return r -} - -// ParamPresent matches if the given query param key is present in the URL. -func (r *Request) ParamPresent(key string) *Request { - r.MatchParam(key, ".*") - return r -} - -// PathParam matches if a given path parameter key is present in the URL. -// -// The value is representative of the restful resource the key defines, e.g. -// // /users/123/name -// r.PathParam("users", "123") -// would match. -func (r *Request) PathParam(key, val string) *Request { - r.PathParams[key] = val - - return r -} - -// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it. -func (r *Request) Persist() *Request { - r.Persisted = true - return r -} - -// WithOptions sets the options for the request. -func (r *Request) WithOptions(options Options) *Request { - r.Options = options - return r -} - -// Times defines the number of times that the current HTTP mock should remain active. -func (r *Request) Times(num int) *Request { - r.Counter = num - return r -} - -// AddMatcher adds a new matcher function to match the request. -func (r *Request) AddMatcher(fn MatchFunc) *Request { - r.Mock.AddMatcher(fn) - return r -} - -// SetMatcher sets a new matcher function to match the request. -func (r *Request) SetMatcher(matcher Matcher) *Request { - r.Mock.SetMatcher(matcher) - return r -} - -// Map adds a new request mapper function to map http.Request before the matching process. -func (r *Request) Map(fn MapRequestFunc) *Request { - r.Mappers = append(r.Mappers, fn) - return r -} - -// Filter filters a new request filter function to filter http.Request before the matching process. -func (r *Request) Filter(fn FilterRequestFunc) *Request { - r.Filters = append(r.Filters, fn) - return r -} - -// EnableNetworking enables the use real networking for the current mock. -func (r *Request) EnableNetworking() *Request { - if r.Response != nil { - r.Response.UseNetwork = true - } - return r -} - -// Reply defines the Response status code and returns the mock Response DSL. -func (r *Request) Reply(status int) *Response { - return r.Response.Status(status) -} - -// ReplyError defines the Response simulated error. -func (r *Request) ReplyError(err error) *Response { - return r.Response.SetError(err) -} - -// ReplyFunc allows the developer to define the mock response via a custom function. -func (r *Request) ReplyFunc(replier func(*Response)) *Response { - replier(r.Response) - return r.Response -} - -// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt -// "To receive authorization, the client sends the userid and password, -// separated by a single colon (":") character, within a base64 -// encoded string in the credentials." -// It is not meant to be urlencoded. -func basicAuth(username, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) + return g.NewRequest() } diff --git a/request_test.go b/request_test.go index 463e784..011a0f9 100644 --- a/request_test.go +++ b/request_test.go @@ -266,7 +266,6 @@ func TestRequestAddMatcher(t *testing.T) { ereq := NewRequest() mock := NewMock(ereq, &Response{}) - mock.matcher = NewMatcher() ereq.Mock = mock ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { diff --git a/responder.go b/responder.go index f0f16bb..0ec2de5 100644 --- a/responder.go +++ b/responder.go @@ -1,111 +1,12 @@ package gock import ( - "bytes" - "io" - "io/ioutil" "net/http" - "strconv" - "time" + + "github.com/h2non/gock/threadsafe" ) // Responder builds a mock http.Response based on the given Response mock. func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) { - // If error present, reply it - err := mock.Error - if err != nil { - return nil, err - } - - if res == nil { - res = createResponse(req) - } - - // Apply response filter - for _, filter := range mock.Filters { - if !filter(res) { - return res, nil - } - } - - // Define mock status code - if mock.StatusCode != 0 { - res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode) - res.StatusCode = mock.StatusCode - } - - // Define headers by merging fields - res.Header = mergeHeaders(res, mock) - - // Define mock body, if present - if len(mock.BodyBuffer) > 0 { - res.ContentLength = int64(len(mock.BodyBuffer)) - res.Body = createReadCloser(mock.BodyBuffer) - } - - // Set raw mock body, if exist - if mock.BodyGen != nil { - res.ContentLength = -1 - res.Body = mock.BodyGen() - } - - // Apply response mappers - for _, mapper := range mock.Mappers { - if tres := mapper(res); tres != nil { - res = tres - } - } - - // Sleep to simulate delay, if necessary - if mock.ResponseDelay > 0 { - // allow escaping from sleep due to request context expiration or cancellation - t := time.NewTimer(mock.ResponseDelay) - select { - case <-t.C: - case <-req.Context().Done(): - // cleanly stop the timer - if !t.Stop() { - <-t.C - } - } - } - - // check if the request context has ended. we could put this up in the delay code above, but putting it here - // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.) - if err = req.Context().Err(); err != nil { - // cleanly close the response and return the context error - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - return nil, err - } - - return res, err -} - -// createResponse creates a new http.Response with default fields. -func createResponse(req *http.Request) *http.Response { - return &http.Response{ - ProtoMajor: 1, - ProtoMinor: 1, - Proto: "HTTP/1.1", - Request: req, - Header: make(http.Header), - Body: createReadCloser([]byte{}), - } -} - -// mergeHeaders copies the mock headers. -func mergeHeaders(res *http.Response, mres *Response) http.Header { - for key, values := range mres.Header { - for _, value := range values { - res.Header.Add(key, value) - } - } - return res.Header -} - -// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an -// http response body. -func createReadCloser(body []byte) io.ReadCloser { - return ioutil.NopCloser(bytes.NewReader(body)) + return threadsafe.Responder(req, mock, res) } diff --git a/response.go b/response.go index 3e62b9e..0eeb314 100644 --- a/response.go +++ b/response.go @@ -1,196 +1,20 @@ package gock import ( - "bytes" - "encoding/json" - "encoding/xml" - "io" - "io/ioutil" - "net/http" - "time" + "github.com/h2non/gock/threadsafe" ) // MapResponseFunc represents the required function interface impletemed by response mappers. -type MapResponseFunc func(*http.Response) *http.Response +type MapResponseFunc = threadsafe.MapResponseFunc // FilterResponseFunc represents the required function interface impletemed by response filters. -type FilterResponseFunc func(*http.Response) bool +type FilterResponseFunc = threadsafe.FilterResponseFunc // Response represents high-level HTTP fields to configure // and define HTTP responses intercepted by gock. -type Response struct { - // Mock stores the parent mock reference for the current response mock used for method delegation. - Mock Mock - - // Error stores the latest response configuration or injected error. - Error error - - // UseNetwork enables the use of real network for the current mock. - UseNetwork bool - - // StatusCode stores the response status code. - StatusCode int - - // Headers stores the response headers. - Header http.Header - - // Cookies stores the response cookie fields. - Cookies []*http.Cookie - - // BodyGen stores a io.ReadCloser generator to be returned. - BodyGen func() io.ReadCloser - - // BodyBuffer stores the array of bytes to use as body. - BodyBuffer []byte - - // ResponseDelay stores the simulated response delay. - ResponseDelay time.Duration - - // Mappers stores the request functions mappers used for matching. - Mappers []MapResponseFunc - - // Filters stores the request functions filters used for matching. - Filters []FilterResponseFunc -} +type Response = threadsafe.Response // NewResponse creates a new Response. func NewResponse() *Response { - return &Response{Header: make(http.Header)} -} - -// Status defines the desired HTTP status code to reply in the current response. -func (r *Response) Status(code int) *Response { - r.StatusCode = code - return r -} - -// Type defines the response Content-Type MIME header field. -// Supports type alias. E.g: json, xml, form, text... -func (r *Response) Type(kind string) *Response { - mime := BodyTypeAliases[kind] - if mime != "" { - kind = mime - } - r.Header.Set("Content-Type", kind) - return r -} - -// SetHeader sets a new header field in the mock response. -func (r *Response) SetHeader(key, value string) *Response { - r.Header.Set(key, value) - return r -} - -// AddHeader adds a new header field in the mock response -// with out removing an existent one. -func (r *Response) AddHeader(key, value string) *Response { - r.Header.Add(key, value) - return r -} - -// SetHeaders sets a map of header fields in the mock response. -func (r *Response) SetHeaders(headers map[string]string) *Response { - for key, value := range headers { - r.Header.Add(key, value) - } - return r -} - -// Body sets the HTTP response body to be used. -func (r *Response) Body(body io.Reader) *Response { - r.BodyBuffer, r.Error = ioutil.ReadAll(body) - return r -} - -// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser -// for every response. This will take priority than other Body methods used. -func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response { - r.BodyGen = generator - return r -} - -// BodyString defines the response body as string. -func (r *Response) BodyString(body string) *Response { - r.BodyBuffer = []byte(body) - return r -} - -// File defines the response body reading the data -// from disk based on the file path string. -func (r *Response) File(path string) *Response { - r.BodyBuffer, r.Error = ioutil.ReadFile(path) - return r -} - -// JSON defines the response body based on a JSON based input. -func (r *Response) JSON(data interface{}) *Response { - r.Header.Set("Content-Type", "application/json") - r.BodyBuffer, r.Error = readAndDecode(data, "json") - return r -} - -// XML defines the response body based on a XML based input. -func (r *Response) XML(data interface{}) *Response { - r.Header.Set("Content-Type", "application/xml") - r.BodyBuffer, r.Error = readAndDecode(data, "xml") - return r -} - -// SetError defines the response simulated error. -func (r *Response) SetError(err error) *Response { - r.Error = err - return r -} - -// Delay defines the response simulated delay. -// This feature is still experimental and will be improved in the future. -func (r *Response) Delay(delay time.Duration) *Response { - r.ResponseDelay = delay - return r -} - -// Map adds a new response mapper function to map http.Response before the matching process. -func (r *Response) Map(fn MapResponseFunc) *Response { - r.Mappers = append(r.Mappers, fn) - return r -} - -// Filter filters a new request filter function to filter http.Request before the matching process. -func (r *Response) Filter(fn FilterResponseFunc) *Response { - r.Filters = append(r.Filters, fn) - return r -} - -// EnableNetworking enables the use real networking for the current mock. -func (r *Response) EnableNetworking() *Response { - r.UseNetwork = true - return r -} - -// Done returns true if the mock was done and disabled. -func (r *Response) Done() bool { - return r.Mock.Done() -} - -func readAndDecode(data interface{}, kind string) ([]byte, error) { - buf := &bytes.Buffer{} - - switch data.(type) { - case string: - buf.WriteString(data.(string)) - case []byte: - buf.Write(data.([]byte)) - default: - var err error - if kind == "xml" { - err = xml.NewEncoder(buf).Encode(data) - } else { - err = json.NewEncoder(buf).Encode(data) - } - if err != nil { - return nil, err - } - } - - return ioutil.ReadAll(buf) + return g.NewResponse() } diff --git a/response_test.go b/response_test.go index 412ca53..c27781a 100644 --- a/response_test.go +++ b/response_test.go @@ -158,7 +158,7 @@ func TestResponseEnableNetworking(t *testing.T) { func TestResponseDone(t *testing.T) { res := NewResponse() - res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)} + res.Mock = NewMock(&Request{Counter: 1}, res) st.Expect(t, res.Done(), false) res.Mock.Disable() st.Expect(t, res.Done(), true) diff --git a/store.go b/store.go index 7ed1316..adc5021 100644 --- a/store.go +++ b/store.go @@ -1,100 +1,46 @@ package gock -import ( - "sync" -) - -// storeMutex is used interally for store synchronization. -var storeMutex = sync.RWMutex{} - -// mocks is internally used to store registered mocks. -var mocks = []Mock{} - // Register registers a new mock in the current mocks stack. func Register(mock Mock) { - if Exists(mock) { - return - } - - // Make ops thread safe - storeMutex.Lock() - defer storeMutex.Unlock() - - // Expose mock in request/response for delegation - mock.Request().Mock = mock - mock.Response().Mock = mock - - // Registers the mock in the global store - mocks = append(mocks, mock) + g.Register(mock) } // GetAll returns the current stack of registered mocks. func GetAll() []Mock { - storeMutex.RLock() - defer storeMutex.RUnlock() - return mocks + return g.GetAll() } // Exists checks if the given Mock is already registered. func Exists(m Mock) bool { - storeMutex.RLock() - defer storeMutex.RUnlock() - for _, mock := range mocks { - if mock == m { - return true - } - } - return false + return g.Exists(m) } // Remove removes a registered mock by reference. func Remove(m Mock) { - for i, mock := range mocks { - if mock == m { - storeMutex.Lock() - mocks = append(mocks[:i], mocks[i+1:]...) - storeMutex.Unlock() - } - } + g.Remove(m) } // Flush flushes the current stack of registered mocks. func Flush() { - storeMutex.Lock() - defer storeMutex.Unlock() - mocks = []Mock{} + g.Flush() } // Pending returns an slice of pending mocks. func Pending() []Mock { - Clean() - storeMutex.RLock() - defer storeMutex.RUnlock() - return mocks + return g.Pending() } // IsDone returns true if all the registered mocks has been triggered successfully. func IsDone() bool { - return !IsPending() + return g.IsDone() } // IsPending returns true if there are pending mocks. func IsPending() bool { - return len(Pending()) > 0 + return g.IsPending() } // Clean cleans the mocks store removing disabled or obsolete mocks. func Clean() { - storeMutex.Lock() - defer storeMutex.Unlock() - - buf := []Mock{} - for _, mock := range mocks { - if mock.Done() { - continue - } - buf = append(buf, mock) - } - - mocks = buf + g.Clean() } diff --git a/store_test.go b/store_test.go index 4ab4c83..b40a078 100644 --- a/store_test.go +++ b/store_test.go @@ -8,36 +8,36 @@ import ( func TestStoreRegister(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock Register(mock) - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, mock.Request().Mock, mock) st.Expect(t, mock.Response().Mock, mock) } func TestStoreGetAll(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock store := GetAll() - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, len(store), 1) st.Expect(t, store[0], mock) } func TestStoreExists(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, Exists(mock), true) } func TestStorePending(t *testing.T) { defer after() New("foo") - st.Expect(t, mocks, Pending()) + st.Expect(t, GetAll(), Pending()) } func TestStoreIsPending(t *testing.T) { @@ -58,9 +58,9 @@ func TestStoreIsDone(t *testing.T) { func TestStoreRemove(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, Exists(mock), true) Remove(mock) @@ -72,16 +72,16 @@ func TestStoreRemove(t *testing.T) { func TestStoreFlush(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock1 := New("foo").Mock mock2 := New("foo").Mock - st.Expect(t, len(mocks), 2) + st.Expect(t, len(GetAll()), 2) st.Expect(t, Exists(mock1), true) st.Expect(t, Exists(mock2), true) Flush() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) st.Expect(t, Exists(mock1), false) st.Expect(t, Exists(mock2), false) } diff --git a/threadsafe/gock.go b/threadsafe/gock.go new file mode 100644 index 0000000..7df7b9a --- /dev/null +++ b/threadsafe/gock.go @@ -0,0 +1,270 @@ +package threadsafe + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "sync" +) + +type Gock struct { + // mutex is used internally for locking thread-sensitive functions. + mutex sync.Mutex + // config global singleton store. + config struct { + Networking bool + NetworkingFilters []FilterRequestFunc + Observer ObserverFunc + } + // DumpRequest is a default implementation of ObserverFunc that dumps + // the HTTP/1.x wire representation of the http request + DumpRequest ObserverFunc + // track unmatched requests so they can be tested for + unmatchedRequests []*http.Request + + // storeMutex is used internally for store synchronization. + storeMutex sync.RWMutex + + // mocks is internally used to store registered mocks. + mocks []Mock + + // DefaultMatcher stores the default Matcher instance used to match mocks. + DefaultMatcher *MockMatcher + + // MatchersHeader exposes a slice of HTTP header specific mock matchers. + MatchersHeader []MatchFunc + // MatchersBody exposes a slice of HTTP body specific built-in mock matchers. + MatchersBody []MatchFunc + // Matchers stores all the built-in mock matchers. + Matchers []MatchFunc + + // BodyTypes stores the supported MIME body types for matching. + // Currently only text-based types. + BodyTypes []string + + // BodyTypeAliases stores a generic MIME type by alias. + BodyTypeAliases map[string]string + + // CompressionSchemes stores the supported Content-Encoding types for decompression. + CompressionSchemes []string + + intercepting bool + + DisableCallback func() + InterceptCallback func() + InterceptingCallback func() bool +} + +func NewGock() *Gock { + g := &Gock{ + DumpRequest: defaultDumpRequest, + + BodyTypes: []string{ + "text/html", + "text/plain", + "application/json", + "application/xml", + "multipart/form-data", + "application/x-www-form-urlencoded", + }, + + BodyTypeAliases: map[string]string{ + "html": "text/html", + "text": "text/plain", + "json": "application/json", + "xml": "application/xml", + "form": "multipart/form-data", + "url": "application/x-www-form-urlencoded", + }, + + // CompressionSchemes stores the supported Content-Encoding types for decompression. + CompressionSchemes: []string{ + "gzip", + }, + } + g.MatchersHeader = []MatchFunc{ + g.MatchMethod, + g.MatchScheme, + g.MatchHost, + g.MatchPath, + g.MatchHeaders, + g.MatchQueryParams, + g.MatchPathParams, + } + g.MatchersBody = []MatchFunc{ + g.MatchBody, + } + g.Matchers = append(g.MatchersHeader, g.MatchersBody...) + + // DefaultMatcher stores the default Matcher instance used to match mocks. + g.DefaultMatcher = g.NewMatcher() + return g +} + +// ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic +type ObserverFunc func(*http.Request, Mock) + +func defaultDumpRequest(request *http.Request, mock Mock) { + bytes, _ := httputil.DumpRequestOut(request, true) + fmt.Println(string(bytes)) + fmt.Printf("\nMatches: %v\n---\n", mock != nil) +} + +// New creates and registers a new HTTP mock with +// default settings and returns the Request DSL for HTTP mock +// definition and set up. +func (g *Gock) New(uri string) *Request { + g.Intercept() + + res := g.NewResponse() + req := g.NewRequest() + req.URLStruct, res.Error = url.Parse(normalizeURI(uri)) + + // Create the new mock expectation + exp := g.NewMock(req, res) + g.Register(exp) + + return req +} + +// Intercepting returns true if gock is currently able to intercept. +func (g *Gock) Intercepting() bool { + g.mutex.Lock() + defer g.mutex.Unlock() + + callbackResponse := true + if g.InterceptingCallback != nil { + callbackResponse = g.InterceptingCallback() + } + + return g.intercepting && callbackResponse +} + +// Intercept enables HTTP traffic interception via http.DefaultTransport. +// If you are using a custom HTTP transport, you have to use `gock.Transport()` +func (g *Gock) Intercept() { + if !g.Intercepting() { + g.mutex.Lock() + g.intercepting = true + + if g.InterceptCallback != nil { + g.InterceptCallback() + } + + g.mutex.Unlock() + } +} + +// InterceptClient allows the developer to intercept HTTP traffic using +// a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation. +func (g *Gock) InterceptClient(cli *http.Client) { + _, ok := cli.Transport.(*Transport) + if ok { + return // if transport already intercepted, just ignore it + } + cli.Transport = g.NewTransport(cli.Transport) +} + +// RestoreClient allows the developer to disable and restore the +// original transport in the given http.Client. +func (g *Gock) RestoreClient(cli *http.Client) { + trans, ok := cli.Transport.(*Transport) + if !ok { + return + } + cli.Transport = trans.Transport +} + +// Disable disables HTTP traffic interception by gock. +func (g *Gock) Disable() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.intercepting = false + + if g.DisableCallback != nil { + g.DisableCallback() + } +} + +// Off disables the default HTTP interceptors and removes +// all the registered mocks, even if they has not been intercepted yet. +func (g *Gock) Off() { + g.Flush() + g.Disable() +} + +// OffAll is like `Off()`, but it also removes the unmatched requests registry. +func (g *Gock) OffAll() { + g.Flush() + g.Disable() + g.CleanUnmatchedRequest() +} + +// Observe provides a hook to support inspection of the request and matched mock +func (g *Gock) Observe(fn ObserverFunc) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Observer = fn +} + +// EnableNetworking enables real HTTP networking +func (g *Gock) EnableNetworking() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Networking = true +} + +// DisableNetworking disables real HTTP networking +func (g *Gock) DisableNetworking() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Networking = false +} + +// NetworkingFilter determines if an http.Request should be triggered or not. +func (g *Gock) NetworkingFilter(fn FilterRequestFunc) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.NetworkingFilters = append(g.config.NetworkingFilters, fn) +} + +// DisableNetworkingFilters disables registered networking filters. +func (g *Gock) DisableNetworkingFilters() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.NetworkingFilters = []FilterRequestFunc{} +} + +// GetUnmatchedRequests returns all requests that have been received but haven't matched any mock +func (g *Gock) GetUnmatchedRequests() []*http.Request { + g.mutex.Lock() + defer g.mutex.Unlock() + return g.unmatchedRequests +} + +// HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock +func (g *Gock) HasUnmatchedRequest() bool { + return len(g.GetUnmatchedRequests()) > 0 +} + +// CleanUnmatchedRequest cleans the unmatched requests internal registry. +func (g *Gock) CleanUnmatchedRequest() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.unmatchedRequests = []*http.Request{} +} + +func (g *Gock) trackUnmatchedRequest(req *http.Request) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.unmatchedRequests = append(g.unmatchedRequests, req) +} + +func normalizeURI(uri string) string { + if ok, _ := regexp.MatchString("^http[s]?", uri); !ok { + return "http://" + uri + } + return uri +} diff --git a/threadsafe/gock_test.go b/threadsafe/gock_test.go new file mode 100644 index 0000000..a503a28 --- /dev/null +++ b/threadsafe/gock_test.go @@ -0,0 +1,517 @@ +package threadsafe + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/nbio/st" +) + +func TestMockSimple(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockOff(t *testing.T) { + g := NewGock() + g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"}) + g.Off() + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Get("http://127.0.0.1:3123") + st.Reject(t, err, nil) +} + +func TestMockBodyStringResponse(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(200).BodyString("foo bar") + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo bar") +} + +func TestMockBodyMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").BodyString("foo bar").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockBodyCannotMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").BodyString("foo foo").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchCompressed(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo") + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte("foo bar")) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "text/plain") + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockBodyCannotMatchCompressed(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"foo": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"bar":"foo"}`) +} + +func TestMockBodyCannotMatchJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"bar": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchCompressedJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + Compression("gzip"). + JSON(map[string]string{"foo": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte(`{"foo":"bar"}`)) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "application/json") + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"bar":"foo"}`) +} + +func TestMockBodyCannotMatchCompressedJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"bar": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte(`{"foo":"bar"}`)) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "application/json") + _, err = c.Do(req) + st.Reject(t, err, nil) +} + +func TestMockMatchHeaders(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + MatchHeader("Content-Type", "(.*)/plain"). + Reply(200). + BodyString("foo foo") + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockMap(t *testing.T) { + g := NewGock() + defer after(g) + + mock := g.New("http://bar.com") + mock.Map(func(req *http.Request) *http.Request { + req.URL.Host = "bar.com" + return req + }) + mock.Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockFilter(t *testing.T) { + g := NewGock() + defer after(g) + + mock := g.New("http://foo.com") + mock.Filter(func(req *http.Request) bool { + return req.URL.Host == "foo.com" + }) + mock.Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockCounterDisabled(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, len(g.GetAll()), 0) +} + +func TestMockEnableNetwork(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.EnableNetworking() + defer g.DisableNetworking() + + g.New(ts.URL).Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, len(g.GetAll()), 0) + + res, err = c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) +} + +func TestMockEnableNetworkFilter(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.EnableNetworking() + defer g.DisableNetworking() + + g.NetworkingFilter(func(req *http.Request) bool { + return strings.Contains(req.URL.Host, "127.0.0.1") + }) + defer g.DisableNetworkingFilters() + + g.New(ts.URL).Reply(0).SetHeader("foo", "bar") + st.Expect(t, len(g.GetAll()), 1) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, len(g.GetAll()), 0) +} + +func TestMockPersistent(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Get("/bar"). + Persist(). + Reply(200). + JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + for i := 0; i < 5; i++ { + res, err := c.Get("http://foo.com/bar") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) + } +} + +func TestMockPersistTimes(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://127.0.0.1:1234"). + Get("/bar"). + Times(4). + Reply(200). + JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + for i := 0; i < 5; i++ { + res, err := c.Get("http://127.0.0.1:1234/bar") + if i == 4 { + st.Reject(t, err, nil) + break + } + + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) + } +} + +func TestUnmatched(t *testing.T) { + g := NewGock() + defer after(g) + + // clear out any unmatchedRequests from other tests + g.unmatchedRequests = []*http.Request{} + + g.Intercept() + + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Get("http://server.com/unmatched") + st.Reject(t, err, nil) + + unmatched := g.GetUnmatchedRequests() + st.Expect(t, len(unmatched), 1) + st.Expect(t, unmatched[0].URL.Host, "server.com") + st.Expect(t, unmatched[0].URL.Path, "/unmatched") + st.Expect(t, g.HasUnmatchedRequest(), true) +} + +func TestMultipleMocks(t *testing.T) { + g := NewGock() + defer g.Disable() + + g.New("http://server.com"). + Get("/foo"). + Reply(200). + JSON(map[string]string{"value": "foo"}) + + g.New("http://server.com"). + Get("/bar"). + Reply(200). + JSON(map[string]string{"value": "bar"}) + + g.New("http://server.com"). + Get("/baz"). + Reply(200). + JSON(map[string]string{"value": "baz"}) + + tests := []struct { + path string + }{ + {"/foo"}, + {"/bar"}, + {"/baz"}, + } + + c := &http.Client{} + g.InterceptClient(c) + for _, test := range tests { + res, err := c.Get("http://server.com" + test.path) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:15], `{"value":"`+test.path[1:]+`"}`) + } + + _, err := c.Get("http://server.com/foo") + st.Reject(t, err, nil) +} + +func TestInterceptClient(t *testing.T) { + g := NewGock() + defer after(g) + + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + req, err := http.NewRequest("GET", "http://foo.com", nil) + client := &http.Client{Transport: &http.Transport{}} + g.InterceptClient(client) + + res, err := client.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) +} + +func TestRestoreClient(t *testing.T) { + g := NewGock() + defer after(g) + + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + req, err := http.NewRequest("GET", "http://foo.com", nil) + client := &http.Client{Transport: &http.Transport{}} + g.InterceptClient(client) + trans := client.Transport + + res, err := client.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + + g.RestoreClient(client) + st.Reject(t, trans, client.Transport) +} + +func TestMockRegExpMatching(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + MatchHeader("Authorization", "Bearer (.*)"). + BodyString(`{"foo":".*"}`). + Reply(200). + SetHeader("Server", "gock"). + JSON(map[string]string{"foo": "bar"}) + + req, _ := http.NewRequest("POST", "http://foo.com/bar", bytes.NewBuffer([]byte(`{"foo":"baz"}`))) + req.Header.Set("Authorization", "Bearer s3cr3t") + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + st.Expect(t, res.Header.Get("Server"), "gock") + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestObserve(t *testing.T) { + g := NewGock() + defer after(g) + var observedRequest *http.Request + var observedMock Mock + g.Observe(func(request *http.Request, mock Mock) { + observedRequest = request + observedMock = mock + }) + g.New("http://observe-foo.com").Reply(200) + req, _ := http.NewRequest("POST", "http://observe-foo.com", nil) + + c := &http.Client{} + g.InterceptClient(c) + c.Do(req) + + st.Expect(t, observedRequest.Host, "observe-foo.com") + st.Expect(t, observedMock.Request().URLStruct.Host, "observe-foo.com") +} + +func TestTryCreatingRacesInNew(t *testing.T) { + g := NewGock() + defer after(g) + for i := 0; i < 10; i++ { + go func() { + g.New("http://example.com") + }() + } +} + +func after(g *Gock) { + g.Flush() + g.Disable() +} diff --git a/threadsafe/matcher.go b/threadsafe/matcher.go new file mode 100644 index 0000000..ad10a84 --- /dev/null +++ b/threadsafe/matcher.go @@ -0,0 +1,116 @@ +package threadsafe + +import "net/http" + +// MatchFunc represents the required function +// interface implemented by matchers. +type MatchFunc func(*http.Request, *Request) (bool, error) + +// Matcher represents the required interface implemented by mock matchers. +type Matcher interface { + // Get returns a slice of registered function matchers. + Get() []MatchFunc + + // Add adds a new matcher function. + Add(MatchFunc) + + // Set sets the matchers functions stack. + Set([]MatchFunc) + + // Flush flushes the current matchers function stack. + Flush() + + // Match matches the given http.Request with a mock Request. + Match(*http.Request, *Request) (bool, error) +} + +// MockMatcher implements a mock matcher +type MockMatcher struct { + Matchers []MatchFunc + g *Gock +} + +// NewMatcher creates a new mock matcher +// using the default matcher functions. +func (g *Gock) NewMatcher() *MockMatcher { + m := g.NewEmptyMatcher() + for _, matchFn := range g.Matchers { + m.Add(matchFn) + } + return m +} + +// NewBasicMatcher creates a new matcher with header only mock matchers. +func (g *Gock) NewBasicMatcher() *MockMatcher { + m := g.NewEmptyMatcher() + for _, matchFn := range g.MatchersHeader { + m.Add(matchFn) + } + return m +} + +// NewEmptyMatcher creates a new empty matcher without default matchers. +func (g *Gock) NewEmptyMatcher() *MockMatcher { + return &MockMatcher{g: g, Matchers: []MatchFunc{}} +} + +// Get returns a slice of registered function matchers. +func (m *MockMatcher) Get() []MatchFunc { + m.g.mutex.Lock() + defer m.g.mutex.Unlock() + return m.Matchers +} + +// Add adds a new function matcher. +func (m *MockMatcher) Add(fn MatchFunc) { + m.Matchers = append(m.Matchers, fn) +} + +// Set sets a new stack of matchers functions. +func (m *MockMatcher) Set(stack []MatchFunc) { + m.Matchers = stack +} + +// Flush flushes the current matcher +func (m *MockMatcher) Flush() { + m.Matchers = []MatchFunc{} +} + +// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs +func (m *MockMatcher) Clone() *MockMatcher { + m2 := m.g.NewEmptyMatcher() + for _, mFn := range m.Get() { + m2.Add(mFn) + } + return m2 +} + +// Match matches the given http.Request with a mock request +// returning true in case that the request matches, otherwise false. +func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) { + for _, matcher := range m.Matchers { + matches, err := matcher(req, ereq) + if err != nil { + return false, err + } + if !matches { + return false, nil + } + } + return true, nil +} + +// MatchMock is a helper function that matches the given http.Request +// in the list of registered mocks, returning it if matches or error if it fails. +func (g *Gock) MatchMock(req *http.Request) (Mock, error) { + for _, mock := range g.GetAll() { + matches, err := mock.Match(req) + if err != nil { + return nil, err + } + if matches { + return mock, nil + } + } + return nil, nil +} diff --git a/threadsafe/matcher_test.go b/threadsafe/matcher_test.go new file mode 100644 index 0000000..c33a475 --- /dev/null +++ b/threadsafe/matcher_test.go @@ -0,0 +1,172 @@ +package threadsafe + +import ( + "net/http" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestRegisteredMatchers(t *testing.T) { + g := NewGock() + st.Expect(t, len(g.MatchersHeader), 7) + st.Expect(t, len(g.MatchersBody), 1) +} + +func TestNewMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + // Funcs are not comparable, checking slice length as it's better than nothing + // See https://golang.org/pkg/reflect/#DeepEqual + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + st.Expect(t, len(matcher.Get()), len(g.Matchers)) +} + +func TestNewBasicMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewBasicMatcher() + // Funcs are not comparable, checking slice length as it's better than nothing + // See https://golang.org/pkg/reflect/#DeepEqual + st.Expect(t, len(matcher.Matchers), len(g.MatchersHeader)) + st.Expect(t, len(matcher.Get()), len(g.MatchersHeader)) +} + +func TestNewEmptyMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewEmptyMatcher() + st.Expect(t, len(matcher.Matchers), 0) + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherAdd(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, len(matcher.Get()), len(g.Matchers)+1) +} + +func TestMatcherSet(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + matchers := []MatchFunc{} + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Set(matchers) + st.Expect(t, matcher.Matchers, matchers) + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherGet(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + matchers := []MatchFunc{} + matcher.Set(matchers) + st.Expect(t, matcher.Get(), matchers) +} + +func TestMatcherFlush(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, len(matcher.Get()), len(g.Matchers)+1) + matcher.Flush() + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherClone(t *testing.T) { + g := NewGock() + matcher := g.DefaultMatcher.Clone() + st.Expect(t, len(matcher.Get()), len(g.DefaultMatcher.Get())) +} + +func TestMatcher(t *testing.T) { + cases := []struct { + method string + url string + matches bool + }{ + {"GET", "http://foo.com/bar", true}, + {"GET", "http://foo.com/baz", true}, + {"GET", "http://foo.com/foo", false}, + {"POST", "http://foo.com/bar", false}, + {"POST", "http://bar.com/bar", false}, + {"GET", "http://foo.com", false}, + } + + g := NewGock() + matcher := g.NewMatcher() + matcher.Flush() + st.Expect(t, len(matcher.Matchers), 0) + + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Method == "GET", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil + }) + + for _, test := range cases { + u, _ := url.Parse(test.url) + req := &http.Request{Method: test.method, URL: u} + matches, err := matcher.Match(req, nil) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchMock(t *testing.T) { + cases := []struct { + method string + url string + matches bool + }{ + {"GET", "http://foo.com/bar", true}, + {"GET", "http://foo.com/baz", true}, + {"GET", "http://foo.com/foo", false}, + {"POST", "http://foo.com/bar", false}, + {"POST", "http://bar.com/bar", false}, + {"GET", "http://foo.com", false}, + } + + g := NewGock() + matcher := g.DefaultMatcher + matcher.Flush() + st.Expect(t, len(matcher.Matchers), 0) + + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Method == "GET", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil + }) + + for _, test := range cases { + g.Flush() + mock := g.New(test.url).method(test.method, "").Mock + + u, _ := url.Parse(test.url) + req := &http.Request{Method: test.method, URL: u} + + match, err := g.MatchMock(req) + st.Expect(t, err, nil) + if test.matches { + st.Expect(t, match, mock) + } else { + st.Expect(t, match, nil) + } + } + + g.DefaultMatcher.Matchers = g.Matchers +} diff --git a/threadsafe/matchers.go b/threadsafe/matchers.go new file mode 100644 index 0000000..9b1c0b3 --- /dev/null +++ b/threadsafe/matchers.go @@ -0,0 +1,240 @@ +package threadsafe + +import ( + "compress/gzip" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "reflect" + "regexp" + "strings" + + "github.com/h2non/parth" +) + +// EOL represents the end of line character. +const EOL = 0xa + +// MatchMethod matches the HTTP method of the given request. +func (g *Gock) MatchMethod(req *http.Request, ereq *Request) (bool, error) { + return ereq.Method == "" || req.Method == ereq.Method, nil +} + +// MatchScheme matches the request URL protocol scheme. +func (g *Gock) MatchScheme(req *http.Request, ereq *Request) (bool, error) { + return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil +} + +// MatchHost matches the HTTP host header field of the given request. +func (g *Gock) MatchHost(req *http.Request, ereq *Request) (bool, error) { + url := ereq.URLStruct + if strings.EqualFold(url.Host, req.URL.Host) { + return true, nil + } + if !ereq.Options.DisableRegexpHost { + return regexp.MatchString(url.Host, req.URL.Host) + } + return false, nil +} + +// MatchPath matches the HTTP URL path of the given request. +func (g *Gock) MatchPath(req *http.Request, ereq *Request) (bool, error) { + if req.URL.Path == ereq.URLStruct.Path { + return true, nil + } + return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path) +} + +// MatchHeaders matches the headers fields of the given request. +func (g *Gock) MatchHeaders(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.Header { + var err error + var match bool + var matchEscaped bool + + for _, field := range req.Header[key] { + match, err = regexp.MatchString(value[0], field) + // Some values may contain reserved regex params e.g. "()", try matching with these escaped. + matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field) + + if err != nil { + return false, err + } + if match || matchEscaped { + break + } + + } + + if !match && !matchEscaped { + return false, nil + } + } + return true, nil +} + +// MatchQueryParams matches the URL query params fields of the given request. +func (g *Gock) MatchQueryParams(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.URLStruct.Query() { + var err error + var match bool + + for _, field := range req.URL.Query()[key] { + match, err = regexp.MatchString(value[0], field) + if err != nil { + return false, err + } + if match { + break + } + } + + if !match { + return false, nil + } + } + return true, nil +} + +// MatchPathParams matches the URL path parameters of the given request. +func (g *Gock) MatchPathParams(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.PathParams { + var s string + + if err := parth.Sequent(req.URL.Path, key, &s); err != nil { + return false, nil + } + + if s != value { + return false, nil + } + } + return true, nil +} + +// MatchBody tries to match the request body. +// TODO: not too smart now, needs several improvements. +func (g *Gock) MatchBody(req *http.Request, ereq *Request) (bool, error) { + // If match body is empty, just continue + if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 { + return true, nil + } + + // Only can match certain MIME body types + if !g.supportedType(req, ereq) { + return false, nil + } + + // Can only match certain compression schemes + if !g.supportedCompressionScheme(req) { + return false, nil + } + + // Create a reader for the body depending on compression type + bodyReader := req.Body + if ereq.CompressionScheme != "" { + if ereq.CompressionScheme != req.Header.Get("Content-Encoding") { + return false, nil + } + compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme) + if err != nil { + return false, err + } + bodyReader = compressedBodyReader + } + + // Read the whole request body + body, err := ioutil.ReadAll(bodyReader) + if err != nil { + return false, err + } + + // Restore body reader stream + req.Body = createReadCloser(body) + + // If empty, ignore the match + if len(body) == 0 && len(ereq.BodyBuffer) != 0 { + return false, nil + } + + // Match body by atomic string comparison + bodyStr := castToString(body) + matchStr := castToString(ereq.BodyBuffer) + if bodyStr == matchStr { + return true, nil + } + + // Match request body by regexp + match, _ := regexp.MatchString(matchStr, bodyStr) + if match == true { + return true, nil + } + + // todo - add conditional do only perform the conversion of body bytes + // representation of JSON to a map and then compare them for equality. + + // Check if the key + value pairs match + var bodyMap map[string]interface{} + var matchMap map[string]interface{} + + // Ensure that both byte bodies that that should be JSON can be converted to maps. + umErr := json.Unmarshal(body, &bodyMap) + umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap) + if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) { + return true, nil + } + + return false, nil +} + +func (g *Gock) supportedType(req *http.Request, ereq *Request) bool { + mime := req.Header.Get("Content-Type") + if mime == "" { + return true + } + + mimeToMatch := ereq.Header.Get("Content-Type") + if mimeToMatch != "" { + return mime == mimeToMatch + } + + for _, kind := range g.BodyTypes { + if match, _ := regexp.MatchString(kind, mime); match { + return true + } + } + return false +} + +func (g *Gock) supportedCompressionScheme(req *http.Request) bool { + encoding := req.Header.Get("Content-Encoding") + if encoding == "" { + return true + } + + for _, kind := range g.CompressionSchemes { + if match, _ := regexp.MatchString(kind, encoding); match { + return true + } + } + return false +} + +func castToString(buf []byte) string { + str := string(buf) + tail := len(str) - 1 + if str[tail] == EOL { + str = str[:tail] + } + return str +} + +func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) { + switch scheme { + case "gzip": + return gzip.NewReader(r) + default: + return r, nil + } +} diff --git a/threadsafe/matchers_test.go b/threadsafe/matchers_test.go new file mode 100644 index 0000000..6db6c08 --- /dev/null +++ b/threadsafe/matchers_test.go @@ -0,0 +1,253 @@ +package threadsafe + +import ( + "net/http" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestMatchMethod(t *testing.T) { + cases := []struct { + value string + method string + matches bool + }{ + {"GET", "GET", true}, + {"POST", "POST", true}, + {"", "POST", true}, + {"POST", "GET", false}, + {"PUT", "GET", false}, + } + + for _, test := range cases { + req := &http.Request{Method: test.method} + ereq := &Request{Method: test.value} + matches, err := NewGock().MatchMethod(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchScheme(t *testing.T) { + cases := []struct { + value string + scheme string + matches bool + }{ + {"http", "http", true}, + {"https", "https", true}, + {"http", "https", false}, + {"", "https", true}, + {"https", "", true}, + } + + for _, test := range cases { + req := &http.Request{URL: &url.URL{Scheme: test.scheme}} + ereq := &Request{URLStruct: &url.URL{Scheme: test.value}} + matches, err := NewGock().MatchScheme(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchHost(t *testing.T) { + cases := []struct { + value string + url string + matches bool + matchesNonRegexp bool + }{ + {"foo.com", "foo.com", true, true}, + {"FOO.com", "foo.com", true, true}, + {"foo.net", "foo.com", false, false}, + {"foo.bar.net", "foo-bar.net", true, false}, + {"foo", "foo.com", true, false}, + {"(.*).com", "foo.com", true, false}, + {"127.0.0.1", "127.0.0.1", true, true}, + {"127.0.0.2", "127.0.0.1", false, false}, + {"127.0.0.*", "127.0.0.1", true, false}, + {"127.0.0.[0-9]", "127.0.0.7", true, false}, + } + + for _, test := range cases { + req := &http.Request{URL: &url.URL{Host: test.url}} + ereq := &Request{URLStruct: &url.URL{Host: test.value}} + matches, err := NewGock().MatchHost(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + ereq.WithOptions(Options{DisableRegexpHost: true}) + matches, err = NewGock().MatchHost(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matchesNonRegexp) + } +} + +func TestMatchPath(t *testing.T) { + cases := []struct { + value string + path string + matches bool + }{ + {"/foo", "/foo", true}, + {"/foo", "/foo/bar", true}, + {"bar", "/foo/bar", true}, + {"foo", "/foo/bar", true}, + {"bar$", "/foo/bar", true}, + {"/foo/*", "/foo/bar", true}, + {"/foo/[a-z]+", "/foo/bar", true}, + {"/foo/baz", "/foo/bar", false}, + {"/foo/baz", "/foo/bar", false}, + {"/foo/bar%3F+%C3%A9", "/foo/bar%3F+%C3%A9", true}, + } + + for _, test := range cases { + u, _ := url.Parse("http://foo.com" + test.path) + mu, _ := url.Parse("http://foo.com" + test.value) + req := &http.Request{URL: u} + ereq := &Request{URLStruct: mu} + matches, err := NewGock().MatchPath(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchHeaders(t *testing.T) { + cases := []struct { + values http.Header + headers http.Header + matches bool + }{ + {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, true}, + {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"bar": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false}, + {http.Header{"foofoo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false}, + {http.Header{"foo": []string{"bar(.*)"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"foo": []string{"b(.*)"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"bar"}}, true}, + {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"barbar"}}, false}, + {http.Header{"UPPERCASE": []string{"bar"}}, http.Header{"UPPERCASE": []string{"bar"}}, true}, + {http.Header{"Mixed-CASE": []string{"bar"}}, http.Header{"Mixed-CASE": []string{"bar"}}, true}, + {http.Header{"User-Agent": []string{"Agent (version1.0)"}}, http.Header{"User-Agent": []string{"Agent (version1.0)"}}, true}, + {http.Header{"Content-Type": []string{"(.*)/plain"}}, http.Header{"Content-Type": []string{"text/plain"}}, true}, + } + + for _, test := range cases { + req := &http.Request{Header: test.headers} + ereq := &Request{Header: test.values} + matches, err := NewGock().MatchHeaders(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchQueryParams(t *testing.T) { + cases := []struct { + value string + path string + matches bool + }{ + {"foo=bar", "foo=bar", true}, + {"foo=bar", "foo=foo&foo=bar", true}, + {"foo=b*", "foo=bar", true}, + {"foo=.*", "foo=bar", true}, + {"foo=f[o]{2}", "foo=foo", true}, + {"foo=bar&bar=foo", "foo=bar&foo=foo&bar=foo", true}, + {"foo=", "foo=bar", true}, + {"foo=foo", "foo=bar", false}, + {"bar=bar", "foo=bar bar", false}, + } + + for _, test := range cases { + u, _ := url.Parse("http://foo.com/?" + test.path) + mu, _ := url.Parse("http://foo.com/?" + test.value) + req := &http.Request{URL: u} + ereq := &Request{URLStruct: mu} + matches, err := NewGock().MatchQueryParams(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchPathParams(t *testing.T) { + cases := []struct { + key string + value string + path string + matches bool + }{ + {"foo", "bar", "/foo/bar", true}, + {"foo", "bar", "/foo/test/bar", false}, + {"foo", "bar", "/test/foo/bar/ack", true}, + {"foo", "bar", "/foo", false}, + } + + for i, test := range cases { + u, _ := url.Parse("http://foo.com" + test.path) + mu, _ := url.Parse("http://foo.com" + test.path) + req := &http.Request{URL: u} + ereq := &Request{ + URLStruct: mu, + PathParams: map[string]string{test.key: test.value}, + } + matches, err := NewGock().MatchPathParams(req, ereq) + st.Expect(t, err, nil, i) + st.Expect(t, matches, test.matches, i) + } +} + +func TestMatchBody(t *testing.T) { + cases := []struct { + value string + body string + matches bool + }{ + {"foo bar", "foo bar\n", true}, + {"foo", "foo bar\n", true}, + {"f[o]+", "foo\n", true}, + {`"foo"`, `{"foo":"bar"}\n`, true}, + {`{"foo":"bar"}`, `{"foo":"bar"}\n`, true}, + {`{"foo":"foo"}`, `{"foo":"bar"}\n`, false}, + + {`{"foo":"bar","bar":"foo"}`, `{"bar":"foo","foo":"bar"}`, true}, + {`{"bar":"foo","foo":{"two":"three","three":"two"}}`, `{"foo":{"three":"two","two":"three"},"bar":"foo"}`, true}, + } + + g := NewGock() + for _, test := range cases { + req := &http.Request{Body: createReadCloser([]byte(test.body))} + ereq := &Request{BodyBuffer: []byte(test.value)} + matches, err := g.MatchBody(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchBody_MatchType(t *testing.T) { + body := `{"foo":"bar"}` + cases := []struct { + body string + requestContentType string + customBodyType string + matches bool + }{ + {body, "application/vnd.apiname.v1+json", "foobar", false}, + {body, "application/vnd.apiname.v1+json", "application/vnd.apiname.v1+json", true}, + {body, "application/json", "foobar", false}, + {body, "application/json", "", true}, + {"", "", "", true}, + } + + g := NewGock() + for _, test := range cases { + req := &http.Request{ + Header: http.Header{"Content-Type": []string{test.requestContentType}}, + Body: createReadCloser([]byte(test.body)), + } + ereq := g.NewRequest().BodyString(test.body).MatchType(test.customBodyType) + matches, err := g.MatchBody(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} diff --git a/threadsafe/mock.go b/threadsafe/mock.go new file mode 100644 index 0000000..004263a --- /dev/null +++ b/threadsafe/mock.go @@ -0,0 +1,172 @@ +package threadsafe + +import ( + "net/http" + "sync" +) + +// Mock represents the required interface that must +// be implemented by HTTP mock instances. +type Mock interface { + // Disable disables the current mock manually. + Disable() + + // Done returns true if the current mock is disabled. + Done() bool + + // Request returns the mock Request instance. + Request() *Request + + // Response returns the mock Response instance. + Response() *Response + + // Match matches the given http.Request with the current mock. + Match(*http.Request) (bool, error) + + // AddMatcher adds a new matcher function. + AddMatcher(MatchFunc) + + // SetMatcher uses a new matcher implementation. + SetMatcher(Matcher) +} + +// Mocker implements a Mock capable interface providing +// a default mock configuration used internally to store mocks. +type Mocker struct { + // disabler stores a disabler for thread safety checking current mock is disabled + disabler *disabler + + // mutex stores the mock mutex for thread safety. + mutex sync.Mutex + + // matcher stores a Matcher capable instance to match the given http.Request. + matcher Matcher + + // request stores the mock Request to match. + request *Request + + // response stores the mock Response to use in case of match. + response *Response +} + +type disabler struct { + // disabled stores if the current mock is disabled. + disabled bool + + // mutex stores the disabler mutex for thread safety. + mutex sync.RWMutex +} + +func (d *disabler) isDisabled() bool { + d.mutex.RLock() + defer d.mutex.RUnlock() + return d.disabled +} + +func (d *disabler) Disable() { + d.mutex.Lock() + defer d.mutex.Unlock() + d.disabled = true +} + +// NewMock creates a new HTTP mock based on the given request and response instances. +// It's mostly used internally. +func (g *Gock) NewMock(req *Request, res *Response) *Mocker { + mock := &Mocker{ + disabler: new(disabler), + request: req, + response: res, + matcher: g.DefaultMatcher.Clone(), + } + res.Mock = mock + req.Mock = mock + req.Response = res + return mock +} + +// Disable disables the current mock manually. +func (m *Mocker) Disable() { + m.disabler.Disable() +} + +// Done returns true in case that the current mock +// instance is disabled and therefore must be removed. +func (m *Mocker) Done() bool { + // prevent deadlock with m.mutex + if m.disabler.isDisabled() { + return true + } + + m.mutex.Lock() + defer m.mutex.Unlock() + return !m.request.Persisted && m.request.Counter == 0 +} + +// Request returns the Request instance +// configured for the current HTTP mock. +func (m *Mocker) Request() *Request { + return m.request +} + +// Response returns the Response instance +// configured for the current HTTP mock. +func (m *Mocker) Response() *Response { + return m.response +} + +// Match matches the given http.Request with the current Request +// mock expectation, returning true if matches. +func (m *Mocker) Match(req *http.Request) (bool, error) { + if m.disabler.isDisabled() { + return false, nil + } + + // Filter + for _, filter := range m.request.Filters { + if !filter(req) { + return false, nil + } + } + + // Map + for _, mapper := range m.request.Mappers { + if treq := mapper(req); treq != nil { + req = treq + } + } + + // Match + matches, err := m.matcher.Match(req, m.request) + if matches { + m.decrement() + } + + return matches, err +} + +// SetMatcher sets a new matcher implementation +// for the current mock expectation. +func (m *Mocker) SetMatcher(matcher Matcher) { + m.matcher = matcher +} + +// AddMatcher adds a new matcher function +// for the current mock expectation. +func (m *Mocker) AddMatcher(fn MatchFunc) { + m.matcher.Add(fn) +} + +// decrement decrements the current mock Request counter. +func (m *Mocker) decrement() { + if m.request.Persisted { + return + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + m.request.Counter-- + if m.request.Counter == 0 { + m.disabler.Disable() + } +} diff --git a/threadsafe/mock_test.go b/threadsafe/mock_test.go new file mode 100644 index 0000000..f277842 --- /dev/null +++ b/threadsafe/mock_test.go @@ -0,0 +1,143 @@ +package threadsafe + +import ( + "net/http" + "testing" + + "github.com/nbio/st" +) + +func TestNewMock(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + + st.Expect(t, mock.Request(), req) + st.Expect(t, mock.Request().Mock, mock) + st.Expect(t, mock.Response(), res) + st.Expect(t, mock.Response().Mock, mock) +} + +func TestMockDisable(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, mock.disabler.isDisabled(), false) + mock.Disable() + st.Expect(t, mock.disabler.isDisabled(), true) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, false) +} + +func TestMockDone(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + + mock := g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.Done(), false) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.Disable() + st.Expect(t, mock.Done(), true) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.request.Counter = 0 + st.Expect(t, mock.Done(), true) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.request.Persisted = true + st.Expect(t, mock.Done(), false) +} + +func TestMockSetMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + matcher := g.NewMatcher() + matcher.Flush() + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + mock.SetMatcher(matcher) + st.Expect(t, len(mock.matcher.Get()), 1) + st.Expect(t, mock.disabler.isDisabled(), false) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, true) +} + +func TestMockAddMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + matcher := g.NewMatcher() + matcher.Flush() + mock.SetMatcher(matcher) + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.matcher, matcher) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, true) +} + +func TestMockMatch(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + matcher := g.NewMatcher() + matcher.Flush() + mock.SetMatcher(matcher) + calls := 0 + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + calls++ + return true, nil + }) + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + calls++ + return true, nil + }) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.matcher, matcher) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, calls, 2) + st.Expect(t, matches, true) +} diff --git a/threadsafe/options.go b/threadsafe/options.go new file mode 100644 index 0000000..98497f9 --- /dev/null +++ b/threadsafe/options.go @@ -0,0 +1,8 @@ +package threadsafe + +// Options represents customized option for gock +type Options struct { + // DisableRegexpHost stores if the host is only a plain string rather than regular expression, + // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string + DisableRegexpHost bool +} diff --git a/threadsafe/request.go b/threadsafe/request.go new file mode 100644 index 0000000..3508bbb --- /dev/null +++ b/threadsafe/request.go @@ -0,0 +1,330 @@ +package threadsafe + +import ( + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +// MapRequestFunc represents the required function interface for request mappers. +type MapRequestFunc func(*http.Request) *http.Request + +// FilterRequestFunc represents the required function interface for request filters. +type FilterRequestFunc func(*http.Request) bool + +// Request represents the high-level HTTP request used to store +// request fields used to match intercepted requests. +type Request struct { + g *Gock + + // Mock stores the parent mock reference for the current request mock used for method delegation. + Mock Mock + + // Response stores the current Response instance for the current matches Request. + Response *Response + + // Error stores the latest mock request configuration error. + Error error + + // Counter stores the pending times that the current mock should be active. + Counter int + + // Persisted stores if the current mock should be always active. + Persisted bool + + // Options stores options for current Request. + Options Options + + // URLStruct stores the parsed URL as *url.URL struct. + URLStruct *url.URL + + // Method stores the Request HTTP method to match. + Method string + + // CompressionScheme stores the Request Compression scheme to match and use for decompression. + CompressionScheme string + + // Header stores the HTTP header fields to match. + Header http.Header + + // Cookies stores the Request HTTP cookies values to match. + Cookies []*http.Cookie + + // PathParams stores the path parameters to match. + PathParams map[string]string + + // BodyBuffer stores the body data to match. + BodyBuffer []byte + + // Mappers stores the request functions mappers used for matching. + Mappers []MapRequestFunc + + // Filters stores the request functions filters used for matching. + Filters []FilterRequestFunc +} + +// NewRequest creates a new Request instance. +func (g *Gock) NewRequest() *Request { + return &Request{ + g: g, + Counter: 1, + URLStruct: &url.URL{}, + Header: make(http.Header), + PathParams: make(map[string]string), + } +} + +// URL defines the mock URL to match. +func (r *Request) URL(uri string) *Request { + r.URLStruct, r.Error = url.Parse(uri) + return r +} + +// SetURL defines the url.URL struct to be used for matching. +func (r *Request) SetURL(u *url.URL) *Request { + r.URLStruct = u + return r +} + +// Path defines the mock URL path value to match. +func (r *Request) Path(path string) *Request { + r.URLStruct.Path = path + return r +} + +// Get specifies the GET method and the given URL path to match. +func (r *Request) Get(path string) *Request { + return r.method("GET", path) +} + +// Post specifies the POST method and the given URL path to match. +func (r *Request) Post(path string) *Request { + return r.method("POST", path) +} + +// Put specifies the PUT method and the given URL path to match. +func (r *Request) Put(path string) *Request { + return r.method("PUT", path) +} + +// Delete specifies the DELETE method and the given URL path to match. +func (r *Request) Delete(path string) *Request { + return r.method("DELETE", path) +} + +// Patch specifies the PATCH method and the given URL path to match. +func (r *Request) Patch(path string) *Request { + return r.method("PATCH", path) +} + +// Head specifies the HEAD method and the given URL path to match. +func (r *Request) Head(path string) *Request { + return r.method("HEAD", path) +} + +// method is a DRY shortcut used to declare the expected HTTP method and URL path. +func (r *Request) method(method, path string) *Request { + if path != "/" { + r.URLStruct.Path = path + } + r.Method = strings.ToUpper(method) + return r +} + +// Body defines the body data to match based on a io.Reader interface. +func (r *Request) Body(body io.Reader) *Request { + r.BodyBuffer, r.Error = ioutil.ReadAll(body) + return r +} + +// BodyString defines the body to match based on a given string. +func (r *Request) BodyString(body string) *Request { + r.BodyBuffer = []byte(body) + return r +} + +// File defines the body to match based on the given file path string. +func (r *Request) File(path string) *Request { + r.BodyBuffer, r.Error = ioutil.ReadFile(path) + return r +} + +// Compression defines the request compression scheme, and enables automatic body decompression. +// Supports only the "gzip" scheme so far. +func (r *Request) Compression(scheme string) *Request { + r.Header.Set("Content-Encoding", scheme) + r.CompressionScheme = scheme + return r +} + +// JSON defines the JSON body to match based on a given structure. +func (r *Request) JSON(data interface{}) *Request { + if r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "application/json") + } + r.BodyBuffer, r.Error = readAndDecode(data, "json") + return r +} + +// XML defines the XML body to match based on a given structure. +func (r *Request) XML(data interface{}) *Request { + if r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "application/xml") + } + r.BodyBuffer, r.Error = readAndDecode(data, "xml") + return r +} + +// MatchType defines the request Content-Type MIME header field. +// Supports custom MIME types and type aliases. E.g: json, xml, form, text... +func (r *Request) MatchType(kind string) *Request { + mime := r.g.BodyTypeAliases[kind] + if mime != "" { + kind = mime + } + r.Header.Set("Content-Type", kind) + return r +} + +// BasicAuth defines a username and password for HTTP Basic Authentication +func (r *Request) BasicAuth(username, password string) *Request { + r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + return r +} + +// MatchHeader defines a new key and value header to match. +func (r *Request) MatchHeader(key, value string) *Request { + r.Header.Set(key, value) + return r +} + +// HeaderPresent defines that a header field must be present in the request. +func (r *Request) HeaderPresent(key string) *Request { + r.Header.Set(key, ".*") + return r +} + +// MatchHeaders defines a map of key-value headers to match. +func (r *Request) MatchHeaders(headers map[string]string) *Request { + for key, value := range headers { + r.Header.Set(key, value) + } + return r +} + +// MatchParam defines a new key and value URL query param to match. +func (r *Request) MatchParam(key, value string) *Request { + query := r.URLStruct.Query() + query.Set(key, value) + r.URLStruct.RawQuery = query.Encode() + return r +} + +// MatchParams defines a map of URL query param key-value to match. +func (r *Request) MatchParams(params map[string]string) *Request { + query := r.URLStruct.Query() + for key, value := range params { + query.Set(key, value) + } + r.URLStruct.RawQuery = query.Encode() + return r +} + +// ParamPresent matches if the given query param key is present in the URL. +func (r *Request) ParamPresent(key string) *Request { + r.MatchParam(key, ".*") + return r +} + +// PathParam matches if a given path parameter key is present in the URL. +// +// The value is representative of the restful resource the key defines, e.g. +// +// // /users/123/name +// r.PathParam("users", "123") +// +// would match. +func (r *Request) PathParam(key, val string) *Request { + r.PathParams[key] = val + + return r +} + +// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it. +func (r *Request) Persist() *Request { + r.Persisted = true + return r +} + +// WithOptions sets the options for the request. +func (r *Request) WithOptions(options Options) *Request { + r.Options = options + return r +} + +// Times defines the number of times that the current HTTP mock should remain active. +func (r *Request) Times(num int) *Request { + r.Counter = num + return r +} + +// AddMatcher adds a new matcher function to match the request. +func (r *Request) AddMatcher(fn MatchFunc) *Request { + r.Mock.AddMatcher(fn) + return r +} + +// SetMatcher sets a new matcher function to match the request. +func (r *Request) SetMatcher(matcher Matcher) *Request { + r.Mock.SetMatcher(matcher) + return r +} + +// Map adds a new request mapper function to map http.Request before the matching process. +func (r *Request) Map(fn MapRequestFunc) *Request { + r.Mappers = append(r.Mappers, fn) + return r +} + +// Filter filters a new request filter function to filter http.Request before the matching process. +func (r *Request) Filter(fn FilterRequestFunc) *Request { + r.Filters = append(r.Filters, fn) + return r +} + +// EnableNetworking enables the use real networking for the current mock. +func (r *Request) EnableNetworking() *Request { + if r.Response != nil { + r.Response.UseNetwork = true + } + return r +} + +// Reply defines the Response status code and returns the mock Response DSL. +func (r *Request) Reply(status int) *Response { + return r.Response.Status(status) +} + +// ReplyError defines the Response simulated error. +func (r *Request) ReplyError(err error) *Response { + return r.Response.SetError(err) +} + +// ReplyFunc allows the developer to define the mock response via a custom function. +func (r *Request) ReplyFunc(replier func(*Response)) *Response { + replier(r.Response) + return r.Response +} + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} diff --git a/threadsafe/request_test.go b/threadsafe/request_test.go new file mode 100644 index 0000000..e67614f --- /dev/null +++ b/threadsafe/request_test.go @@ -0,0 +1,318 @@ +package threadsafe + +import ( + "bytes" + "net/http" + "net/url" + "path/filepath" + "testing" + + "github.com/nbio/st" +) + +func TestNewRequest(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + st.Expect(t, req.URLStruct.Host, "foo.com") + st.Expect(t, req.URLStruct.Scheme, "http") + req.MatchHeader("foo", "bar") + st.Expect(t, req.Header.Get("foo"), "bar") +} + +func TestRequestSetURL(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + req.SetURL(&url.URL{Host: "bar.com", Path: "/foo"}) + st.Expect(t, req.URLStruct.Host, "bar.com") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestPath(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + req.Path("/foo") + st.Expect(t, req.URLStruct.Scheme, "http") + st.Expect(t, req.URLStruct.Host, "foo.com") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestBody(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Body(bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, string(req.BodyBuffer), "foo bar") +} + +func TestRequestBodyString(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.BodyString("foo bar") + st.Expect(t, string(req.BodyBuffer), "foo bar") +} + +func TestRequestFile(t *testing.T) { + g := NewGock() + req := g.NewRequest() + absPath, err := filepath.Abs("../version.go") + st.Expect(t, err, nil) + req.File(absPath) + st.Expect(t, string(req.BodyBuffer)[:12], "package gock") +} + +func TestRequestJSON(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.JSON(map[string]string{"foo": "bar"}) + st.Expect(t, string(req.BodyBuffer)[:13], `{"foo":"bar"}`) + st.Expect(t, req.Header.Get("Content-Type"), "application/json") +} + +func TestRequestXML(t *testing.T) { + g := NewGock() + req := g.NewRequest() + type xml struct { + Data string `xml:"data"` + } + req.XML(xml{Data: "foo"}) + st.Expect(t, string(req.BodyBuffer), `foo`) + st.Expect(t, req.Header.Get("Content-Type"), "application/xml") +} + +func TestRequestMatchType(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchType("json") + st.Expect(t, req.Header.Get("Content-Type"), "application/json") + + req = g.NewRequest() + req.MatchType("html") + st.Expect(t, req.Header.Get("Content-Type"), "text/html") + + req = g.NewRequest() + req.MatchType("foo/bar") + st.Expect(t, req.Header.Get("Content-Type"), "foo/bar") +} + +func TestRequestBasicAuth(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.BasicAuth("bob", "qwerty") + st.Expect(t, req.Header.Get("Authorization"), "Basic Ym9iOnF3ZXJ0eQ==") +} + +func TestRequestMatchHeader(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchHeader("foo", "bar") + req.MatchHeader("bar", "baz") + req.MatchHeader("UPPERCASE", "bat") + req.MatchHeader("Mixed-CASE", "foo") + + st.Expect(t, req.Header.Get("foo"), "bar") + st.Expect(t, req.Header.Get("bar"), "baz") + st.Expect(t, req.Header.Get("UPPERCASE"), "bat") + st.Expect(t, req.Header.Get("Mixed-CASE"), "foo") +} + +func TestRequestHeaderPresent(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.HeaderPresent("foo") + req.HeaderPresent("bar") + req.HeaderPresent("UPPERCASE") + req.HeaderPresent("Mixed-CASE") + st.Expect(t, req.Header.Get("foo"), ".*") + st.Expect(t, req.Header.Get("bar"), ".*") + st.Expect(t, req.Header.Get("UPPERCASE"), ".*") + st.Expect(t, req.Header.Get("Mixed-CASE"), ".*") +} + +func TestRequestMatchParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchParam("foo", "bar") + req.MatchParam("bar", "baz") + st.Expect(t, req.URLStruct.Query().Get("foo"), "bar") + st.Expect(t, req.URLStruct.Query().Get("bar"), "baz") +} + +func TestRequestMatchParams(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchParams(map[string]string{"foo": "bar", "bar": "baz"}) + st.Expect(t, req.URLStruct.Query().Get("foo"), "bar") + st.Expect(t, req.URLStruct.Query().Get("bar"), "baz") +} + +func TestRequestPresentParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.ParamPresent("key") + st.Expect(t, req.URLStruct.Query().Get("key"), ".*") +} + +func TestRequestPathParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.PathParam("key", "value") + st.Expect(t, req.PathParams["key"], "value") +} + +func TestRequestPersist(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, req.Persisted, false) + req.Persist() + st.Expect(t, req.Persisted, true) +} + +func TestRequestTimes(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, req.Counter, 1) + req.Times(3) + st.Expect(t, req.Counter, 3) +} + +func TestRequestMap(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, len(req.Mappers), 0) + req.Map(func(req *http.Request) *http.Request { + return req + }) + st.Expect(t, len(req.Mappers), 1) +} + +func TestRequestFilter(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, len(req.Filters), 0) + req.Filter(func(req *http.Request) bool { + return true + }) + st.Expect(t, len(req.Filters), 1) +} + +func TestRequestEnableNetworking(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Response = &Response{} + st.Expect(t, req.Response.UseNetwork, false) + req.EnableNetworking() + st.Expect(t, req.Response.UseNetwork, true) +} + +func TestRequestResponse(t *testing.T) { + g := NewGock() + req := g.NewRequest() + res := g.NewResponse() + req.Response = res + chain := req.Reply(200) + st.Expect(t, chain, res) + st.Expect(t, chain.StatusCode, 200) +} + +func TestRequestReplyFunc(t *testing.T) { + g := NewGock() + req := g.NewRequest() + res := g.NewResponse() + req.Response = res + chain := req.ReplyFunc(func(r *Response) { + r.Status(204) + }) + st.Expect(t, chain, res) + st.Expect(t, chain.StatusCode, 204) +} + +func TestRequestMethods(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Get("/foo") + st.Expect(t, req.Method, "GET") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Post("/foo") + st.Expect(t, req.Method, "POST") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Put("/foo") + st.Expect(t, req.Method, "PUT") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Delete("/foo") + st.Expect(t, req.Method, "DELETE") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Patch("/foo") + st.Expect(t, req.Method, "PATCH") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Head("/foo") + st.Expect(t, req.Method, "HEAD") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestSetMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + matcher := g.NewEmptyMatcher() + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Header.Get("foo") == "bar", nil + }) + ereq := g.NewRequest() + mock := g.NewMock(ereq, &Response{}) + mock.SetMatcher(matcher) + ereq.Mock = mock + + headers := make(http.Header) + headers.Set("foo", "bar") + req := &http.Request{ + URL: &url.URL{Host: "foo.com", Path: "/bar"}, + Header: headers, + } + + match, err := ereq.Mock.Match(req) + st.Expect(t, err, nil) + st.Expect(t, match, true) +} + +func TestRequestAddMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + ereq := g.NewRequest() + mock := g.NewMock(ereq, &Response{}) + mock.matcher = g.NewMatcher() + ereq.Mock = mock + + ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return req.Header.Get("foo") == "bar", nil + }) + + headers := make(http.Header) + headers.Set("foo", "bar") + req := &http.Request{ + URL: &url.URL{Host: "foo.com", Path: "/bar"}, + Header: headers, + } + + match, err := ereq.Mock.Match(req) + st.Expect(t, err, nil) + st.Expect(t, match, true) +} diff --git a/threadsafe/responder.go b/threadsafe/responder.go new file mode 100644 index 0000000..5dc4a7d --- /dev/null +++ b/threadsafe/responder.go @@ -0,0 +1,111 @@ +package threadsafe + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "strconv" + "time" +) + +// Responder builds a mock http.Response based on the given Response mock. +func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) { + // If error present, reply it + err := mock.Error + if err != nil { + return nil, err + } + + if res == nil { + res = createResponse(req) + } + + // Apply response filter + for _, filter := range mock.Filters { + if !filter(res) { + return res, nil + } + } + + // Define mock status code + if mock.StatusCode != 0 { + res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode) + res.StatusCode = mock.StatusCode + } + + // Define headers by merging fields + res.Header = mergeHeaders(res, mock) + + // Define mock body, if present + if len(mock.BodyBuffer) > 0 { + res.ContentLength = int64(len(mock.BodyBuffer)) + res.Body = createReadCloser(mock.BodyBuffer) + } + + // Set raw mock body, if exist + if mock.BodyGen != nil { + res.ContentLength = -1 + res.Body = mock.BodyGen() + } + + // Apply response mappers + for _, mapper := range mock.Mappers { + if tres := mapper(res); tres != nil { + res = tres + } + } + + // Sleep to simulate delay, if necessary + if mock.ResponseDelay > 0 { + // allow escaping from sleep due to request context expiration or cancellation + t := time.NewTimer(mock.ResponseDelay) + select { + case <-t.C: + case <-req.Context().Done(): + // cleanly stop the timer + if !t.Stop() { + <-t.C + } + } + } + + // check if the request context has ended. we could put this up in the delay code above, but putting it here + // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.) + if err = req.Context().Err(); err != nil { + // cleanly close the response and return the context error + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + return nil, err + } + + return res, err +} + +// createResponse creates a new http.Response with default fields. +func createResponse(req *http.Request) *http.Response { + return &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Proto: "HTTP/1.1", + Request: req, + Header: make(http.Header), + Body: createReadCloser([]byte{}), + } +} + +// mergeHeaders copies the mock headers. +func mergeHeaders(res *http.Response, mres *Response) http.Header { + for key, values := range mres.Header { + for _, value := range values { + res.Header.Add(key, value) + } + } + return res.Header +} + +// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an +// http response body. +func createReadCloser(body []byte) io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader(body)) +} diff --git a/threadsafe/responder_test.go b/threadsafe/responder_test.go new file mode 100644 index 0000000..7d18820 --- /dev/null +++ b/threadsafe/responder_test.go @@ -0,0 +1,191 @@ +package threadsafe + +import ( + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + "github.com/nbio/st" +) + +func TestResponder(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Reply(200).BodyString("foo") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") +} + +func TestResponder_ReadTwice(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Reply(200).BodyString("foo") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderBodyGenerator(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") +} + +func TestResponderBodyGenerator_ReadTwice(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderBodyGenerator_Override(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator).BodyString("bar") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderSupportsMultipleHeadersWithSameKey(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo"). + Reply(200). + AddHeader("Set-Cookie", "a=1"). + AddHeader("Set-Cookie", "b=2") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Header, http.Header{"Set-Cookie": []string{"a=1", "b=2"}}) +} + +func TestResponderError(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").ReplyError(errors.New("error")) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err.Error(), "error") + st.Expect(t, res == nil, true) +} + +func TestResponderCancelledContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo") + + // create a context and schedule a call to cancel in 10ms + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.Canceled) + st.Expect(t, res == nil, true) +} + +func TestResponderExpiredContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo") + + // create a context that is set to expire in 10ms + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.DeadlineExceeded) + st.Expect(t, res == nil, true) +} + +func TestResponderPreExpiredContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).BodyString("foo") + + // create a context and wait to ensure it is expired + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Microsecond) + defer cancel() + time.Sleep(1 * time.Millisecond) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.DeadlineExceeded) + st.Expect(t, res == nil, true) +} diff --git a/threadsafe/response.go b/threadsafe/response.go new file mode 100644 index 0000000..04de096 --- /dev/null +++ b/threadsafe/response.go @@ -0,0 +1,198 @@ +package threadsafe + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "net/http" + "time" +) + +// MapResponseFunc represents the required function interface impletemed by response mappers. +type MapResponseFunc func(*http.Response) *http.Response + +// FilterResponseFunc represents the required function interface impletemed by response filters. +type FilterResponseFunc func(*http.Response) bool + +// Response represents high-level HTTP fields to configure +// and define HTTP responses intercepted by gock. +type Response struct { + g *Gock + + // Mock stores the parent mock reference for the current response mock used for method delegation. + Mock Mock + + // Error stores the latest response configuration or injected error. + Error error + + // UseNetwork enables the use of real network for the current mock. + UseNetwork bool + + // StatusCode stores the response status code. + StatusCode int + + // Headers stores the response headers. + Header http.Header + + // Cookies stores the response cookie fields. + Cookies []*http.Cookie + + // BodyGen stores a io.ReadCloser generator to be returned. + BodyGen func() io.ReadCloser + + // BodyBuffer stores the array of bytes to use as body. + BodyBuffer []byte + + // ResponseDelay stores the simulated response delay. + ResponseDelay time.Duration + + // Mappers stores the request functions mappers used for matching. + Mappers []MapResponseFunc + + // Filters stores the request functions filters used for matching. + Filters []FilterResponseFunc +} + +// NewResponse creates a new Response. +func (g *Gock) NewResponse() *Response { + return &Response{g: g, Header: make(http.Header)} +} + +// Status defines the desired HTTP status code to reply in the current response. +func (r *Response) Status(code int) *Response { + r.StatusCode = code + return r +} + +// Type defines the response Content-Type MIME header field. +// Supports type alias. E.g: json, xml, form, text... +func (r *Response) Type(kind string) *Response { + mime := r.g.BodyTypeAliases[kind] + if mime != "" { + kind = mime + } + r.Header.Set("Content-Type", kind) + return r +} + +// SetHeader sets a new header field in the mock response. +func (r *Response) SetHeader(key, value string) *Response { + r.Header.Set(key, value) + return r +} + +// AddHeader adds a new header field in the mock response +// with out removing an existent one. +func (r *Response) AddHeader(key, value string) *Response { + r.Header.Add(key, value) + return r +} + +// SetHeaders sets a map of header fields in the mock response. +func (r *Response) SetHeaders(headers map[string]string) *Response { + for key, value := range headers { + r.Header.Add(key, value) + } + return r +} + +// Body sets the HTTP response body to be used. +func (r *Response) Body(body io.Reader) *Response { + r.BodyBuffer, r.Error = ioutil.ReadAll(body) + return r +} + +// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser +// for every response. This will take priority than other Body methods used. +func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response { + r.BodyGen = generator + return r +} + +// BodyString defines the response body as string. +func (r *Response) BodyString(body string) *Response { + r.BodyBuffer = []byte(body) + return r +} + +// File defines the response body reading the data +// from disk based on the file path string. +func (r *Response) File(path string) *Response { + r.BodyBuffer, r.Error = ioutil.ReadFile(path) + return r +} + +// JSON defines the response body based on a JSON based input. +func (r *Response) JSON(data interface{}) *Response { + r.Header.Set("Content-Type", "application/json") + r.BodyBuffer, r.Error = readAndDecode(data, "json") + return r +} + +// XML defines the response body based on a XML based input. +func (r *Response) XML(data interface{}) *Response { + r.Header.Set("Content-Type", "application/xml") + r.BodyBuffer, r.Error = readAndDecode(data, "xml") + return r +} + +// SetError defines the response simulated error. +func (r *Response) SetError(err error) *Response { + r.Error = err + return r +} + +// Delay defines the response simulated delay. +// This feature is still experimental and will be improved in the future. +func (r *Response) Delay(delay time.Duration) *Response { + r.ResponseDelay = delay + return r +} + +// Map adds a new response mapper function to map http.Response before the matching process. +func (r *Response) Map(fn MapResponseFunc) *Response { + r.Mappers = append(r.Mappers, fn) + return r +} + +// Filter filters a new request filter function to filter http.Request before the matching process. +func (r *Response) Filter(fn FilterResponseFunc) *Response { + r.Filters = append(r.Filters, fn) + return r +} + +// EnableNetworking enables the use real networking for the current mock. +func (r *Response) EnableNetworking() *Response { + r.UseNetwork = true + return r +} + +// Done returns true if the mock was done and disabled. +func (r *Response) Done() bool { + return r.Mock.Done() +} + +func readAndDecode(data interface{}, kind string) ([]byte, error) { + buf := &bytes.Buffer{} + + switch data.(type) { + case string: + buf.WriteString(data.(string)) + case []byte: + buf.Write(data.([]byte)) + default: + var err error + if kind == "xml" { + err = xml.NewEncoder(buf).Encode(data) + } else { + err = json.NewEncoder(buf).Encode(data) + } + if err != nil { + return nil, err + } + } + + return ioutil.ReadAll(buf) +} diff --git a/threadsafe/response_test.go b/threadsafe/response_test.go new file mode 100644 index 0000000..0a44c63 --- /dev/null +++ b/threadsafe/response_test.go @@ -0,0 +1,186 @@ +package threadsafe + +import ( + "bytes" + "errors" + "io" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/nbio/st" +) + +func TestNewResponse(t *testing.T) { + g := NewGock() + res := g.NewResponse() + + res.Status(200) + st.Expect(t, res.StatusCode, 200) + + res.SetHeader("foo", "bar") + st.Expect(t, res.Header.Get("foo"), "bar") + + res.Delay(1000 * time.Millisecond) + st.Expect(t, res.ResponseDelay, 1000*time.Millisecond) + + res.EnableNetworking() + st.Expect(t, res.UseNetwork, true) +} + +func TestResponseStatus(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.StatusCode, 0) + res.Status(200) + st.Expect(t, res.StatusCode, 200) +} + +func TestResponseType(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Type("json") + st.Expect(t, res.Header.Get("Content-Type"), "application/json") + + res = g.NewResponse() + res.Type("xml") + st.Expect(t, res.Header.Get("Content-Type"), "application/xml") + + res = g.NewResponse() + res.Type("foo/bar") + st.Expect(t, res.Header.Get("Content-Type"), "foo/bar") +} + +func TestResponseSetHeader(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.SetHeader("foo", "bar") + res.SetHeader("bar", "baz") + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header.Get("bar"), "baz") +} + +func TestResponseAddHeader(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.AddHeader("foo", "bar") + res.AddHeader("foo", "baz") + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header["Foo"][1], "baz") +} + +func TestResponseSetHeaders(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.SetHeaders(map[string]string{"foo": "bar", "bar": "baz"}) + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header.Get("bar"), "baz") +} + +func TestResponseBody(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Body(bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, string(res.BodyBuffer), "foo bar") +} + +func TestResponseBodyGenerator(t *testing.T) { + g := NewGock() + res := g.NewResponse() + generator := func() io.ReadCloser { + return io.NopCloser(bytes.NewBuffer([]byte("foo bar"))) + } + res.BodyGenerator(generator) + bytes, err := io.ReadAll(res.BodyGen()) + st.Expect(t, err, nil) + st.Expect(t, string(bytes), "foo bar") +} + +func TestResponseBodyString(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.BodyString("foo bar") + st.Expect(t, string(res.BodyBuffer), "foo bar") +} + +func TestResponseFile(t *testing.T) { + g := NewGock() + res := g.NewResponse() + absPath, err := filepath.Abs("../version.go") + st.Expect(t, err, nil) + res.File(absPath) + st.Expect(t, string(res.BodyBuffer)[:12], "package gock") +} + +func TestResponseJSON(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.JSON(map[string]string{"foo": "bar"}) + st.Expect(t, string(res.BodyBuffer)[:13], `{"foo":"bar"}`) + st.Expect(t, res.Header.Get("Content-Type"), "application/json") +} + +func TestResponseXML(t *testing.T) { + g := NewGock() + res := g.NewResponse() + type xml struct { + Data string `xml:"data"` + } + res.XML(xml{Data: "foo"}) + st.Expect(t, string(res.BodyBuffer), `foo`) + st.Expect(t, res.Header.Get("Content-Type"), "application/xml") +} + +func TestResponseMap(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, len(res.Mappers), 0) + res.Map(func(res *http.Response) *http.Response { + return res + }) + st.Expect(t, len(res.Mappers), 1) +} + +func TestResponseFilter(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, len(res.Filters), 0) + res.Filter(func(res *http.Response) bool { + return true + }) + st.Expect(t, len(res.Filters), 1) +} + +func TestResponseSetError(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.Error, nil) + res.SetError(errors.New("foo error")) + st.Expect(t, res.Error.Error(), "foo error") +} + +func TestResponseDelay(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.ResponseDelay, 0*time.Microsecond) + res.Delay(100 * time.Millisecond) + st.Expect(t, res.ResponseDelay, 100*time.Millisecond) +} + +func TestResponseEnableNetworking(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.UseNetwork, false) + res.EnableNetworking() + st.Expect(t, res.UseNetwork, true) +} + +func TestResponseDone(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)} + st.Expect(t, res.Done(), false) + res.Mock.Disable() + st.Expect(t, res.Done(), true) +} diff --git a/threadsafe/store.go b/threadsafe/store.go new file mode 100644 index 0000000..d22a02e --- /dev/null +++ b/threadsafe/store.go @@ -0,0 +1,90 @@ +package threadsafe + +// Register registers a new mock in the current mocks stack. +func (g *Gock) Register(mock Mock) { + if g.Exists(mock) { + return + } + + // Make ops thread safe + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + + // Expose mock in request/response for delegation + mock.Request().Mock = mock + mock.Response().Mock = mock + + // Registers the mock in the global store + g.mocks = append(g.mocks, mock) +} + +// GetAll returns the current stack of registered mocks. +func (g *Gock) GetAll() []Mock { + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + return g.mocks +} + +// Exists checks if the given Mock is already registered. +func (g *Gock) Exists(m Mock) bool { + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + for _, mock := range g.mocks { + if mock == m { + return true + } + } + return false +} + +// Remove removes a registered mock by reference. +func (g *Gock) Remove(m Mock) { + for i, mock := range g.mocks { + if mock == m { + g.storeMutex.Lock() + g.mocks = append(g.mocks[:i], g.mocks[i+1:]...) + g.storeMutex.Unlock() + } + } +} + +// Flush flushes the current stack of registered mocks. +func (g *Gock) Flush() { + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + g.mocks = []Mock{} +} + +// Pending returns an slice of pending mocks. +func (g *Gock) Pending() []Mock { + g.Clean() + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + return g.mocks +} + +// IsDone returns true if all the registered mocks has been triggered successfully. +func (g *Gock) IsDone() bool { + return !g.IsPending() +} + +// IsPending returns true if there are pending mocks. +func (g *Gock) IsPending() bool { + return len(g.Pending()) > 0 +} + +// Clean cleans the mocks store removing disabled or obsolete mocks. +func (g *Gock) Clean() { + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + + buf := []Mock{} + for _, mock := range g.mocks { + if mock.Done() { + continue + } + buf = append(buf, mock) + } + + g.mocks = buf +} diff --git a/threadsafe/store_test.go b/threadsafe/store_test.go new file mode 100644 index 0000000..e4081bb --- /dev/null +++ b/threadsafe/store_test.go @@ -0,0 +1,95 @@ +package threadsafe + +import ( + "testing" + + "github.com/nbio/st" +) + +func TestStoreRegister(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + g.Register(mock) + st.Expect(t, len(g.mocks), 1) + st.Expect(t, mock.Request().Mock, mock) + st.Expect(t, mock.Response().Mock, mock) +} + +func TestStoreGetAll(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + store := g.GetAll() + st.Expect(t, len(g.mocks), 1) + st.Expect(t, len(store), 1) + st.Expect(t, store[0], mock) +} + +func TestStoreExists(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + st.Expect(t, len(g.mocks), 1) + st.Expect(t, g.Exists(mock), true) +} + +func TestStorePending(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.mocks, g.Pending()) +} + +func TestStoreIsPending(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.IsPending(), true) + g.Flush() + st.Expect(t, g.IsPending(), false) +} + +func TestStoreIsDone(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.IsDone(), false) + g.Flush() + st.Expect(t, g.IsDone(), true) +} + +func TestStoreRemove(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + st.Expect(t, len(g.mocks), 1) + st.Expect(t, g.Exists(mock), true) + + g.Remove(mock) + st.Expect(t, g.Exists(mock), false) + + g.Remove(mock) + st.Expect(t, g.Exists(mock), false) +} + +func TestStoreFlush(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + + mock1 := g.New("foo").Mock + mock2 := g.New("foo").Mock + st.Expect(t, len(g.mocks), 2) + st.Expect(t, g.Exists(mock1), true) + st.Expect(t, g.Exists(mock2), true) + + g.Flush() + st.Expect(t, len(g.mocks), 0) + st.Expect(t, g.Exists(mock1), false) + st.Expect(t, g.Exists(mock2), false) +} diff --git a/threadsafe/transport.go b/threadsafe/transport.go new file mode 100644 index 0000000..ee1af5e --- /dev/null +++ b/threadsafe/transport.go @@ -0,0 +1,112 @@ +package threadsafe + +import ( + "errors" + "net/http" + "sync" +) + +var ( + // ErrCannotMatch store the error returned in case of no matches. + ErrCannotMatch = errors.New("gock: cannot match any request") +) + +// Transport implements http.RoundTripper, which fulfills single http requests issued by +// an http.Client. +// +// gock's Transport encapsulates a given or default http.Transport for further +// delegation, if needed. +type Transport struct { + g *Gock + + // mutex is used to make transport thread-safe of concurrent uses across goroutines. + mutex sync.Mutex + + // Transport encapsulates the original http.RoundTripper transport interface for delegation. + Transport http.RoundTripper +} + +// NewTransport creates a new *Transport with no responders. +func (g *Gock) NewTransport(transport http.RoundTripper) *Transport { + return &Transport{g: g, Transport: transport} +} + +// transport is used to always return a non-nil transport. This is the same as `(http.Client).transport`, and is what +// would be invoked if gock's transport were not present. +func (m *Transport) transport() http.RoundTripper { + if m.Transport != nil { + return m.Transport + } + return http.DefaultTransport +} + +// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to +// implement the http.RoundTripper interface. You will not interact with this directly, instead +// the *http.Client you are using will call it for you. +func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + // Just act as a proxy if not intercepting + if !m.g.Intercepting() { + return m.transport().RoundTrip(req) + } + + m.mutex.Lock() + defer m.g.Clean() + + var err error + var res *http.Response + + // Match mock for the incoming http.Request + mock, err := m.g.MatchMock(req) + if err != nil { + m.mutex.Unlock() + return nil, err + } + + // Invoke the observer with the intercepted http.Request and matched mock + if m.g.config.Observer != nil { + m.g.config.Observer(req, mock) + } + + // Verify if should use real networking + networking := shouldUseNetwork(m.g, req, mock) + if !networking && mock == nil { + m.mutex.Unlock() + m.g.trackUnmatchedRequest(req) + return nil, ErrCannotMatch + } + + // Ensure me unlock the mutex before building the response + m.mutex.Unlock() + + // Perform real networking via original transport + if networking { + res, err = m.transport().RoundTrip(req) + // In no mock matched, continue with the response + if err != nil || mock == nil { + return res, err + } + } + + return Responder(req, mock.Response(), res) +} + +// CancelRequest is a no-op function. +func (m *Transport) CancelRequest(req *http.Request) {} + +func shouldUseNetwork(g *Gock, req *http.Request, mock Mock) bool { + if mock != nil && mock.Response().UseNetwork { + return true + } + if !g.config.Networking { + return false + } + if len(g.config.NetworkingFilters) == 0 { + return true + } + for _, filter := range g.config.NetworkingFilters { + if !filter(req) { + return false + } + } + return true +} diff --git a/threadsafe/transport_test.go b/threadsafe/transport_test.go new file mode 100644 index 0000000..5215da6 --- /dev/null +++ b/threadsafe/transport_test.go @@ -0,0 +1,55 @@ +package threadsafe + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestTransportMatch(t *testing.T) { + g := NewGock() + defer after(g) + const uri = "http://foo.com" + g.New(uri).Reply(204) + u, _ := url.Parse(uri) + req := &http.Request{URL: u} + res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, res.Request, req) +} + +func TestTransportCannotMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(204) + u, _ := url.Parse("http://127.0.0.1:1234") + req := &http.Request{URL: u} + _, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, ErrCannotMatch) +} + +func TestTransportNotIntercepting(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.New(ts.URL).Reply(200) + g.Disable() + + u, _ := url.Parse(ts.URL) + req := &http.Request{URL: u, Header: make(http.Header)} + + res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, nil) + st.Expect(t, g.Intercepting(), false) + st.Expect(t, res.StatusCode, 200) +} diff --git a/transport.go b/transport.go index 5b2bba2..985e3ad 100644 --- a/transport.go +++ b/transport.go @@ -1,9 +1,9 @@ package gock import ( - "errors" "net/http" - "sync" + + "github.com/h2non/gock/threadsafe" ) // var mutex *sync.Mutex = &sync.Mutex{} @@ -19,7 +19,7 @@ var ( var ( // ErrCannotMatch store the error returned in case of no matches. - ErrCannotMatch = errors.New("gock: cannot match any request") + ErrCannotMatch = threadsafe.ErrCannotMatch ) // Transport implements http.RoundTripper, which fulfills single http requests issued by @@ -27,86 +27,9 @@ var ( // // gock's Transport encapsulates a given or default http.Transport for further // delegation, if needed. -type Transport struct { - // mutex is used to make transport thread-safe of concurrent uses across goroutines. - mutex sync.Mutex - - // Transport encapsulates the original http.RoundTripper transport interface for delegation. - Transport http.RoundTripper -} +type Transport = threadsafe.Transport // NewTransport creates a new *Transport with no responders. func NewTransport() *Transport { - return &Transport{Transport: NativeTransport} -} - -// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to -// implement the http.RoundTripper interface. You will not interact with this directly, instead -// the *http.Client you are using will call it for you. -func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - // Just act as a proxy if not intercepting - if !Intercepting() { - return m.Transport.RoundTrip(req) - } - - m.mutex.Lock() - defer Clean() - - var err error - var res *http.Response - - // Match mock for the incoming http.Request - mock, err := MatchMock(req) - if err != nil { - m.mutex.Unlock() - return nil, err - } - - // Invoke the observer with the intercepted http.Request and matched mock - if config.Observer != nil { - config.Observer(req, mock) - } - - // Verify if should use real networking - networking := shouldUseNetwork(req, mock) - if !networking && mock == nil { - m.mutex.Unlock() - trackUnmatchedRequest(req) - return nil, ErrCannotMatch - } - - // Ensure me unlock the mutex before building the response - m.mutex.Unlock() - - // Perform real networking via original transport - if networking { - res, err = m.Transport.RoundTrip(req) - // In no mock matched, continue with the response - if err != nil || mock == nil { - return res, err - } - } - - return Responder(req, mock.Response(), res) -} - -// CancelRequest is a no-op function. -func (m *Transport) CancelRequest(req *http.Request) {} - -func shouldUseNetwork(req *http.Request, mock Mock) bool { - if mock != nil && mock.Response().UseNetwork { - return true - } - if !config.Networking { - return false - } - if len(config.NetworkingFilters) == 0 { - return true - } - for _, filter := range config.NetworkingFilters { - if !filter(req) { - return false - } - } - return true + return g.NewTransport(NativeTransport) }