diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/instance/endpoint_test.go | 209 | ||||
-rw-r--r-- | pkg/instance/instance.go | 8 | ||||
-rw-r--r-- | pkg/instance/instance_test.go | 77 |
3 files changed, 278 insertions, 16 deletions
diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go index efcd4c0..fabf2e9 100644 --- a/pkg/instance/endpoint_test.go +++ b/pkg/instance/endpoint_test.go @@ -1,6 +1,7 @@ package stfe import ( + //"reflect" "bytes" "encoding/hex" "fmt" @@ -366,14 +367,6 @@ func TestGetConsistencyProof(t *testing.T) { types.NewSize, types.Delim, newSize, types.EOL, )) } - // values in testProof are not relevant for the test, just need a path - testProof := &types.ConsistencyProof{ - OldSize: 1, - NewSize: 2, - Path: []*[types.HashSize]byte{ - types.Hash(nil), - }, - } for _, table := range []struct { description string ascii io.Reader // buffer used to populate HTTP request @@ -388,11 +381,34 @@ func TestGetConsistencyProof(t *testing.T) { wantCode: http.StatusBadRequest, }, { + description: "invalid: bad request (OldSize is zero)", + ascii: buf(0, 1), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad request (OldSize > NewSize)", + ascii: buf(2, 1), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: backend failure", + ascii: buf(1, 2), + expect: true, + err: fmt.Errorf("something went wrong"), + wantCode: http.StatusInternalServerError, + }, + { description: "valid", ascii: buf(1, 2), expect: true, - rsp: testProof, - wantCode: http.StatusOK, + rsp: &types.ConsistencyProof{ + OldSize: 1, + NewSize: 2, + Path: []*[types.HashSize]byte{ + types.Hash(nil), + }, + }, + wantCode: http.StatusOK, }, } { // Run deferred functions at the end of each iteration @@ -426,7 +442,180 @@ func TestGetConsistencyProof(t *testing.T) { } func TestGetInclusionProof(t *testing.T) { + buf := func(hash *[types.HashSize]byte, treeSize int) io.Reader { + return bytes.NewBufferString(fmt.Sprintf( + "%s%s%x%s"+"%s%s%d%s", + types.LeafHash, types.Delim, hash[:], types.EOL, + types.TreeSize, types.Delim, treeSize, types.EOL, + )) + } + for _, table := range []struct { + description string + ascii io.Reader // buffer used to populate HTTP request + expect bool // set if a mock answer is expected + rsp *types.InclusionProof // inclusion proof from Trillian client + err error // error from Trillian client + wantCode int // HTTP status ok + }{ + { + description: "invalid: bad request (parser error)", + ascii: bytes.NewBufferString("key=value\n"), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad request (no proof for tree size)", + ascii: buf(types.Hash(nil), 1), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: backend failure", + ascii: buf(types.Hash(nil), 2), + expect: true, + err: fmt.Errorf("something went wrong"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + ascii: buf(types.Hash(nil), 2), + expect: true, + rsp: &types.InclusionProof{ + TreeSize: 2, + LeafIndex: 0, + Path: []*[types.HashSize]byte{ + types.Hash(nil), + }, + }, + wantCode: http.StatusOK, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + if table.expect { + client.EXPECT().GetInclusionProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + } + i := Instance{ + Config: testConfig, + Client: client, + } + + // Create HTTP request + url := types.EndpointGetProofByHash.Path("http://example.com", i.Prefix) + req, err := http.NewRequest("POST", url, table.ascii) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + + // Run HTTP request + w := httptest.NewRecorder() + mustHandle(t, i, types.EndpointGetProofByHash).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) + } + }() + } } func TestGetLeaves(t *testing.T) { + buf := func(startSize, endSize int64) io.Reader { + return bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s", + types.StartSize, types.Delim, startSize, types.EOL, + types.EndSize, types.Delim, endSize, types.EOL, + )) + } + for _, table := range []struct { + description string + ascii io.Reader // buffer used to populate HTTP request + expect bool // set if a mock answer is expected + rsp *types.LeafList // list of leaves from Trillian client + err error // error from Trillian client + wantCode int // HTTP status ok + }{ + { + description: "invalid: bad request (parser error)", + ascii: bytes.NewBufferString("key=value\n"), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad request (StartSize > EndSize)", + ascii: buf(1, 0), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: backend failure", + ascii: buf(0, 0), + expect: true, + err: fmt.Errorf("something went wrong"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid: one more entry than the configured MaxRange", + ascii: buf(0, testConfig.MaxRange), // query will be pruned + expect: true, + rsp: func() *types.LeafList { + var list types.LeafList + for i := int64(0); i < testConfig.MaxRange; i++ { + list = append(list[:], &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: types.Hash(nil), + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: types.Hash(nil), + }, + }) + } + return &list + }(), + wantCode: http.StatusOK, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + if table.expect { + client.EXPECT().GetLeaves(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + } + i := Instance{ + Config: testConfig, + Client: client, + } + + // Create HTTP request + url := types.EndpointGetLeaves.Path("http://example.com", i.Prefix) + req, err := http.NewRequest("POST", url, table.ascii) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + + // Run HTTP request + w := httptest.NewRecorder() + mustHandle(t, i, types.EndpointGetLeaves).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + // TODO: check that we got the right leaves back. It is especially + // important that we check that we got the right number of leaves. + // + // Pseuducode for when we have types.LeafList.UnmarshalASCII() + // + //list := &types.LeafList{} + //if err := list.UnmarshalASCII(w.Body); err != nil { + // t.Fatalf("must unmarshal leaf list: %v", err) + //} + //if got, want := list, table.rsp; !reflect.DeepEqual(got, want) { + // t.Errorf("got leaf list\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + //} + }() + } } diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index c2fe8fa..536eb60 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -110,7 +110,7 @@ func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, err func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) { var req types.CosignatureRequest if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("unpackOctetPost: %v", err) + return nil, fmt.Errorf("UnmarshalASCII: %v", err) } if _, ok := i.Witnesses[*req.KeyHash]; !ok { return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash) @@ -137,8 +137,10 @@ func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.Inclus if err := req.UnmarshalASCII(r.Body); err != nil { return nil, fmt.Errorf("UnmarshalASCII: %v", err) } - if req.TreeSize < 1 { - return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize) + if req.TreeSize < 2 { + // TreeSize:0 => not possible to prove inclusion of anything + // TreeSize:1 => you don't need an inclusion proof (it is always empty) + return nil, fmt.Errorf("TreeSize(%d) must be larger than one", req.TreeSize) } return &req, nil } diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go index 45a2837..f864628 100644 --- a/pkg/instance/instance_test.go +++ b/pkg/instance/instance_test.go @@ -1,9 +1,80 @@ package stfe import ( + "net/http" + "net/http/httptest" "testing" + + "github.com/system-transparency/stfe/pkg/types" ) -func TestHandlers(t *testing.T) {} -func TestPath(t *testing.T) {} -func TestServeHTTP(t *testing.T) {} +// TestHandlers check that the expected handlers are configured +func TestHandlers(t *testing.T) { + endpoints := map[types.Endpoint]bool{ + types.EndpointAddLeaf: false, + types.EndpointAddCosignature: false, + types.EndpointGetTreeHeadLatest: false, + types.EndpointGetTreeHeadToSign: false, + types.EndpointGetTreeHeadCosigned: false, + types.EndpointGetConsistencyProof: false, + types.EndpointGetProofByHash: false, + types.EndpointGetLeaves: false, + } + i := &Instance{ + Config: testConfig, + } + for _, handler := range i.Handlers() { + if _, ok := endpoints[handler.Endpoint]; !ok { + t.Errorf("got unexpected endpoint: %s", handler.Endpoint) + } + endpoints[handler.Endpoint] = true + } + for endpoint, ok := range endpoints { + if !ok { + t.Errorf("endpoint %s is not configured", endpoint) + } + } +} + +// TestServeHTTP checks that invalid HTTP methods are rejected +func TestServeHTTP(t *testing.T) { + i := &Instance{ + Config: testConfig, + } + for _, handler := range i.Handlers() { + // Prepare invalid HTTP request + method := http.MethodPost + if method == handler.Method { + method = http.MethodGet + } + url := handler.Endpoint.Path("http://example.com", i.Prefix) + req, err := http.NewRequest(method, url, nil) + if err != nil { + t.Fatalf("must create HTTP request: %v", err) + } + w := httptest.NewRecorder() + + // Check that it is rejected + handler.ServeHTTP(w, req) + if got, want := w.Code, http.StatusMethodNotAllowed; got != want { + t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint) + } + } +} + +func TestPath(t *testing.T) { + instance := &Instance{ + Config: Config{ + Prefix: "testonly", + }, + } + handler := Handler{ + Instance: instance, + Handler: addLeaf, + Endpoint: types.EndpointAddLeaf, + Method: http.MethodPost, + } + if got, want := handler.Path(), "testonly/st/v0/add-leaf"; got != want { + t.Errorf("got path %v but wanted %v", got, want) + } +} |