diff options
Diffstat (limited to 'pkg/instance')
-rw-r--r-- | pkg/instance/handler.go | 21 | ||||
-rw-r--r-- | pkg/instance/handler_test.go | 64 | ||||
-rw-r--r-- | pkg/instance/instance.go | 5 | ||||
-rw-r--r-- | pkg/instance/instance_test.go | 23 |
4 files changed, 90 insertions, 23 deletions
diff --git a/pkg/instance/handler.go b/pkg/instance/handler.go index f2bc621..95d90a8 100644 --- a/pkg/instance/handler.go +++ b/pkg/instance/handler.go @@ -41,9 +41,9 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline)) defer cancel() - if r.Method != a.Method { + statusCode = a.verifyMethod(w, r) + if statusCode != 0 { glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method) - http.Error(w, "", http.StatusMethodNotAllowed) return } @@ -54,6 +54,23 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +// verifyMethod checks that an appropriate HTTP method is used. Error handling +// is based on RFC 7231, see Sections 6.5.5 (Status 405) and 6.5.1 (Status 400). +func (h *Handler) verifyMethod(w http.ResponseWriter, r *http.Request) int { + if h.Method == r.Method { + return 0 + } + + code := http.StatusBadRequest + if ok := h.Instance.checkHTTPMethod(r.Method); ok { + w.Header().Set("Allow", h.Method) + code = http.StatusMethodNotAllowed + } + + http.Error(w, fmt.Sprintf("error=%s", http.StatusText(code)), code) + return code +} + func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { glog.V(3).Info("handling add-entry request") req, err := i.leafRequestFromHTTP(ctx, r) diff --git a/pkg/instance/handler_test.go b/pkg/instance/handler_test.go index 8a48860..05388fa 100644 --- a/pkg/instance/handler_test.go +++ b/pkg/instance/handler_test.go @@ -76,28 +76,50 @@ func TestHandlers(t *testing.T) { } } -// TestServeHTTP checks that invalid HTTP methods are rejected -func TestServeHTTP(t *testing.T) { - i := &Instance{ - Config: testConfig, - } - for _, handler := range i.Handlers() { - // Prepare invalid HTTP request - method := http.MethodPost - if method == handler.Method { - method = http.MethodGet - } - url := handler.Endpoint.Path("http://example.com", i.Prefix) - req, err := http.NewRequest(method, url, nil) - if err != nil { - t.Fatalf("must create HTTP request: %v", err) - } - w := httptest.NewRecorder() +func TestVerifyMethod(t *testing.T) { + badMethod := http.MethodHead + instance := Instance{Config: testConfig} + for _, handler := range instance.Handlers() { + for _, method := range []string{ + http.MethodGet, + http.MethodPost, + badMethod, + } { + url := handler.Endpoint.Path("http://log.example.com", instance.Prefix) + req, err := http.NewRequest(method, url, nil) + if err != nil { + t.Fatalf("must create HTTP request: %v", err) + } - // Check that it is rejected - handler.ServeHTTP(w, req) - if got, want := w.Code, http.StatusMethodNotAllowed; got != want { - t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint) + w := httptest.NewRecorder() + code := handler.verifyMethod(w, req) + if got, want := code == 0, handler.Method == method; got != want { + t.Errorf("%s %s: got %v but wanted %v: %v", method, url, got, want, err) + continue + } + if code == 0 { + continue + } + + if method == badMethod { + if got, want := code, http.StatusBadRequest; got != want { + t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) + } + if _, ok := w.Header()["Allow"]; ok { + t.Errorf("%s %s: got Allow header, wanted none", method, url) + } + continue + } + + if got, want := code, http.StatusMethodNotAllowed; got != want { + t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) + } else if methods, ok := w.Header()["Allow"]; !ok { + t.Errorf("%s %s: got no allow header, expected one", method, url) + } else if got, want := len(methods), 1; got != want { + t.Errorf("%s %s: got %d allowed method(s), wanted %d", method, url, got, want) + } else if got, want := methods[0], handler.Method; got != want { + t.Errorf("%s %s: got allowed method %s, wanted %s", method, url, got, want) + } } } } diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 7ade955..f62922e 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -51,6 +51,11 @@ func (i *Instance) Handlers() []Handler { } } +// checkHTTPMethod checks if an HTTP method is supported +func (i *Instance) checkHTTPMethod(m string) bool { + return m == http.MethodGet || m == http.MethodPost +} + func (i *Instance) leafRequestFromHTTP(ctx context.Context, r *http.Request) (*requests.Leaf, error) { var req requests.Leaf if err := req.FromASCII(r.Body); err != nil { diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go new file mode 100644 index 0000000..00d996d --- /dev/null +++ b/pkg/instance/instance_test.go @@ -0,0 +1,23 @@ +package instance + +import ( + "net/http" + "testing" +) + +func CheckHTTPMethod(t *testing.T) { + var instance Instance + for _, table := range []struct { + method string + wantOK bool + }{ + {wantOK: false, method: http.MethodHead}, + {wantOK: true, method: http.MethodPost}, + {wantOK: true, method: http.MethodGet}, + } { + ok := instance.checkHTTPMethod(table.method) + if got, want := ok, table.wantOK; got != want { + t.Errorf("%s: got %v but wanted %v", table.method, got, want) + } + } +} |