diff options
| -rw-r--r-- | handler_test.go | 202 | 
1 files changed, 202 insertions, 0 deletions
| diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..a98c9d6 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,202 @@ +package stfe + +import ( +	"crypto" +	"errors" +	"fmt" +	"strings" +	"testing" +	"time" + +	"crypto/x509" +	"net/http" +	"net/http/httptest" + +	"github.com/golang/mock/gomock" +	"github.com/google/certificate-transparency-go/trillian/mockclient" +	"github.com/google/go-cmp/cmp" +	"github.com/google/trillian" +	"github.com/system-transparency/stfe/server/testdata" +	"github.com/system-transparency/stfe/x509util" + +	"google.golang.org/protobuf/proto" +) + +type testHandler struct { +	mockCtrl *gomock.Controller +	client   *mockclient.MockTrillianLogClient +	instance *Instance +} + +func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler { +	anchorList, err := x509util.NewCertificateList(testdata.PemAnchors) +	if err != nil { +		t.Fatalf("failed parsing trust anchors: %v", err) +	} +	ctrl := gomock.NewController(t) +	client := mockclient.NewMockTrillianLogClient(ctrl) +	return &testHandler{ +		mockCtrl: ctrl, +		client:   client, +		instance: &Instance{ +			Deadline: time.Second * 10, // TODO: fix me? +			Client:   client, +			LogParameters: &LogParameters{ +				LogId:      make([]byte, 32), +				TreeId:     0, +				Prefix:     "/test", +				MaxRange:   3, +				MaxChain:   3, +				AnchorPool: x509util.NewCertPool(anchorList), +				AnchorList: anchorList, +				KeyUsage:   []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, +				Signer:     signer, +				HashType:   crypto.SHA256, +			}, +		}, +	} +} + +func (th *testHandler) getHandlers(t *testing.T) map[string]handler { +	return map[string]handler{ +		"get-sth":               handler{instance: th.instance, handler: getSth, endpoint: "get-sth", method: http.MethodGet}, +		"get-consistency-proof": handler{instance: th.instance, handler: getConsistencyProof, endpoint: "get-consistency-proof", method: http.MethodGet}, +		"get-proof-by-hash":     handler{instance: th.instance, handler: getProofByHash, endpoint: "get-proof-by-hash", method: http.MethodGet}, +		"get-anchors":           handler{instance: th.instance, handler: getAnchors, endpoint: "get-anchors", method: http.MethodGet}, +		"get-entries":           handler{instance: th.instance, handler: getEntries, endpoint: "get-entries", method: http.MethodGet}, +	} +} + +func (th *testHandler) getHandler(t *testing.T, endpoint string) handler { +	handler, ok := th.getHandlers(t)[endpoint] +	if !ok { +		t.Fatalf("no such get endpoint: %s", endpoint) +	} +	return handler +} + +func (th *testHandler) postHandlers(t *testing.T) map[string]handler { +	return map[string]handler{ +		"add-entry": handler{instance: th.instance, handler: addEntry, endpoint: "add-entry", method: http.MethodPost}, +	} +} + +func (th *testHandler) postHandler(t *testing.T, endpoint string) handler { +	handler, ok := th.postHandlers(t)[endpoint] +	if !ok { +		t.Fatalf("no such post endpoint: %s", endpoint) +	} +	return handler +} + +// TestGetHandlersRejectPost checks that all get handlers reject post requests +func TestGetHandlersRejectPost(t *testing.T) { +	th := newTestHandler(t, nil) +	defer th.mockCtrl.Finish() + +	for endpoint, handler := range th.getHandlers(t) { +		t.Run(endpoint, func(t *testing.T) { +			s := httptest.NewServer(handler) +			defer s.Close() + +			url := s.URL + strings.Join([]string{th.instance.LogParameters.Prefix, endpoint}, "/") +			if rsp, err := http.Post(url, "application/json", nil); err != nil { +				t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err) +			} else if rsp.StatusCode != http.StatusMethodNotAllowed { +				t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) +			} +		}) +	} +} + +// TestPostHandlersRejectGet checks that all post handlers reject get requests +func TestPostHandlersRejectGet(t *testing.T) { +	th := newTestHandler(t, nil) +	defer th.mockCtrl.Finish() + +	for endpoint, handler := range th.postHandlers(t) { +		t.Run(endpoint, func(t *testing.T) { +			s := httptest.NewServer(handler) +			defer s.Close() + +			url := s.URL + strings.Join([]string{th.instance.LogParameters.Prefix, endpoint}, "/") +			if rsp, err := http.Get(url); err != nil { +				t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err) +			} else if rsp.StatusCode != http.StatusMethodNotAllowed { +				t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) +			} +		}) +	} +} + +func TestGetSth(t *testing.T) { +	for _, table := range []struct { +		description string +		trsp        *trillian.GetLatestSignedLogRootResponse +		terr        error +		wantCode    int +		wantErrText string +	}{ +		{ +			description: "empty trillian response", +			trsp:        nil, +			terr:        errors.New("back-end failure"), +			wantCode:    http.StatusInternalServerError, +			wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", +		}, +	} { +		func() { // run deferred functions at the end of each iteration +			th := newTestHandler(t, nil) +			defer th.mockCtrl.Finish() + +			treq := &trillian.GetLatestSignedLogRootRequest{ +				LogId: th.instance.LogParameters.TreeId, +			} +			th.client.EXPECT().GetLatestSignedLogRoot(deadlineMatcher{}, compareMatcher{treq}).Return(table.trsp, table.terr) + +			url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-sth" +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("failed creating http request: %v", err) +			} + +			w := httptest.NewRecorder() +			th.getHandler(t, "get-sth").ServeHTTP(w, req) +			if w.Code != table.wantCode { +				t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) +			} + +			body := w.Body.String() +			if w.Code != http.StatusOK { +				if body != table.wantErrText { +					t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) +				} +				return +			} +			// TODO: check that response is in fact valid +		}() +	} +} + +type deadlineMatcher struct { +} + +func (dm deadlineMatcher) Matches(x interface{}) bool { +	return true // TODO: deadlineMatcher.Matches +} + +func (dm deadlineMatcher) String() string { +	return fmt.Sprintf("deadline is: TODO") +} + +type compareMatcher struct { +	want interface{} +} + +func (cm compareMatcher) Matches(got interface{}) bool { +	return cmp.Equal(got, cm.want, cmp.Comparer(proto.Equal)) +} + +func (cm compareMatcher) String() string { +	return fmt.Sprintf("equals: TODO") +} | 
