aboutsummaryrefslogtreecommitdiff
path: root/pkg/instance
diff options
context:
space:
mode:
authorRasmus Dahlberg <rasmus@mullvad.net>2022-04-01 02:27:52 +0200
committerRasmus Dahlberg <rasmus@mullvad.net>2022-04-01 02:56:42 +0200
commitaa903b2f5356f35a486a8e7e6ef92e9db332748e (patch)
tree3fff6448da782fdebffe9d24bf9b70edca14d396 /pkg/instance
parentb09d20111227be5e6d5126ec905b44a7a4e96b0d (diff)
fix non-compliant use of HTTP status code 405
See RFC 7231, ยง6.5.5.
Diffstat (limited to 'pkg/instance')
-rw-r--r--pkg/instance/handler.go21
-rw-r--r--pkg/instance/handler_test.go64
-rw-r--r--pkg/instance/instance.go5
-rw-r--r--pkg/instance/instance_test.go23
4 files changed, 90 insertions, 23 deletions
diff --git a/pkg/instance/handler.go b/pkg/instance/handler.go
index f2bc621..95d90a8 100644
--- a/pkg/instance/handler.go
+++ b/pkg/instance/handler.go
@@ -41,9 +41,9 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline))
defer cancel()
- if r.Method != a.Method {
+ statusCode = a.verifyMethod(w, r)
+ if statusCode != 0 {
glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method)
- http.Error(w, "", http.StatusMethodNotAllowed)
return
}
@@ -54,6 +54,23 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
+// verifyMethod checks that an appropriate HTTP method is used. Error handling
+// is based on RFC 7231, see Sections 6.5.5 (Status 405) and 6.5.1 (Status 400).
+func (h *Handler) verifyMethod(w http.ResponseWriter, r *http.Request) int {
+ if h.Method == r.Method {
+ return 0
+ }
+
+ code := http.StatusBadRequest
+ if ok := h.Instance.checkHTTPMethod(r.Method); ok {
+ w.Header().Set("Allow", h.Method)
+ code = http.StatusMethodNotAllowed
+ }
+
+ http.Error(w, fmt.Sprintf("error=%s", http.StatusText(code)), code)
+ return code
+}
+
func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
glog.V(3).Info("handling add-entry request")
req, err := i.leafRequestFromHTTP(ctx, r)
diff --git a/pkg/instance/handler_test.go b/pkg/instance/handler_test.go
index 8a48860..05388fa 100644
--- a/pkg/instance/handler_test.go
+++ b/pkg/instance/handler_test.go
@@ -76,28 +76,50 @@ func TestHandlers(t *testing.T) {
}
}
-// TestServeHTTP checks that invalid HTTP methods are rejected
-func TestServeHTTP(t *testing.T) {
- i := &Instance{
- Config: testConfig,
- }
- for _, handler := range i.Handlers() {
- // Prepare invalid HTTP request
- method := http.MethodPost
- if method == handler.Method {
- method = http.MethodGet
- }
- url := handler.Endpoint.Path("http://example.com", i.Prefix)
- req, err := http.NewRequest(method, url, nil)
- if err != nil {
- t.Fatalf("must create HTTP request: %v", err)
- }
- w := httptest.NewRecorder()
+func TestVerifyMethod(t *testing.T) {
+ badMethod := http.MethodHead
+ instance := Instance{Config: testConfig}
+ for _, handler := range instance.Handlers() {
+ for _, method := range []string{
+ http.MethodGet,
+ http.MethodPost,
+ badMethod,
+ } {
+ url := handler.Endpoint.Path("http://log.example.com", instance.Prefix)
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ t.Fatalf("must create HTTP request: %v", err)
+ }
- // Check that it is rejected
- handler.ServeHTTP(w, req)
- if got, want := w.Code, http.StatusMethodNotAllowed; got != want {
- t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint)
+ w := httptest.NewRecorder()
+ code := handler.verifyMethod(w, req)
+ if got, want := code == 0, handler.Method == method; got != want {
+ t.Errorf("%s %s: got %v but wanted %v: %v", method, url, got, want, err)
+ continue
+ }
+ if code == 0 {
+ continue
+ }
+
+ if method == badMethod {
+ if got, want := code, http.StatusBadRequest; got != want {
+ t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want)
+ }
+ if _, ok := w.Header()["Allow"]; ok {
+ t.Errorf("%s %s: got Allow header, wanted none", method, url)
+ }
+ continue
+ }
+
+ if got, want := code, http.StatusMethodNotAllowed; got != want {
+ t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want)
+ } else if methods, ok := w.Header()["Allow"]; !ok {
+ t.Errorf("%s %s: got no allow header, expected one", method, url)
+ } else if got, want := len(methods), 1; got != want {
+ t.Errorf("%s %s: got %d allowed method(s), wanted %d", method, url, got, want)
+ } else if got, want := methods[0], handler.Method; got != want {
+ t.Errorf("%s %s: got allowed method %s, wanted %s", method, url, got, want)
+ }
}
}
}
diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go
index 7ade955..f62922e 100644
--- a/pkg/instance/instance.go
+++ b/pkg/instance/instance.go
@@ -51,6 +51,11 @@ func (i *Instance) Handlers() []Handler {
}
}
+// checkHTTPMethod checks if an HTTP method is supported
+func (i *Instance) checkHTTPMethod(m string) bool {
+ return m == http.MethodGet || m == http.MethodPost
+}
+
func (i *Instance) leafRequestFromHTTP(ctx context.Context, r *http.Request) (*requests.Leaf, error) {
var req requests.Leaf
if err := req.FromASCII(r.Body); err != nil {
diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go
new file mode 100644
index 0000000..00d996d
--- /dev/null
+++ b/pkg/instance/instance_test.go
@@ -0,0 +1,23 @@
+package instance
+
+import (
+ "net/http"
+ "testing"
+)
+
+func CheckHTTPMethod(t *testing.T) {
+ var instance Instance
+ for _, table := range []struct {
+ method string
+ wantOK bool
+ }{
+ {wantOK: false, method: http.MethodHead},
+ {wantOK: true, method: http.MethodPost},
+ {wantOK: true, method: http.MethodGet},
+ } {
+ ok := instance.checkHTTPMethod(table.method)
+ if got, want := ok, table.wantOK; got != want {
+ t.Errorf("%s: got %v but wanted %v", table.method, got, want)
+ }
+ }
+}