From aa903b2f5356f35a486a8e7e6ef92e9db332748e Mon Sep 17 00:00:00 2001 From: Rasmus Dahlberg Date: Fri, 1 Apr 2022 02:27:52 +0200 Subject: fix non-compliant use of HTTP status code 405 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See RFC 7231, ยง6.5.5. --- pkg/instance/handler_test.go | 64 +++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 21 deletions(-) (limited to 'pkg/instance/handler_test.go') diff --git a/pkg/instance/handler_test.go b/pkg/instance/handler_test.go index 8a48860..05388fa 100644 --- a/pkg/instance/handler_test.go +++ b/pkg/instance/handler_test.go @@ -76,28 +76,50 @@ func TestHandlers(t *testing.T) { } } -// TestServeHTTP checks that invalid HTTP methods are rejected -func TestServeHTTP(t *testing.T) { - i := &Instance{ - Config: testConfig, - } - for _, handler := range i.Handlers() { - // Prepare invalid HTTP request - method := http.MethodPost - if method == handler.Method { - method = http.MethodGet - } - url := handler.Endpoint.Path("http://example.com", i.Prefix) - req, err := http.NewRequest(method, url, nil) - if err != nil { - t.Fatalf("must create HTTP request: %v", err) - } - w := httptest.NewRecorder() +func TestVerifyMethod(t *testing.T) { + badMethod := http.MethodHead + instance := Instance{Config: testConfig} + for _, handler := range instance.Handlers() { + for _, method := range []string{ + http.MethodGet, + http.MethodPost, + badMethod, + } { + url := handler.Endpoint.Path("http://log.example.com", instance.Prefix) + req, err := http.NewRequest(method, url, nil) + if err != nil { + t.Fatalf("must create HTTP request: %v", err) + } - // Check that it is rejected - handler.ServeHTTP(w, req) - if got, want := w.Code, http.StatusMethodNotAllowed; got != want { - t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint) + w := httptest.NewRecorder() + code := handler.verifyMethod(w, req) + if got, want := code == 0, handler.Method == method; got != want { + t.Errorf("%s %s: got %v but wanted %v: %v", method, url, got, want, err) + continue + } + if code == 0 { + continue + } + + if method == badMethod { + if got, want := code, http.StatusBadRequest; got != want { + t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) + } + if _, ok := w.Header()["Allow"]; ok { + t.Errorf("%s %s: got Allow header, wanted none", method, url) + } + continue + } + + if got, want := code, http.StatusMethodNotAllowed; got != want { + t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) + } else if methods, ok := w.Header()["Allow"]; !ok { + t.Errorf("%s %s: got no allow header, expected one", method, url) + } else if got, want := len(methods), 1; got != want { + t.Errorf("%s %s: got %d allowed method(s), wanted %d", method, url, got, want) + } else if got, want := methods[0], handler.Method; got != want { + t.Errorf("%s %s: got allowed method %s, wanted %s", method, url, got, want) + } } } } -- cgit v1.2.3