diff options
Diffstat (limited to 'pkg/instance')
| -rw-r--r-- | pkg/instance/endpoint_test.go | 209 | ||||
| -rw-r--r-- | pkg/instance/instance.go | 8 | ||||
| -rw-r--r-- | pkg/instance/instance_test.go | 77 | 
3 files changed, 278 insertions, 16 deletions
| diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go index efcd4c0..fabf2e9 100644 --- a/pkg/instance/endpoint_test.go +++ b/pkg/instance/endpoint_test.go @@ -1,6 +1,7 @@  package stfe  import ( +	//"reflect"  	"bytes"  	"encoding/hex"  	"fmt" @@ -366,14 +367,6 @@ func TestGetConsistencyProof(t *testing.T) {  			types.NewSize, types.Delim, newSize, types.EOL,  		))  	} -	// values in testProof are not relevant for the test, just need a path -	testProof := &types.ConsistencyProof{ -		OldSize: 1, -		NewSize: 2, -		Path: []*[types.HashSize]byte{ -			types.Hash(nil), -		}, -	}  	for _, table := range []struct {  		description string  		ascii       io.Reader               // buffer used to populate HTTP request @@ -388,11 +381,34 @@ func TestGetConsistencyProof(t *testing.T) {  			wantCode:    http.StatusBadRequest,  		},  		{ +			description: "invalid: bad request (OldSize is zero)", +			ascii:       buf(0, 1), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (OldSize > NewSize)", +			ascii:       buf(2, 1), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(1, 2), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{  			description: "valid",  			ascii:       buf(1, 2),  			expect:      true, -			rsp:         testProof, -			wantCode:    http.StatusOK, +			rsp: &types.ConsistencyProof{ +				OldSize: 1, +				NewSize: 2, +				Path: []*[types.HashSize]byte{ +					types.Hash(nil), +				}, +			}, +			wantCode: http.StatusOK,  		},  	} {  		// Run deferred functions at the end of each iteration @@ -426,7 +442,180 @@ func TestGetConsistencyProof(t *testing.T) {  }  func TestGetInclusionProof(t *testing.T) { +	buf := func(hash *[types.HashSize]byte, treeSize int) io.Reader { +		return bytes.NewBufferString(fmt.Sprintf( +			"%s%s%x%s"+"%s%s%d%s", +			types.LeafHash, types.Delim, hash[:], types.EOL, +			types.TreeSize, types.Delim, treeSize, types.EOL, +		)) +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader             // buffer used to populate HTTP request +		expect      bool                  // set if a mock answer is expected +		rsp         *types.InclusionProof // inclusion proof from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (no proof for tree size)", +			ascii:       buf(types.Hash(nil), 1), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(types.Hash(nil), 2), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			ascii:       buf(types.Hash(nil), 2), +			expect:      true, +			rsp: &types.InclusionProof{ +				TreeSize:  2, +				LeafIndex: 0, +				Path: []*[types.HashSize]byte{ +					types.Hash(nil), +				}, +			}, +			wantCode: http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetInclusionProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config: testConfig, +				Client: client, +			} + +			// Create HTTP request +			url := types.EndpointGetProofByHash.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetProofByHash).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	}  }  func TestGetLeaves(t *testing.T) { +	buf := func(startSize, endSize int64) io.Reader { +		return bytes.NewBufferString(fmt.Sprintf( +			"%s%s%d%s"+"%s%s%d%s", +			types.StartSize, types.Delim, startSize, types.EOL, +			types.EndSize, types.Delim, endSize, types.EOL, +		)) +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader       // buffer used to populate HTTP request +		expect      bool            // set if a mock answer is expected +		rsp         *types.LeafList // list of leaves from Trillian client +		err         error           // error from Trillian client +		wantCode    int             // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (StartSize > EndSize)", +			ascii:       buf(1, 0), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(0, 0), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid: one more entry than the configured MaxRange", +			ascii:       buf(0, testConfig.MaxRange), // query will be pruned +			expect:      true, +			rsp: func() *types.LeafList { +				var list types.LeafList +				for i := int64(0); i < testConfig.MaxRange; i++ { +					list = append(list[:], &types.Leaf{ +						Message: types.Message{ +							ShardHint: 0, +							Checksum:  types.Hash(nil), +						}, +						SigIdent: types.SigIdent{ +							Signature: &[types.SignatureSize]byte{}, +							KeyHash:   types.Hash(nil), +						}, +					}) +				} +				return &list +			}(), +			wantCode: http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetLeaves(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config: testConfig, +				Client: client, +			} + +			// Create HTTP request +			url := types.EndpointGetLeaves.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetLeaves).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +			if w.Code != http.StatusOK { +				return +			} + +			// TODO: check that we got the right leaves back.  It is especially +			// important that we check that we got the right number of leaves. +			// +			// Pseuducode for when we have types.LeafList.UnmarshalASCII() +			// +			//list := &types.LeafList{} +			//if err := list.UnmarshalASCII(w.Body); err != nil { +			//	t.Fatalf("must unmarshal leaf list: %v", err) +			//} +			//if got, want := list, table.rsp; !reflect.DeepEqual(got, want) { +			//	t.Errorf("got leaf list\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			//} +		}() +	}  } diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index c2fe8fa..536eb60 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -110,7 +110,7 @@ func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, err  func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) {  	var req types.CosignatureRequest  	if err := req.UnmarshalASCII(r.Body); err != nil { -		return nil, fmt.Errorf("unpackOctetPost: %v", err) +		return nil, fmt.Errorf("UnmarshalASCII: %v", err)  	}  	if _, ok := i.Witnesses[*req.KeyHash]; !ok {  		return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash) @@ -137,8 +137,10 @@ func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.Inclus  	if err := req.UnmarshalASCII(r.Body); err != nil {  		return nil, fmt.Errorf("UnmarshalASCII: %v", err)  	} -	if req.TreeSize < 1 { -		return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize) +	if req.TreeSize < 2 { +		// TreeSize:0 => not possible to prove inclusion of anything +		// TreeSize:1 => you don't need an inclusion proof (it is always empty) +		return nil, fmt.Errorf("TreeSize(%d) must be larger than one", req.TreeSize)  	}  	return &req, nil  } diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go index 45a2837..f864628 100644 --- a/pkg/instance/instance_test.go +++ b/pkg/instance/instance_test.go @@ -1,9 +1,80 @@  package stfe  import ( +	"net/http" +	"net/http/httptest"  	"testing" + +	"github.com/system-transparency/stfe/pkg/types"  ) -func TestHandlers(t *testing.T)  {} -func TestPath(t *testing.T)      {} -func TestServeHTTP(t *testing.T) {} +// TestHandlers check that the expected handlers are configured +func TestHandlers(t *testing.T) { +	endpoints := map[types.Endpoint]bool{ +		types.EndpointAddLeaf:             false, +		types.EndpointAddCosignature:      false, +		types.EndpointGetTreeHeadLatest:   false, +		types.EndpointGetTreeHeadToSign:   false, +		types.EndpointGetTreeHeadCosigned: false, +		types.EndpointGetConsistencyProof: false, +		types.EndpointGetProofByHash:      false, +		types.EndpointGetLeaves:           false, +	} +	i := &Instance{ +		Config: testConfig, +	} +	for _, handler := range i.Handlers() { +		if _, ok := endpoints[handler.Endpoint]; !ok { +			t.Errorf("got unexpected endpoint: %s", handler.Endpoint) +		} +		endpoints[handler.Endpoint] = true +	} +	for endpoint, ok := range endpoints { +		if !ok { +			t.Errorf("endpoint %s is not configured", endpoint) +		} +	} +} + +// TestServeHTTP checks that invalid HTTP methods are rejected +func TestServeHTTP(t *testing.T) { +	i := &Instance{ +		Config: testConfig, +	} +	for _, handler := range i.Handlers() { +		// Prepare invalid HTTP request +		method := http.MethodPost +		if method == handler.Method { +			method = http.MethodGet +		} +		url := handler.Endpoint.Path("http://example.com", i.Prefix) +		req, err := http.NewRequest(method, url, nil) +		if err != nil { +			t.Fatalf("must create HTTP request: %v", err) +		} +		w := httptest.NewRecorder() + +		// Check that it is rejected +		handler.ServeHTTP(w, req) +		if got, want := w.Code, http.StatusMethodNotAllowed; got != want { +			t.Errorf("got HTTP code %v but wanted %v for endpoint %q", got, want, handler.Endpoint) +		} +	} +} + +func TestPath(t *testing.T) { +	instance := &Instance{ +		Config: Config{ +			Prefix: "testonly", +		}, +	} +	handler := Handler{ +		Instance: instance, +		Handler:  addLeaf, +		Endpoint: types.EndpointAddLeaf, +		Method:   http.MethodPost, +	} +	if got, want := handler.Path(), "testonly/st/v0/add-leaf"; got != want { +		t.Errorf("got path %v but wanted %v", got, want) +	} +} | 
