diff options
Diffstat (limited to 'handler_test.go')
| -rw-r--r-- | handler_test.go | 225 | 
1 files changed, 214 insertions, 11 deletions
diff --git a/handler_test.go b/handler_test.go index dd32c37..daa1a6c 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,5 +1,7 @@  package stfe +// TODO: refactor tests +  import (  	"bytes"  	"context" @@ -8,7 +10,6 @@ import (  	"testing"  	"crypto/ed25519" -	//"crypto/tls"  	"encoding/base64"  	"encoding/json"  	"net/http" @@ -28,16 +29,31 @@ type testHandler struct {  	instance *Instance  } -func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler { +func newTestHandler(t *testing.T, signer crypto.Signer, sth *StItem) *testHandler {  	ctrl := gomock.NewController(t)  	client := mockclient.NewMockTrillianLogClient(ctrl) +	lp := makeTestLogParameters(t, signer) +	source := &ActiveSthSource{ +		client:        client, +		logParameters: lp, +	} +	if sth != nil { +		source.currSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ +			SignatureV1{ +				Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), +				Signature: testSignature, +			}, +		}) +		source.nextSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) +		source.cosignatureFrom = make(map[string]bool) +	}  	return &testHandler{  		mockCtrl: ctrl,  		client:   client,  		instance: &Instance{ -			Deadline:      testDeadline,  			Client:        client, -			LogParameters: makeTestLogParameters(t, signer), +			LogParameters: lp, +			SthSource:     source,  		},  	}  } @@ -49,6 +65,8 @@ func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler {  		EndpointGetProofByHash:      Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},  		EndpointGetAnchors:          Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},  		EndpointGetEntries:          Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, +		EndpointGetStableSth:        Handler{instance: th.instance, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet}, +		EndpointGetCosi:             Handler{instance: th.instance, handler: getCosi, endpoint: EndpointGetCosi, method: http.MethodGet},  	}  } @@ -63,6 +81,7 @@ func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler {  func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler {  	return map[Endpoint]Handler{  		EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, +		EndpointAddCosi:  Handler{instance: th.instance, handler: addCosi, endpoint: EndpointAddCosi, method: http.MethodPost},  	}  } @@ -76,7 +95,7 @@ func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler {  // TestGetHandlersRejectPost checks that all get handlers reject post requests  func TestGetHandlersRejectPost(t *testing.T) { -	th := newTestHandler(t, nil) +	th := newTestHandler(t, nil, nil)  	defer th.mockCtrl.Finish()  	for endpoint, handler := range th.getHandlers(t) { @@ -96,7 +115,7 @@ func TestGetHandlersRejectPost(t *testing.T) {  // TestPostHandlersRejectGet checks that all post handlers reject get requests  func TestPostHandlersRejectGet(t *testing.T) { -	th := newTestHandler(t, nil) +	th := newTestHandler(t, nil, nil)  	defer th.mockCtrl.Finish()  	for endpoint, handler := range th.postHandlers(t) { @@ -196,7 +215,7 @@ func TestGetEntries(t *testing.T) {  		},  	} {  		func() { // run deferred functions at the end of each iteration -			th := newTestHandler(t, nil) +			th := newTestHandler(t, nil, nil)  			defer th.mockCtrl.Finish()  			url := EndpointGetEntries.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -298,7 +317,7 @@ func TestAddEntry(t *testing.T) {  		},  	} {  		func() { // run deferred functions at the end of each iteration -			th := newTestHandler(t, table.signer) +			th := newTestHandler(t, table.signer, nil)  			defer th.mockCtrl.Finish()  			url := EndpointAddEntry.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -392,7 +411,7 @@ func TestGetSth(t *testing.T) {  		},  	} {  		func() { // run deferred functions at the end of each iteration -			th := newTestHandler(t, table.signer) +			th := newTestHandler(t, table.signer, nil)  			defer th.mockCtrl.Finish()  			url := EndpointGetSth.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -496,7 +515,7 @@ func TestGetConsistencyProof(t *testing.T) {  		},  	} {  		func() { // run deferred functions at the end of each iteration -			th := newTestHandler(t, nil) +			th := newTestHandler(t, nil, nil)  			defer th.mockCtrl.Finish()  			url := EndpointGetConsistencyProof.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -605,7 +624,7 @@ func TestGetProofByHash(t *testing.T) {  		},  	} {  		func() { // run deferred functions at the end of each iteration -			th := newTestHandler(t, nil) +			th := newTestHandler(t, nil, nil)  			defer th.mockCtrl.Finish()  			url := EndpointGetProofByHash.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -671,6 +690,166 @@ func TestGetProofByHash(t *testing.T) {  	}  } +func TestGetStableSth(t *testing.T) { +	for _, table := range cosiTestCases(t) { +		func() { // run deferred functions at the end of each iteration +			th := newTestHandler(t, nil, table.sth) +			defer th.mockCtrl.Finish() + +			// Setup and run client query +			url := EndpointGetStableSth.Path("http://example.com", th.instance.LogParameters.Prefix) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("failed creating http request: %v", err) +			} +			w := httptest.NewRecorder() +			th.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) + +			// Check response code +			if w.Code != table.wantCode { +				t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) +			} +			if w.Code != http.StatusOK { +				return +			} +			// Check response bytes +			var gotBytes []byte +			if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { +				t.Errorf("failed unmarshaling json: %v, wanted ok", err) +				return +			} +			wantBytes, _ := table.sth.Marshal() +			if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { +				t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetCosi(t *testing.T) { +	for _, table := range cosiTestCases(t) { +		func() { // run deferred functions at the end of each iteration +			th := newTestHandler(t, nil, table.sth) +			defer th.mockCtrl.Finish() + +			// Setup and run client query +			url := EndpointGetCosi.Path("http://example.com", th.instance.LogParameters.Prefix) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("failed creating http request: %v", err) +			} +			w := httptest.NewRecorder() +			th.getHandler(t, EndpointGetCosi).ServeHTTP(w, req) + +			// Check response code +			if w.Code != table.wantCode { +				t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) +			} +			if w.Code != http.StatusOK { +				return +			} +			// Check response bytes +			var gotBytes []byte +			if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { +				t.Errorf("failed unmarshaling json: %v, wanted ok", err) +				return +			} +			wantCosth := NewCosignedTreeHeadV1(table.sth.SignedTreeHeadV1, []SignatureV1{ +				SignatureV1{ +					Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), +					Signature: testSignature, +				}, +			}) +			wantBytes, _ := wantCosth.Marshal() +			if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { +				t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestAddCosi(t *testing.T) { +	validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) +	validSth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature) +	for _, table := range []struct { +		description string +		sth         *StItem +		breq        *bytes.Buffer +		wantCode    int +	}{ +		{ +			description: "invalid request: untrusted witness", // more specific tests can be found in TestNewAddCosignatureRequest +			sth:         validSth, +			breq:        mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid request: cosigned wrong sth", // more specific tests can be found in TestAddCosignature +			sth:         validSth, +			breq:        mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth2), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "valid", +			sth:         validSth, +			breq:        mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth), +			wantCode:    http.StatusOK, +		}, +	} { +		func() { // run deferred functions at the end of each iteration +			th := newTestHandler(t, nil, table.sth) +			defer th.mockCtrl.Finish() + +			// Setup and run client query +			url := EndpointAddCosi.Path("http://example.com", th.instance.LogParameters.Prefix) +			req, err := http.NewRequest("POST", url, table.breq) +			if err != nil { +				t.Fatalf("failed creating http request: %v", err) +			} +			req.Header.Set("Content-Type", "application/json") + +			w := httptest.NewRecorder() +			th.postHandler(t, EndpointAddCosi).ServeHTTP(w, req) +			if w.Code != table.wantCode { +				t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) +			} + +			// Check response +			if w.Code != table.wantCode { +				t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) +			} +		}() +	} +} + +type cosiTestCase struct { +	description string +	sth         *StItem +	wantCode    int +} + +// cosiTestCases returns test cases used by TestGetStableSth and TestGetCosi +func cosiTestCases(t *testing.T) []cosiTestCase { +	t.Helper() +	return []cosiTestCase{ +		{ +			description: "no cosigned/stable sth", +			sth:         nil, +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "malformed cosigned/stable sth", +			sth:         NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), []byte("not a log id"), testSignature), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			sth:         NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature), +			wantCode:    http.StatusOK, +		}, +	} +} +  // mustMakeEd25519ChecksumV1 creates an ed25519-signed ChecksumV1 leaf  func mustMakeEd25519ChecksumV1(t *testing.T, id, checksum, vk, sk []byte) ([]byte, []byte) {  	t.Helper() @@ -697,6 +876,30 @@ func mustMakeEd25519ChecksumV1Buffer(t *testing.T, identifier, checksum, vk, sk  	return bytes.NewBuffer(data)  } +// mustMakeAddCosiBuffer creates an add-cosi data buffer +func mustMakeAddCosiBuffer(t *testing.T, sk, vk []byte, sth *StItem) *bytes.Buffer { +	t.Helper() +	msg, err := sth.Marshal() +	if err != nil { +		t.Fatalf("must marshal sth: %v", err) +	} +	costh := NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ +		SignatureV1{ +			Namespace: *mustNewNamespaceEd25519V1(t, vk), +			Signature: ed25519.Sign(ed25519.PrivateKey(sk), msg), +		}, +	}) +	item, err := costh.Marshal() +	if err != nil { +		t.Fatalf("must marshal costh: %v", err) +	} +	data, err := json.Marshal(AddCosignatureRequest{item}) +	if err != nil { +		t.Fatalf("must marshal add-cosi request: %v", err) +	} +	return bytes.NewBuffer(data) +} +  // deadlineMatcher implements gomock.Matcher, such that an error is raised if  // there is no context.Context deadline set  type deadlineMatcher struct{}  | 
