package stfe // TODO: refactor tests import ( "bytes" "context" "crypto" "fmt" "testing" "crypto/ed25519" "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "github.com/golang/mock/gomock" "github.com/google/certificate-transparency-go/trillian/mockclient" cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" "github.com/google/trillian" "github.com/system-transparency/stfe/namespace/testdata" ) type testHandler struct { mockCtrl *gomock.Controller client *mockclient.MockTrillianLogClient instance *Instance } func newTestHandler(t *testing.T, signer crypto.Signer, sth *StItem) *testHandler { ctrl := gomock.NewController(t) client := mockclient.NewMockTrillianLogClient(ctrl) lp := makeTestLogParameters(t, signer) source := &ActiveSthSource{ client: client, logParameters: lp, } if sth != nil { source.currCosth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ SignatureV1{ Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), Signature: testSignature, }, }) source.nextCosth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) source.cosignatureFrom = make(map[string]bool) } return &testHandler{ mockCtrl: ctrl, client: client, instance: &Instance{ Client: client, LogParameters: lp, SthSource: source, }, } } func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler { return map[Endpoint]Handler{ EndpointGetLatestSth: Handler{instance: th.instance, handler: getSth, endpoint: EndpointGetLatestSth, method: http.MethodGet}, EndpointGetConsistencyProof: Handler{instance: th.instance, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet}, EndpointGetProofByHash: Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, EndpointGetAnchors: Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, EndpointGetEntries: Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, EndpointGetStableSth: Handler{instance: th.instance, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet}, EndpointGetCosignedSth: Handler{instance: th.instance, handler: getCosi, endpoint: EndpointGetCosignedSth, method: http.MethodGet}, } } func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler { handler, ok := th.getHandlers(t)[endpoint] if !ok { t.Fatalf("no such get endpoint: %s", endpoint) } return handler } func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler { return map[Endpoint]Handler{ EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, EndpointAddCosignature: Handler{instance: th.instance, handler: addCosi, endpoint: EndpointAddCosignature, method: http.MethodPost}, } } func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler { handler, ok := th.postHandlers(t)[endpoint] if !ok { t.Fatalf("no such post endpoint: %s", endpoint) } return handler } // TestGetHandlersRejectPost checks that all get handlers reject post requests func TestGetHandlersRejectPost(t *testing.T) { th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.getHandlers(t) { t.Run(string(endpoint), func(t *testing.T) { s := httptest.NewServer(handler) defer s.Close() url := endpoint.Path(s.URL, th.instance.LogParameters.Prefix) if rsp, err := http.Post(url, "application/json", nil); err != nil { t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err) } else if rsp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) } }) } } // TestPostHandlersRejectGet checks that all post handlers reject get requests func TestPostHandlersRejectGet(t *testing.T) { th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.postHandlers(t) { t.Run(string(endpoint), func(t *testing.T) { s := httptest.NewServer(handler) defer s.Close() url := endpoint.Path(s.URL, th.instance.LogParameters.Prefix) if rsp, err := http.Get(url); err != nil { t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err) } else if rsp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) } }) } } //// TestGetAnchors checks for a valid number of decodable trust anchors //func TestGetAnchors(t *testing.T) { // // TODO: refactor with namespaces // //th := newTestHandler(t, nil) // //defer th.mockCtrl.Finish() // // //url := EndpointGetAnchors.Path("http://example.com", th.instance.LogParameters.Prefix) // //req, err := http.NewRequest("GET", url, nil) // //if err != nil { // // t.Fatalf("failed creating http request: %v", err) // //} // // //w := httptest.NewRecorder() // //th.getHandler(t, EndpointGetAnchors).ServeHTTP(w, req) // //if w.Code != http.StatusOK { // // t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, http.StatusOK) // // return // //} // // //var derAnchors [][]byte // //if err := json.Unmarshal([]byte(w.Body.String()), &derAnchors); err != nil { // // t.Errorf("failed unmarshaling trust anchors response: %v", err) // // return // //} // //if got, want := len(derAnchors), len(th.instance.LogParameters.); got != want { // // t.Errorf("unexpected trust anchor count %d, want %d", got, want) // //} // //if _, err := x509util.ParseDerList(derAnchors); err != nil { // // t.Errorf("failed decoding trust anchors: %v", err) // //} //} func TestGetEntries(t *testing.T) { for _, table := range []struct { description string breq *GetEntriesRequest trsp *trillian.GetLeavesByRangeResponse terr error wantCode int wantErrText string }{ { description: "bad request parameters", breq: &GetEntriesRequest{ Start: 1, End: 0, }, wantCode: http.StatusBadRequest, wantErrText: http.StatusText(http.StatusBadRequest) + "\n", }, { description: "empty trillian response", breq: &GetEntriesRequest{ Start: 0, End: 1, }, terr: fmt.Errorf("back-end failure"), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", }, // TODO: make invalid get-entries response //{ // description: "invalid get-entries response", // breq: &GetEntriesRequest{ // Start: 0, // End: 1, // }, // trsp: makeTrillianGetLeavesByRangeResponse(t, 0, 1, testPackage, testdata.Ed25519Vk, testdata.Ed25519Sk), // wantCode: http.StatusInternalServerError, // wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", //}, { description: "valid get-entries response", breq: &GetEntriesRequest{ Start: 0, End: 1, }, trsp: makeTrillianGetLeavesByRangeResponse(t, 0, 1, testPackage, testdata.Ed25519Vk, testdata.Ed25519Sk), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetEntries.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must create http request: %v", err) } q := req.URL.Query() q.Add("start", fmt.Sprintf("%d", table.breq.Start)) q.Add("end", fmt.Sprintf("%d", table.breq.End)) req.URL.RawQuery = q.Encode() if table.trsp != nil || table.terr != nil { th.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() th.getHandler(t, EndpointGetEntries).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } body := w.Body.String() if w.Code != http.StatusOK { if body != table.wantErrText { t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) } return } var rsps []*GetEntryResponse if err := json.Unmarshal([]byte(body), &rsps); err != nil { t.Errorf("failed parsing list of log entries: %v", err) return } for i, rsp := range rsps { var item StItem if err := item.Unmarshal(rsp.Item); err != nil { t.Errorf("failed unmarshaling StItem: %v", err) } else { if item.Format != StFormatChecksumV1 { t.Errorf("invalid StFormat: got %v, want %v", item.Format, StFormatChecksumV1) } checksum := item.ChecksumV1 if got, want := checksum.Package, []byte(fmt.Sprintf("%s_%d", testPackage, int64(i)+table.breq.Start)); !bytes.Equal(got, want) { t.Errorf("got package name %s, want %s", string(got), string(want)) } if got, want := checksum.Checksum, make([]byte, 32); !bytes.Equal(got, want) { t.Errorf("got package checksum %X, want %X", got, want) } // TODO: check namespace? } // TODO: verify signaturew w/ namespace? //if !ed25519.Verify(chain[0].PublicKey.(ed25519.PublicKey), rsp.Item, rsp.Signature) { // t.Errorf("invalid ed25519 signature") //} } }() } } func TestAddEntry(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer trsp *trillian.QueueLeafResponse terr error wantCode int wantErrText string signer crypto.Signer }{ { description: "empty trillian response", breq: mustMakeEd25519ChecksumV1Buffer(t, testPackage, testChecksum, testdata.Ed25519Vk, testdata.Ed25519Sk), terr: fmt.Errorf("back-end failure"), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", }, { description: "bad request parameters: invalid signature", breq: mustMakeEd25519ChecksumV1Buffer(t, testPackage, testChecksum, make([]byte, 32), testdata.Ed25519Sk), wantCode: http.StatusBadRequest, wantErrText: http.StatusText(http.StatusBadRequest) + "\n", }, { description: "log signature failure", breq: mustMakeEd25519ChecksumV1Buffer(t, testPackage, testChecksum, testdata.Ed25519Vk, testdata.Ed25519Sk), trsp: makeTrillianQueueLeafResponse(t, testPackage, testdata.Ed25519Vk, testdata.Ed25519Sk, false), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", signer: cttestdata.NewSignerWithErr(nil, fmt.Errorf("signing failed")), }, { description: "valid add-entry request-response", breq: mustMakeEd25519ChecksumV1Buffer(t, testPackage, testChecksum, testdata.Ed25519Vk, testdata.Ed25519Sk), trsp: makeTrillianQueueLeafResponse(t, testPackage, testdata.Ed25519Vk, testdata.Ed25519Sk, false), wantCode: http.StatusOK, signer: cttestdata.NewSignerWithFixedSig(nil, make([]byte, 32)), }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointAddEntry.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("failed creating http request: %v", err) } req.Header.Set("Content-Type", "application/json") if table.trsp != nil || table.terr != nil { th.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() th.postHandler(t, EndpointAddEntry).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } body := w.Body.String() if w.Code != http.StatusOK { if body != table.wantErrText { t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) } return } // status code is http.StatusOK, check response var data []byte if err := json.Unmarshal([]byte(body), &data); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } var item StItem if err := item.Unmarshal(data); err != nil { t.Errorf("failed unmarshaling StItem: %v, wanted ok", err) return } if item.Format != StFormatSignedDebugInfoV1 { t.Errorf("invalid StFormat: got %v, want %v", item.Format, StFormatSignedDebugInfoV1) } sdi := item.SignedDebugInfoV1 if !bytes.Equal(sdi.LogId, th.instance.LogParameters.LogId) { t.Errorf("want log id %X, got %X", sdi.LogId, th.instance.LogParameters.LogId) } if len(sdi.Message) == 0 { t.Errorf("expected message, got none") } if !bytes.Equal(sdi.Signature, make([]byte, 32)) { t.Errorf("want signature %X, got %X", sdi.Signature, make([]byte, 32)) } }() } } func TestGetSth(t *testing.T) { tr := makeLatestSignedLogRootResponse(t, 0, 0, make([]byte, 32)) tr.SignedLogRoot.LogRoot = tr.SignedLogRoot.LogRoot[1:] for _, table := range []struct { description string trsp *trillian.GetLatestSignedLogRootResponse terr error wantCode int wantErrText string signer crypto.Signer }{ { description: "empty trillian response", terr: fmt.Errorf("back-end failure"), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", }, { description: "marshal failure: no signature", trsp: makeLatestSignedLogRootResponse(t, 0, 0, make([]byte, 32)), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", signer: cttestdata.NewSignerWithFixedSig(nil, make([]byte, 0)), }, { description: "signature failure", trsp: makeLatestSignedLogRootResponse(t, 0, 0, make([]byte, 32)), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", signer: cttestdata.NewSignerWithErr(nil, fmt.Errorf("signing failed")), }, { description: "valid request and response", trsp: makeLatestSignedLogRootResponse(t, 0, 0, make([]byte, 32)), wantCode: http.StatusOK, signer: cttestdata.NewSignerWithFixedSig(nil, make([]byte, 32)), }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointGetLatestSth.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } w := httptest.NewRecorder() th.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) th.getHandler(t, EndpointGetLatestSth).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } body := w.Body.String() if w.Code != http.StatusOK { if body != table.wantErrText { t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) } return } // status code is http.StatusOK, check response var data []byte if err := json.Unmarshal([]byte(body), &data); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } var item StItem if err := item.Unmarshal(data); err != nil { t.Errorf("failed unmarshaling StItem: %v, wanted ok", err) return } if item.Format != StFormatSignedTreeHeadV1 { t.Errorf("invalid StFormat: got %v, want %v", item.Format, StFormatSignedTreeHeadV1) } sth := item.SignedTreeHeadV1 if !bytes.Equal(sth.LogId, th.instance.LogParameters.LogId) { t.Errorf("want log id %X, got %X", sth.LogId, th.instance.LogParameters.LogId) } if !bytes.Equal(sth.Signature, make([]byte, 32)) { t.Errorf("want signature %X, got %X", sth.Signature, make([]byte, 32)) } if sth.TreeHead.TreeSize != 0 { t.Errorf("want tree size %d, got %d", 0, sth.TreeHead.TreeSize) } if sth.TreeHead.Timestamp != 0 { t.Errorf("want timestamp %d, got %d", 0, sth.TreeHead.Timestamp) } if !bytes.Equal(sth.TreeHead.RootHash.Data, make([]byte, 32)) { t.Errorf("want root hash %X, got %X", make([]byte, 32), sth.TreeHead.RootHash) } if len(sth.TreeHead.Extension) != 0 { t.Errorf("want no extensions, got %v", sth.TreeHead.Extension) } }() } } func TestGetConsistencyProof(t *testing.T) { fixedProof := [][]byte{ make([]byte, 32), make([]byte, 32), } for _, table := range []struct { description string breq *GetConsistencyProofRequest trsp *trillian.GetConsistencyProofResponse terr error wantCode int wantErrText string }{ { description: "bad request parameters", breq: &GetConsistencyProofRequest{ First: 2, Second: 1, }, wantCode: http.StatusBadRequest, wantErrText: http.StatusText(http.StatusBadRequest) + "\n", }, { description: "empty trillian response", breq: &GetConsistencyProofRequest{ First: 1, Second: 2, }, terr: fmt.Errorf("back-end failure"), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", }, { description: "valid request and response", breq: &GetConsistencyProofRequest{ First: 1, Second: 2, }, trsp: makeTrillianGetConsistencyProofResponse(t, fixedProof), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetConsistencyProof.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } q := req.URL.Query() q.Add("first", fmt.Sprintf("%d", table.breq.First)) q.Add("second", fmt.Sprintf("%d", table.breq.Second)) req.URL.RawQuery = q.Encode() w := httptest.NewRecorder() if table.trsp != nil || table.terr != nil { th.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } th.getHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } body := w.Body.String() if w.Code != http.StatusOK { if body != table.wantErrText { t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) } return } // status code is http.StatusOK, check response var data []byte if err := json.Unmarshal([]byte(body), &data); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } var item StItem if err := item.Unmarshal(data); err != nil { t.Errorf("failed unmarshaling StItem: %v, wanted ok", err) return } if item.Format != StFormatConsistencyProofV1 { t.Errorf("invalid StFormat: got %v, want %v", item.Format, StFormatInclusionProofV1) } proof := item.ConsistencyProofV1 if !bytes.Equal(proof.LogId, th.instance.LogParameters.LogId) { t.Errorf("want log id %X, got %X", proof.LogId, th.instance.LogParameters.LogId) } if got, want := proof.TreeSize1, uint64(table.breq.First); got != want { t.Errorf("want tree size %d, got %d", want, got) } if got, want := proof.TreeSize2, uint64(table.breq.Second); got != want { t.Errorf("want tree size %d, got %d", want, got) } if got, want := len(proof.ConsistencyPath), len(fixedProof); got != want { t.Errorf("want proof length %d, got %d", want, got) return } for i, nh := range proof.ConsistencyPath { if !bytes.Equal(nh.Data, fixedProof[i]) { t.Errorf("want proof[%d]=%X, got %X", i, fixedProof[i], nh.Data) } } }() } } func TestGetProofByHash(t *testing.T) { fixedProof := [][]byte{ make([]byte, 32), make([]byte, 32), } for _, table := range []struct { description string breq *GetProofByHashRequest trsp *trillian.GetInclusionProofByHashResponse terr error wantCode int wantErrText string }{ { description: "bad request parameters", breq: &GetProofByHashRequest{ Hash: make([]byte, 32), TreeSize: 0, }, wantCode: http.StatusBadRequest, wantErrText: http.StatusText(http.StatusBadRequest) + "\n", }, { description: "empty trillian response", breq: &GetProofByHashRequest{ Hash: make([]byte, 32), TreeSize: 128, }, terr: fmt.Errorf("back-end failure"), wantCode: http.StatusInternalServerError, wantErrText: http.StatusText(http.StatusInternalServerError) + "\n", }, { description: "valid request and response", breq: &GetProofByHashRequest{ Hash: make([]byte, 32), TreeSize: 128, }, trsp: makeTrillianGetInclusionProofByHashResponse(t, 0, fixedProof), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetProofByHash.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } q := req.URL.Query() q.Add("hash", base64.StdEncoding.EncodeToString(table.breq.Hash)) q.Add("tree_size", fmt.Sprintf("%d", table.breq.TreeSize)) req.URL.RawQuery = q.Encode() w := httptest.NewRecorder() if table.trsp != nil || table.terr != nil { th.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } th.getHandler(t, EndpointGetProofByHash).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } body := w.Body.String() if w.Code != http.StatusOK { if body != table.wantErrText { t.Errorf("GET(%s)=%q, want text %q", url, body, table.wantErrText) } return } // status code is http.StatusOK, check response var data []byte if err := json.Unmarshal([]byte(body), &data); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } var item StItem if err := item.Unmarshal(data); err != nil { t.Errorf("failed unmarshaling StItem: %v, wanted ok", err) return } if item.Format != StFormatInclusionProofV1 { t.Errorf("invalid StFormat: got %v, want %v", item.Format, StFormatInclusionProofV1) } proof := item.InclusionProofV1 if !bytes.Equal(proof.LogId, th.instance.LogParameters.LogId) { t.Errorf("want log id %X, got %X", proof.LogId, th.instance.LogParameters.LogId) } if proof.TreeSize != uint64(table.breq.TreeSize) { t.Errorf("want tree size %d, got %d", table.breq.TreeSize, proof.TreeSize) } if proof.LeafIndex != 0 { t.Errorf("want index %d, got %d", 0, proof.LeafIndex) } if got, want := len(proof.InclusionPath), len(fixedProof); got != want { t.Errorf("want proof length %d, got %d", want, got) return } for i, nh := range proof.InclusionPath { if !bytes.Equal(nh.Data, fixedProof[i]) { t.Errorf("want proof[%d]=%X, got %X", i, fixedProof[i], nh.Data) } } }() } } func TestGetStableSth(t *testing.T) { for _, table := range cosiTestCases(t) { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, table.sth) defer th.mockCtrl.Finish() // Setup and run client query url := EndpointGetStableSth.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } w := httptest.NewRecorder() th.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) // Check response code if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } if w.Code != http.StatusOK { return } // Check response bytes var gotBytes []byte if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } wantBytes, _ := table.sth.Marshal() if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) } }() } } func TestGetCosi(t *testing.T) { for _, table := range cosiTestCases(t) { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, table.sth) defer th.mockCtrl.Finish() // Setup and run client query url := EndpointGetCosignedSth.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } w := httptest.NewRecorder() th.getHandler(t, EndpointGetCosignedSth).ServeHTTP(w, req) // Check response code if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } if w.Code != http.StatusOK { return } // Check response bytes var gotBytes []byte if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { t.Errorf("failed unmarshaling json: %v, wanted ok", err) return } wantCosth := NewCosignedTreeHeadV1(table.sth.SignedTreeHeadV1, []SignatureV1{ SignatureV1{ Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), Signature: testSignature, }, }) wantBytes, _ := wantCosth.Marshal() if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) } }() } } func TestAddCosi(t *testing.T) { validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) validSth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature) for _, table := range []struct { description string sth *StItem breq *bytes.Buffer wantCode int }{ { description: "invalid request: untrusted witness", // more specific tests can be found in TestNewAddCosignatureRequest sth: validSth, breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth), wantCode: http.StatusBadRequest, }, { description: "invalid request: cosigned wrong sth", // more specific tests can be found in TestAddCosignature sth: validSth, breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth2), wantCode: http.StatusBadRequest, }, { description: "valid", sth: validSth, breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration th := newTestHandler(t, nil, table.sth) defer th.mockCtrl.Finish() // Setup and run client query url := EndpointAddCosignature.Path("http://example.com", th.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("failed creating http request: %v", err) } req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() th.postHandler(t, EndpointAddCosignature).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } // Check response if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } }() } } type cosiTestCase struct { description string sth *StItem wantCode int } // cosiTestCases returns test cases used by TestGetStableSth and TestGetCosi func cosiTestCases(t *testing.T) []cosiTestCase { t.Helper() return []cosiTestCase{ { description: "no cosigned/stable sth", sth: nil, wantCode: http.StatusInternalServerError, }, { description: "malformed cosigned/stable sth", sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), []byte("not a log id"), testSignature), wantCode: http.StatusInternalServerError, }, { description: "valid", sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature), wantCode: http.StatusOK, }, } } // mustMakeEd25519ChecksumV1 creates an ed25519-signed ChecksumV1 leaf func mustMakeEd25519ChecksumV1(t *testing.T, id, checksum, vk, sk []byte) ([]byte, []byte) { t.Helper() leaf, err := NewChecksumV1(id, checksum, mustNewNamespaceEd25519V1(t, vk)).Marshal() if err != nil { t.Fatalf("must serialize checksum_v1: %v", err) } return leaf, ed25519.Sign(ed25519.PrivateKey(sk), leaf) } // mustMakeEd25519ChecksumV1Buffer creates an add-entry data buffer with an // Ed25519 namespace that can be posted. func mustMakeEd25519ChecksumV1Buffer(t *testing.T, identifier, checksum, vk, sk []byte) *bytes.Buffer { t.Helper() leaf, signature := mustMakeEd25519ChecksumV1(t, identifier, checksum, vk, sk) req := AddEntryRequest{ Item: leaf, Signature: signature, } data, err := json.Marshal(req) if err != nil { t.Fatalf("must marshal add-entry request: %v", err) } return bytes.NewBuffer(data) } // mustMakeAddCosiBuffer creates an add-cosi data buffer func mustMakeAddCosiBuffer(t *testing.T, sk, vk []byte, sth *StItem) *bytes.Buffer { t.Helper() msg, err := sth.Marshal() if err != nil { t.Fatalf("must marshal sth: %v", err) } costh := NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ SignatureV1{ Namespace: *mustNewNamespaceEd25519V1(t, vk), Signature: ed25519.Sign(ed25519.PrivateKey(sk), msg), }, }) item, err := costh.Marshal() if err != nil { t.Fatalf("must marshal costh: %v", err) } data, err := json.Marshal(AddCosignatureRequest{item}) if err != nil { t.Fatalf("must marshal add-cosi request: %v", err) } return bytes.NewBuffer(data) } // deadlineMatcher implements gomock.Matcher, such that an error is raised if // there is no context.Context deadline set type deadlineMatcher struct{} // newDeadlineMatcher returns a new DeadlineMatcher func newDeadlineMatcher() gomock.Matcher { return &deadlineMatcher{} } // Matches returns true if the passed interface is a context with a deadline func (dm *deadlineMatcher) Matches(i interface{}) bool { ctx, ok := i.(context.Context) if !ok { return false } _, ok = ctx.Deadline() return ok } // String is needed to implement gomock.Matcher func (dm *deadlineMatcher) String() string { return fmt.Sprintf("deadlineMatcher{}") }