diff options
author | Rasmus Dahlberg <rasmus.dahlberg@kau.se> | 2021-02-17 19:58:27 +0100 |
---|---|---|
committer | Rasmus Dahlberg <rasmus.dahlberg@kau.se> | 2021-02-17 19:58:27 +0100 |
commit | 238518951868db81cd3a004e5c3f0b99f8e82b06 (patch) | |
tree | 1df8e71e869272bc5324e7412eab9236276f3548 /handler_test.go | |
parent | 72c8492ee1bd07d5960c9920e51b7addac11b806 (diff) |
added basic server-side cosigning (work in progress)
Diffstat (limited to 'handler_test.go')
-rw-r--r-- | handler_test.go | 225 |
1 files changed, 214 insertions, 11 deletions
diff --git a/handler_test.go b/handler_test.go index dd32c37..daa1a6c 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,5 +1,7 @@ package stfe +// TODO: refactor tests + import ( "bytes" "context" @@ -8,7 +10,6 @@ import ( "testing" "crypto/ed25519" - //"crypto/tls" "encoding/base64" "encoding/json" "net/http" @@ -28,16 +29,31 @@ type testHandler struct { instance *Instance } -func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler { +func newTestHandler(t *testing.T, signer crypto.Signer, sth *StItem) *testHandler { ctrl := gomock.NewController(t) client := mockclient.NewMockTrillianLogClient(ctrl) + lp := makeTestLogParameters(t, signer) + source := &ActiveSthSource{ + client: client, + logParameters: lp, + } + if sth != nil { + source.currSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + Signature: testSignature, + }, + }) + source.nextSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) + source.cosignatureFrom = make(map[string]bool) + } return &testHandler{ mockCtrl: ctrl, client: client, instance: &Instance{ - Deadline: testDeadline, Client: client, - LogParameters: makeTestLogParameters(t, signer), + LogParameters: lp, + SthSource: source, }, } } @@ -49,6 +65,8 @@ func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler { EndpointGetProofByHash: Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, EndpointGetAnchors: Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, EndpointGetEntries: Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, + EndpointGetStableSth: Handler{instance: th.instance, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet}, + EndpointGetCosi: Handler{instance: th.instance, handler: getCosi, endpoint: EndpointGetCosi, method: http.MethodGet}, } } @@ -63,6 +81,7 @@ func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler { func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler { return map[Endpoint]Handler{ EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, + EndpointAddCosi: Handler{instance: th.instance, handler: addCosi, endpoint: EndpointAddCosi, method: http.MethodPost}, } } @@ -76,7 +95,7 @@ func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler { // TestGetHandlersRejectPost checks that all get handlers reject post requests func TestGetHandlersRejectPost(t *testing.T) { - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.getHandlers(t) { @@ -96,7 +115,7 @@ func TestGetHandlersRejectPost(t *testing.T) { // TestPostHandlersRejectGet checks that all post handlers reject get requests func TestPostHandlersRejectGet(t *testing.T) { - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.postHandlers(t) { @@ -196,7 +215,7 @@ func TestGetEntries(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetEntries.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -298,7 +317,7 @@ func TestAddEntry(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, table.signer) + th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointAddEntry.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -392,7 +411,7 @@ func TestGetSth(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, table.signer) + th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointGetSth.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -496,7 +515,7 @@ func TestGetConsistencyProof(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetConsistencyProof.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -605,7 +624,7 @@ func TestGetProofByHash(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetProofByHash.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -671,6 +690,166 @@ func TestGetProofByHash(t *testing.T) { } } +func TestGetStableSth(t *testing.T) { + for _, table := range cosiTestCases(t) { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointGetStableSth.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + w := httptest.NewRecorder() + th.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) + + // Check response code + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + if w.Code != http.StatusOK { + return + } + // Check response bytes + var gotBytes []byte + if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { + t.Errorf("failed unmarshaling json: %v, wanted ok", err) + return + } + wantBytes, _ := table.sth.Marshal() + if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { + t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) + } + }() + } +} + +func TestGetCosi(t *testing.T) { + for _, table := range cosiTestCases(t) { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointGetCosi.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + w := httptest.NewRecorder() + th.getHandler(t, EndpointGetCosi).ServeHTTP(w, req) + + // Check response code + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + if w.Code != http.StatusOK { + return + } + // Check response bytes + var gotBytes []byte + if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { + t.Errorf("failed unmarshaling json: %v, wanted ok", err) + return + } + wantCosth := NewCosignedTreeHeadV1(table.sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + Signature: testSignature, + }, + }) + wantBytes, _ := wantCosth.Marshal() + if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { + t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) + } + }() + } +} + +func TestAddCosi(t *testing.T) { + validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + validSth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature) + for _, table := range []struct { + description string + sth *StItem + breq *bytes.Buffer + wantCode int + }{ + { + description: "invalid request: untrusted witness", // more specific tests can be found in TestNewAddCosignatureRequest + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid request: cosigned wrong sth", // more specific tests can be found in TestAddCosignature + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth2), + wantCode: http.StatusBadRequest, + }, + { + description: "valid", + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth), + wantCode: http.StatusOK, + }, + } { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointAddCosi.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + th.postHandler(t, EndpointAddCosi).ServeHTTP(w, req) + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + + // Check response + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + }() + } +} + +type cosiTestCase struct { + description string + sth *StItem + wantCode int +} + +// cosiTestCases returns test cases used by TestGetStableSth and TestGetCosi +func cosiTestCases(t *testing.T) []cosiTestCase { + t.Helper() + return []cosiTestCase{ + { + description: "no cosigned/stable sth", + sth: nil, + wantCode: http.StatusInternalServerError, + }, + { + description: "malformed cosigned/stable sth", + sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), []byte("not a log id"), testSignature), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature), + wantCode: http.StatusOK, + }, + } +} + // mustMakeEd25519ChecksumV1 creates an ed25519-signed ChecksumV1 leaf func mustMakeEd25519ChecksumV1(t *testing.T, id, checksum, vk, sk []byte) ([]byte, []byte) { t.Helper() @@ -697,6 +876,30 @@ func mustMakeEd25519ChecksumV1Buffer(t *testing.T, identifier, checksum, vk, sk return bytes.NewBuffer(data) } +// mustMakeAddCosiBuffer creates an add-cosi data buffer +func mustMakeAddCosiBuffer(t *testing.T, sk, vk []byte, sth *StItem) *bytes.Buffer { + t.Helper() + msg, err := sth.Marshal() + if err != nil { + t.Fatalf("must marshal sth: %v", err) + } + costh := NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, vk), + Signature: ed25519.Sign(ed25519.PrivateKey(sk), msg), + }, + }) + item, err := costh.Marshal() + if err != nil { + t.Fatalf("must marshal costh: %v", err) + } + data, err := json.Marshal(AddCosignatureRequest{item}) + if err != nil { + t.Fatalf("must marshal add-cosi request: %v", err) + } + return bytes.NewBuffer(data) +} + // deadlineMatcher implements gomock.Matcher, such that an error is raised if // there is no context.Context deadline set type deadlineMatcher struct{} |