diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/instance/handler.go | 21 | ||||
| -rw-r--r-- | pkg/instance/handler_test.go | 64 | ||||
| -rw-r--r-- | pkg/instance/instance.go | 5 | ||||
| -rw-r--r-- | pkg/instance/instance_test.go | 23 | 
4 files changed, 90 insertions, 23 deletions
| diff --git a/pkg/instance/handler.go b/pkg/instance/handler.go index f2bc621..95d90a8 100644 --- a/pkg/instance/handler.go +++ b/pkg/instance/handler.go @@ -41,9 +41,9 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {  	ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline))  	defer cancel() -	if r.Method != a.Method { +	statusCode = a.verifyMethod(w, r) +	if statusCode != 0 {  		glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method) -		http.Error(w, "", http.StatusMethodNotAllowed)  		return  	} @@ -54,6 +54,23 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {  	}  } +// verifyMethod checks that an appropriate HTTP method is used.  Error handling +// is based on RFC 7231, see Sections 6.5.5 (Status 405) and 6.5.1 (Status 400). +func (h *Handler) verifyMethod(w http.ResponseWriter, r *http.Request) int { +	if h.Method == r.Method { +		return 0 +	} + +	code := http.StatusBadRequest +	if ok := h.Instance.checkHTTPMethod(r.Method); ok { +		w.Header().Set("Allow", h.Method) +		code = http.StatusMethodNotAllowed +	} + +	http.Error(w, fmt.Sprintf("error=%s", http.StatusText(code)), code) +	return code +} +  func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {  	glog.V(3).Info("handling add-entry request")  	req, err := i.leafRequestFromHTTP(ctx, r) diff --git a/pkg/instance/handler_test.go b/pkg/instance/handler_test.go index 8a48860..05388fa 100644 --- a/pkg/instance/handler_test.go +++ b/pkg/instance/handler_test.go @@ -76,28 +76,50 @@ func TestHandlers(t *testing.T) {  	}  } -// TestServeHTTP checks that invalid HTTP methods are rejected -func TestServeHTTP(t *testing.T) { -	i := &Instance{ -		Config: testConfig, -	} -	for _, handler := range i.Handlers() { -		// Prepare invalid HTTP request -		method := http.MethodPost -		if method == handler.Method { -			method = http.MethodGet -		} -		url := handler.Endpoint.Path("http://example.com", i.Prefix) -		req, err := http.NewRequest(method, url, nil) -		if err != nil { -			t.Fatalf("must create HTTP request: %v", err) -		} -		w := httptest.NewRecorder() +func TestVerifyMethod(t *testing.T) { +	badMethod := http.MethodHead +	instance := Instance{Config: testConfig} +	for _, handler := range instance.Handlers() { +		for _, method := range []string{ +			http.MethodGet, +			http.MethodPost, +			badMethod, +		} { +			url := handler.Endpoint.Path("http://log.example.com", instance.Prefix) +			req, err := http.NewRequest(method, url, nil) +			if err != nil { +				t.Fatalf("must create HTTP request: %v", err) +			} -		// Check that it is rejected -		handler.ServeHTTP(w, req) -		if got, want := w.Code, http.StatusMethodNotAllowed; got != want { -			t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint) +			w := httptest.NewRecorder() +			code := handler.verifyMethod(w, req) +			if got, want := code == 0, handler.Method == method; got != want { +				t.Errorf("%s %s: got %v but wanted %v: %v", method, url, got, want, err) +				continue +			} +			if code == 0 { +				continue +			} + +			if method == badMethod { +				if got, want := code, http.StatusBadRequest; got != want { +					t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) +				} +				if _, ok := w.Header()["Allow"]; ok { +					t.Errorf("%s %s: got Allow header, wanted none", method, url) +				} +				continue +			} + +			if got, want := code, http.StatusMethodNotAllowed; got != want { +				t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) +			} else if methods, ok := w.Header()["Allow"]; !ok { +				t.Errorf("%s %s: got no allow header, expected one", method, url) +			} else if got, want := len(methods), 1; got != want { +				t.Errorf("%s %s: got %d allowed method(s), wanted %d", method, url, got, want) +			} else if got, want := methods[0], handler.Method; got != want { +				t.Errorf("%s %s: got allowed method %s, wanted %s", method, url, got, want) +			}  		}  	}  } diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 7ade955..f62922e 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -51,6 +51,11 @@ func (i *Instance) Handlers() []Handler {  	}  } +// checkHTTPMethod checks if an HTTP method is supported +func (i *Instance) checkHTTPMethod(m string) bool { +	return m == http.MethodGet || m == http.MethodPost +} +  func (i *Instance) leafRequestFromHTTP(ctx context.Context, r *http.Request) (*requests.Leaf, error) {  	var req requests.Leaf  	if err := req.FromASCII(r.Body); err != nil { diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go new file mode 100644 index 0000000..00d996d --- /dev/null +++ b/pkg/instance/instance_test.go @@ -0,0 +1,23 @@ +package instance + +import ( +	"net/http" +	"testing" +) + +func CheckHTTPMethod(t *testing.T) { +	var instance Instance +	for _, table := range []struct { +		method string +		wantOK bool +	}{ +		{wantOK: false, method: http.MethodHead}, +		{wantOK: true, method: http.MethodPost}, +		{wantOK: true, method: http.MethodGet}, +	} { +		ok := instance.checkHTTPMethod(table.method) +		if got, want := ok, table.wantOK; got != want { +			t.Errorf("%s: got %v but wanted %v", table.method, got, want) +		} +	} +} | 
