aboutsummaryrefslogtreecommitdiff
path: root/handler_test.go
diff options
context:
space:
mode:
authorRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-02-17 19:58:27 +0100
committerRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-02-17 19:58:27 +0100
commit238518951868db81cd3a004e5c3f0b99f8e82b06 (patch)
tree1df8e71e869272bc5324e7412eab9236276f3548 /handler_test.go
parent72c8492ee1bd07d5960c9920e51b7addac11b806 (diff)
added basic server-side cosigning (work in progress)
Diffstat (limited to 'handler_test.go')
-rw-r--r--handler_test.go225
1 files changed, 214 insertions, 11 deletions
diff --git a/handler_test.go b/handler_test.go
index dd32c37..daa1a6c 100644
--- a/handler_test.go
+++ b/handler_test.go
@@ -1,5 +1,7 @@
package stfe
+// TODO: refactor tests
+
import (
"bytes"
"context"
@@ -8,7 +10,6 @@ import (
"testing"
"crypto/ed25519"
- //"crypto/tls"
"encoding/base64"
"encoding/json"
"net/http"
@@ -28,16 +29,31 @@ type testHandler struct {
instance *Instance
}
-func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler {
+func newTestHandler(t *testing.T, signer crypto.Signer, sth *StItem) *testHandler {
ctrl := gomock.NewController(t)
client := mockclient.NewMockTrillianLogClient(ctrl)
+ lp := makeTestLogParameters(t, signer)
+ source := &ActiveSthSource{
+ client: client,
+ logParameters: lp,
+ }
+ if sth != nil {
+ source.currSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{
+ SignatureV1{
+ Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk),
+ Signature: testSignature,
+ },
+ })
+ source.nextSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil)
+ source.cosignatureFrom = make(map[string]bool)
+ }
return &testHandler{
mockCtrl: ctrl,
client: client,
instance: &Instance{
- Deadline: testDeadline,
Client: client,
- LogParameters: makeTestLogParameters(t, signer),
+ LogParameters: lp,
+ SthSource: source,
},
}
}
@@ -49,6 +65,8 @@ func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler {
EndpointGetProofByHash: Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},
EndpointGetAnchors: Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},
EndpointGetEntries: Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet},
+ EndpointGetStableSth: Handler{instance: th.instance, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet},
+ EndpointGetCosi: Handler{instance: th.instance, handler: getCosi, endpoint: EndpointGetCosi, method: http.MethodGet},
}
}
@@ -63,6 +81,7 @@ func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler {
func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler {
return map[Endpoint]Handler{
EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost},
+ EndpointAddCosi: Handler{instance: th.instance, handler: addCosi, endpoint: EndpointAddCosi, method: http.MethodPost},
}
}
@@ -76,7 +95,7 @@ func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler {
// TestGetHandlersRejectPost checks that all get handlers reject post requests
func TestGetHandlersRejectPost(t *testing.T) {
- th := newTestHandler(t, nil)
+ th := newTestHandler(t, nil, nil)
defer th.mockCtrl.Finish()
for endpoint, handler := range th.getHandlers(t) {
@@ -96,7 +115,7 @@ func TestGetHandlersRejectPost(t *testing.T) {
// TestPostHandlersRejectGet checks that all post handlers reject get requests
func TestPostHandlersRejectGet(t *testing.T) {
- th := newTestHandler(t, nil)
+ th := newTestHandler(t, nil, nil)
defer th.mockCtrl.Finish()
for endpoint, handler := range th.postHandlers(t) {
@@ -196,7 +215,7 @@ func TestGetEntries(t *testing.T) {
},
} {
func() { // run deferred functions at the end of each iteration
- th := newTestHandler(t, nil)
+ th := newTestHandler(t, nil, nil)
defer th.mockCtrl.Finish()
url := EndpointGetEntries.Path("http://example.com", th.instance.LogParameters.Prefix)
@@ -298,7 +317,7 @@ func TestAddEntry(t *testing.T) {
},
} {
func() { // run deferred functions at the end of each iteration
- th := newTestHandler(t, table.signer)
+ th := newTestHandler(t, table.signer, nil)
defer th.mockCtrl.Finish()
url := EndpointAddEntry.Path("http://example.com", th.instance.LogParameters.Prefix)
@@ -392,7 +411,7 @@ func TestGetSth(t *testing.T) {
},
} {
func() { // run deferred functions at the end of each iteration
- th := newTestHandler(t, table.signer)
+ th := newTestHandler(t, table.signer, nil)
defer th.mockCtrl.Finish()
url := EndpointGetSth.Path("http://example.com", th.instance.LogParameters.Prefix)
@@ -496,7 +515,7 @@ func TestGetConsistencyProof(t *testing.T) {
},
} {
func() { // run deferred functions at the end of each iteration
- th := newTestHandler(t, nil)
+ th := newTestHandler(t, nil, nil)
defer th.mockCtrl.Finish()
url := EndpointGetConsistencyProof.Path("http://example.com", th.instance.LogParameters.Prefix)
@@ -605,7 +624,7 @@ func TestGetProofByHash(t *testing.T) {
},
} {
func() { // run deferred functions at the end of each iteration
- th := newTestHandler(t, nil)
+ th := newTestHandler(t, nil, nil)
defer th.mockCtrl.Finish()
url := EndpointGetProofByHash.Path("http://example.com", th.instance.LogParameters.Prefix)
@@ -671,6 +690,166 @@ func TestGetProofByHash(t *testing.T) {
}
}
+func TestGetStableSth(t *testing.T) {
+ for _, table := range cosiTestCases(t) {
+ func() { // run deferred functions at the end of each iteration
+ th := newTestHandler(t, nil, table.sth)
+ defer th.mockCtrl.Finish()
+
+ // Setup and run client query
+ url := EndpointGetStableSth.Path("http://example.com", th.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ w := httptest.NewRecorder()
+ th.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req)
+
+ // Check response code
+ if w.Code != table.wantCode {
+ t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+ // Check response bytes
+ var gotBytes []byte
+ if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil {
+ t.Errorf("failed unmarshaling json: %v, wanted ok", err)
+ return
+ }
+ wantBytes, _ := table.sth.Marshal()
+ if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) {
+ t.Errorf("wanted response %X but got %X in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetCosi(t *testing.T) {
+ for _, table := range cosiTestCases(t) {
+ func() { // run deferred functions at the end of each iteration
+ th := newTestHandler(t, nil, table.sth)
+ defer th.mockCtrl.Finish()
+
+ // Setup and run client query
+ url := EndpointGetCosi.Path("http://example.com", th.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ w := httptest.NewRecorder()
+ th.getHandler(t, EndpointGetCosi).ServeHTTP(w, req)
+
+ // Check response code
+ if w.Code != table.wantCode {
+ t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+ // Check response bytes
+ var gotBytes []byte
+ if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil {
+ t.Errorf("failed unmarshaling json: %v, wanted ok", err)
+ return
+ }
+ wantCosth := NewCosignedTreeHeadV1(table.sth.SignedTreeHeadV1, []SignatureV1{
+ SignatureV1{
+ Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk),
+ Signature: testSignature,
+ },
+ })
+ wantBytes, _ := wantCosth.Marshal()
+ if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) {
+ t.Errorf("wanted response %X but got %X in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestAddCosi(t *testing.T) {
+ validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature)
+ validSth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature)
+ for _, table := range []struct {
+ description string
+ sth *StItem
+ breq *bytes.Buffer
+ wantCode int
+ }{
+ {
+ description: "invalid request: untrusted witness", // more specific tests can be found in TestNewAddCosignatureRequest
+ sth: validSth,
+ breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid request: cosigned wrong sth", // more specific tests can be found in TestAddCosignature
+ sth: validSth,
+ breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth2),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "valid",
+ sth: validSth,
+ breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth),
+ wantCode: http.StatusOK,
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ th := newTestHandler(t, nil, table.sth)
+ defer th.mockCtrl.Finish()
+
+ // Setup and run client query
+ url := EndpointAddCosi.Path("http://example.com", th.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ w := httptest.NewRecorder()
+ th.postHandler(t, EndpointAddCosi).ServeHTTP(w, req)
+ if w.Code != table.wantCode {
+ t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode)
+ }
+
+ // Check response
+ if w.Code != table.wantCode {
+ t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode)
+ }
+ }()
+ }
+}
+
+type cosiTestCase struct {
+ description string
+ sth *StItem
+ wantCode int
+}
+
+// cosiTestCases returns test cases used by TestGetStableSth and TestGetCosi
+func cosiTestCases(t *testing.T) []cosiTestCase {
+ t.Helper()
+ return []cosiTestCase{
+ {
+ description: "no cosigned/stable sth",
+ sth: nil,
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "malformed cosigned/stable sth",
+ sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), []byte("not a log id"), testSignature),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature),
+ wantCode: http.StatusOK,
+ },
+ }
+}
+
// mustMakeEd25519ChecksumV1 creates an ed25519-signed ChecksumV1 leaf
func mustMakeEd25519ChecksumV1(t *testing.T, id, checksum, vk, sk []byte) ([]byte, []byte) {
t.Helper()
@@ -697,6 +876,30 @@ func mustMakeEd25519ChecksumV1Buffer(t *testing.T, identifier, checksum, vk, sk
return bytes.NewBuffer(data)
}
+// mustMakeAddCosiBuffer creates an add-cosi data buffer
+func mustMakeAddCosiBuffer(t *testing.T, sk, vk []byte, sth *StItem) *bytes.Buffer {
+ t.Helper()
+ msg, err := sth.Marshal()
+ if err != nil {
+ t.Fatalf("must marshal sth: %v", err)
+ }
+ costh := NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{
+ SignatureV1{
+ Namespace: *mustNewNamespaceEd25519V1(t, vk),
+ Signature: ed25519.Sign(ed25519.PrivateKey(sk), msg),
+ },
+ })
+ item, err := costh.Marshal()
+ if err != nil {
+ t.Fatalf("must marshal costh: %v", err)
+ }
+ data, err := json.Marshal(AddCosignatureRequest{item})
+ if err != nil {
+ t.Fatalf("must marshal add-cosi request: %v", err)
+ }
+ return bytes.NewBuffer(data)
+}
+
// deadlineMatcher implements gomock.Matcher, such that an error is raised if
// there is no context.Context deadline set
type deadlineMatcher struct{}