package stfe import ( "bytes" "context" "fmt" "reflect" "testing" "net/http" "net/http/httptest" "github.com/golang/mock/gomock" cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" "github.com/google/trillian" "github.com/system-transparency/stfe/testdata" "github.com/system-transparency/stfe/types" ) func TestEndpointAddEntry(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer trsp *trillian.QueueLeafResponse terr error wantCode int }{ { description: "invalid: bad request: empty", breq: bytes.NewBuffer(nil), wantCode: http.StatusBadRequest, }, { description: "invalid: bad Trillian response: error", breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), terr: fmt.Errorf("backend failure"), wantCode: http.StatusInternalServerError, }, { description: "valid", breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), trsp: testdata.DefaultTQlr(t, false), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) defer ti.ctrl.Finish() url := EndpointAddEntry.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("must create http request: %v", err) } req.Header.Set("Content-Type", "application/octet-stream") if table.trsp != nil || table.terr != nil { ti.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() ti.postHandler(t, EndpointAddEntry).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } }() } } func TestEndpointAddCosignature(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer wantCode int }{ { description: "invalid: bad request: empty", breq: bytes.NewBuffer(nil), wantCode: http.StatusBadRequest, }, { description: "invalid: signed wrong sth", // newLogParameters() use testdata.Ed25519VkLog as default breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog2), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), wantCode: http.StatusBadRequest, }, { description: "valid", breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), wantCode: http.StatusOK, }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) defer ti.ctrl.Finish() url := EndpointAddCosignature.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("must create http request: %v", err) } req.Header.Set("Content-Type", "application/octet-stream") w := httptest.NewRecorder() ti.postHandler(t, EndpointAddCosignature).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } }() } } func TestEndpointGetLatestSth(t *testing.T) { for _, table := range []struct { description string trsp *trillian.GetLatestSignedLogRootResponse terr error wantCode int wantItem *types.StItem }{ { description: "backend failure", terr: fmt.Errorf("backend failure"), wantCode: http.StatusInternalServerError, }, { description: "valid", trsp: testdata.DefaultTSlr(t), wantCode: http.StatusOK, wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, cttestdata.NewSignerWithFixedSig(nil, testdata.Signature)) ti.ctrl.Finish() // Setup and run client query url := EndpointGetLatestSth.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must create http request: %v", err) } if table.trsp != nil || table.terr != nil { ti.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() ti.getHandler(t, EndpointGetLatestSth).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItem if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointGetStableSth(t *testing.T) { for _, table := range []struct { description string useBadSource bool wantCode int wantItem *types.StItem }{ { description: "invalid: sth source failure", useBadSource: true, wantCode: http.StatusInternalServerError, }, { description: "valid", wantCode: http.StatusOK, wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) ti.ctrl.Finish() if table.useBadSource { ti.instance.SthSource = &ActiveSthSource{} } // Setup and run client query url := EndpointGetStableSth.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must create http request: %v", err) } w := httptest.NewRecorder() ti.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItem if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointGetCosignedSth(t *testing.T) { for _, table := range []struct { description string useBadSource bool wantCode int wantItem *types.StItem }{ { description: "invalid: sth source failure", useBadSource: true, wantCode: http.StatusInternalServerError, }, { description: "valid", wantCode: http.StatusOK, wantItem: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) ti.ctrl.Finish() if table.useBadSource { ti.instance.SthSource = &ActiveSthSource{} } // Setup and run client query url := EndpointGetCosignedSth.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must create http request: %v", err) } w := httptest.NewRecorder() ti.getHandler(t, EndpointGetCosignedSth).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItem if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointGetProofByHash(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer trsp *trillian.GetInclusionProofByHashResponse terr error wantCode int wantItem *types.StItem }{ { description: "invalid: bad request: empty", breq: bytes.NewBuffer(nil), wantCode: http.StatusBadRequest, }, { description: "invalid: bad Trillian response: error", breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), terr: fmt.Errorf("backend failure"), wantCode: http.StatusInternalServerError, }, { description: "valid", breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), trsp: testdata.DefaultTGipbhr(t), wantCode: http.StatusOK, wantItem: testdata.DefaultInclusionProof(t, 1), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) defer ti.ctrl.Finish() url := EndpointGetProofByHash.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("must create http request: %v", err) } req.Header.Set("Content-Type", "application/octet-stream") if table.trsp != nil || table.terr != nil { ti.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() ti.postHandler(t, EndpointGetProofByHash).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItem if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointGetConsistencyProof(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer trsp *trillian.GetConsistencyProofResponse terr error wantCode int wantItem *types.StItem }{ { description: "invalid: bad request: empty", breq: bytes.NewBuffer(nil), wantCode: http.StatusBadRequest, }, { description: "invalid: bad Trillian response: error", breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), terr: fmt.Errorf("backend failure"), wantCode: http.StatusInternalServerError, }, { description: "valid", breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), trsp: testdata.DefaultTGcpr(t), wantCode: http.StatusOK, wantItem: testdata.DefaultConsistencyProof(t, 1, 2), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) defer ti.ctrl.Finish() url := EndpointGetConsistencyProof.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("must create http request: %v", err) } req.Header.Set("Content-Type", "application/octet-stream") if table.trsp != nil || table.terr != nil { ti.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() ti.postHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItem if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointGetEntriesV1(t *testing.T) { for _, table := range []struct { description string breq *bytes.Buffer trsp *trillian.GetLeavesByRangeResponse terr error wantCode int wantItem *types.StItemList }{ { description: "invalid: bad request: empty", breq: bytes.NewBuffer(nil), wantCode: http.StatusBadRequest, }, { description: "invalid: bad Trillian response: error", breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: 0})), terr: fmt.Errorf("backend failure"), wantCode: http.StatusInternalServerError, }, { description: "valid", // remember that newLogParameters() have testdata.MaxRange configured breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange - 1)})), trsp: testdata.DefaultTGlbrr(t, 0, testdata.MaxRange-1), wantCode: http.StatusOK, wantItem: testdata.DefaultStItemList(t, 0, uint64(testdata.MaxRange)-1), }, } { func() { // run deferred functions at the end of each iteration ti := newTestInstance(t, nil) defer ti.ctrl.Finish() url := EndpointGetEntries.Path("http://example.com", ti.instance.LogParameters.Prefix) req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("must create http request: %v", err) } req.Header.Set("Content-Type", "application/octet-stream") if table.trsp != nil || table.terr != nil { ti.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() ti.postHandler(t, EndpointGetEntries).ServeHTTP(w, req) if got, want := w.Code, table.wantCode; got != want { t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) } if w.Code != http.StatusOK { return } var item types.StItemList if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) } if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) } }() } } func TestEndpointPath(t *testing.T) { base, prefix, proto := "http://example.com", "test", "st/v1" for _, table := range []struct { endpoint Endpoint want string }{ { endpoint: EndpointAddEntry, want: "http://example.com/test/st/v1/add-entry", }, { endpoint: EndpointAddCosignature, want: "http://example.com/test/st/v1/add-cosignature", }, { endpoint: EndpointGetLatestSth, want: "http://example.com/test/st/v1/get-latest-sth", }, { endpoint: EndpointGetStableSth, want: "http://example.com/test/st/v1/get-stable-sth", }, { endpoint: EndpointGetCosignedSth, want: "http://example.com/test/st/v1/get-cosigned-sth", }, { endpoint: EndpointGetConsistencyProof, want: "http://example.com/test/st/v1/get-consistency-proof", }, { endpoint: EndpointGetProofByHash, want: "http://example.com/test/st/v1/get-proof-by-hash", }, { endpoint: EndpointGetEntries, want: "http://example.com/test/st/v1/get-entries", }, } { if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want { t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want) } if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want { t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want) } } } // TODO: TestWriteOctetResponse func TestWriteOctetResponse(t *testing.T) { } // deadlineMatcher implements gomock.Matcher, such that an error is raised if // there is no context.Context deadline set type deadlineMatcher struct{} // newDeadlineMatcher returns a new DeadlineMatcher func newDeadlineMatcher() gomock.Matcher { return &deadlineMatcher{} } // Matches returns true if the passed interface is a context with a deadline func (dm *deadlineMatcher) Matches(i interface{}) bool { ctx, ok := i.(context.Context) if !ok { return false } _, ok = ctx.Deadline() return ok } // String is needed to implement gomock.Matcher func (dm *deadlineMatcher) String() string { return fmt.Sprintf("deadlineMatcher{}") }