aboutsummaryrefslogtreecommitdiff
path: root/instance_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'instance_test.go')
-rw-r--r--instance_test.go320
1 files changed, 113 insertions, 207 deletions
diff --git a/instance_test.go b/instance_test.go
index 0ceedb8..21ef808 100644
--- a/instance_test.go
+++ b/instance_test.go
@@ -1,121 +1,90 @@
package stfe
import (
- "bytes"
+ "crypto"
"testing"
- "time"
- "crypto"
- "crypto/ed25519"
+ "net/http"
+ "net/http/httptest"
- "github.com/system-transparency/stfe/namespace"
- "github.com/system-transparency/stfe/namespace/testdata"
+ "github.com/golang/mock/gomock"
+ "github.com/google/certificate-transparency-go/trillian/mockclient"
+ "github.com/system-transparency/stfe/testdata"
+ "github.com/system-transparency/stfe/types"
)
-var (
- testLogId = append([]byte{0x00, 0x01, 0x20}, testdata.Ed25519Vk3...)
- testTreeId = int64(0)
- testMaxRange = int64(3)
- testPrefix = "test"
- testHashType = crypto.SHA256
- testSignature = make([]byte, 32)
- testNodeHash = make([]byte, 32)
- testMessage = []byte("test message")
- testPackage = []byte("foobar")
- testChecksum = make([]byte, 32)
- testTreeSize = uint64(128)
- testTreeSizeLarger = uint64(256)
- testTimestamp = uint64(0)
- testProof = [][]byte{
- testNodeHash,
- testNodeHash,
- }
- testIndex = uint64(0)
- testHashLen = 31
- testDeadline = time.Second * 5
- testInterval = time.Second * 10
-)
+type testInstance struct {
+ ctrl *gomock.Controller
+ client *mockclient.MockTrillianLogClient
+ instance *Instance
+}
-// TestNewLogParamters checks that invalid ones are rejected and that a valid
-// set of parameters are accepted.
-func TestNewLogParameters(t *testing.T) {
- testLogId := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk3)
- namespaces := mustNewNamespacePool(t, []*namespace.Namespace{
- mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk),
- })
- witnesses := mustNewNamespacePool(t, []*namespace.Namespace{})
- signer := ed25519.PrivateKey(testdata.Ed25519Sk)
- for _, table := range []struct {
- description string
- logId *namespace.Namespace
- maxRange int64
- signer crypto.Signer
- wantErr bool
- }{
- {
- description: "invalid signer: nil",
- logId: testLogId,
- maxRange: testMaxRange,
- signer: nil,
- wantErr: true,
- },
- {
- description: "invalid max range",
- logId: testLogId,
- maxRange: 0,
- signer: signer,
- wantErr: true,
- },
- {
- description: "invalid log identifier",
- logId: &namespace.Namespace{
- Format: namespace.NamespaceFormatEd25519V1,
- NamespaceEd25519V1: &namespace.NamespaceEd25519V1{
- Namespace: make([]byte, 31), // too short
- },
+// newTestInstances sets up a test instance that uses default log parameters
+// with an optional signer, see newLogParameters() for further details. The
+// SthSource is instantiated with an ActiveSthSource that has (i) the default
+// STH as the currently cosigned STH based on testdata.Ed25519VkWitness, and
+// (ii) the default STH without any cosignatures as the currently stable STH.
+func newTestInstance(t *testing.T, signer crypto.Signer) *testInstance {
+ t.Helper()
+ ctrl := gomock.NewController(t)
+ client := mockclient.NewMockTrillianLogClient(ctrl)
+ return &testInstance{
+ ctrl: ctrl,
+ client: client,
+ instance: &Instance{
+ Client: client,
+ LogParameters: newLogParameters(t, signer),
+ SthSource: &ActiveSthSource{
+ client: client,
+ logParameters: newLogParameters(t, signer),
+ currCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}),
+ nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil),
+ cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool),
},
- maxRange: testMaxRange,
- signer: signer,
- wantErr: true,
},
- {
- description: "valid log parameters",
- logId: testLogId,
- maxRange: testMaxRange,
- signer: signer,
- },
- } {
- lp, err := NewLogParameters(table.signer, table.logId, testTreeId, testPrefix, namespaces, witnesses, table.maxRange, testInterval, testDeadline)
- if got, want := err != nil, table.wantErr; got != want {
- t.Errorf("got error=%v but wanted %v in test %q: %v", got, want, table.description, err)
- }
- if err != nil {
- continue
- }
- lid, err := table.logId.Marshal()
- if err != nil {
- t.Fatalf("must marshal log id: %v", err)
- }
+ }
+}
- if got, want := lp.LogId, lid; !bytes.Equal(got, want) {
- t.Errorf("got log id %X but wanted %X in test %q", got, want, table.description)
- }
- if got, want := lp.TreeId, testTreeId; got != want {
- t.Errorf("got tree id %d but wanted %d in test %q", got, want, table.description)
- }
- if got, want := lp.Prefix, testPrefix; got != want {
- t.Errorf("got prefix %s but wanted %s in test %q", got, want, table.description)
- }
- if got, want := lp.MaxRange, testMaxRange; got != want {
- t.Errorf("got max range %d but wanted %d in test %q", got, want, table.description)
- }
- if got, want := len(lp.Submitters.List()), len(namespaces.List()); got != want {
- t.Errorf("got %d anchors but wanted %d in test %q", got, want, table.description)
- }
- if got, want := len(lp.Witnesses.List()), len(witnesses.List()); got != want {
- t.Errorf("got %d anchors but wanted %d in test %q", got, want, table.description)
- }
+// getHandlers returns all endpoints that use HTTP GET as a map to handlers
+func (ti *testInstance) getHandlers(t *testing.T) map[Endpoint]Handler {
+ t.Helper()
+ return map[Endpoint]Handler{
+ EndpointGetLatestSth: Handler{Instance: ti.instance, Handler: getLatestSth, Endpoint: EndpointGetLatestSth, Method: http.MethodGet},
+ EndpointGetStableSth: Handler{Instance: ti.instance, Handler: getStableSth, Endpoint: EndpointGetStableSth, Method: http.MethodGet},
+ EndpointGetCosignedSth: Handler{Instance: ti.instance, Handler: getCosignedSth, Endpoint: EndpointGetCosignedSth, Method: http.MethodGet},
+ }
+}
+
+// postHandlers returns all endpoints that use HTTP POST as a map to handlers
+func (ti *testInstance) postHandlers(t *testing.T) map[Endpoint]Handler {
+ t.Helper()
+ return map[Endpoint]Handler{
+ EndpointAddEntry: Handler{Instance: ti.instance, Handler: addEntry, Endpoint: EndpointAddEntry, Method: http.MethodPost},
+ EndpointAddCosignature: Handler{Instance: ti.instance, Handler: addCosignature, Endpoint: EndpointAddCosignature, Method: http.MethodPost},
+ EndpointGetConsistencyProof: Handler{Instance: ti.instance, Handler: getConsistencyProof, Endpoint: EndpointGetConsistencyProof, Method: http.MethodPost},
+ EndpointGetProofByHash: Handler{Instance: ti.instance, Handler: getProofByHash, Endpoint: EndpointGetProofByHash, Method: http.MethodPost},
+ EndpointGetEntries: Handler{Instance: ti.instance, Handler: getEntries, Endpoint: EndpointGetEntries, Method: http.MethodPost},
+ }
+}
+
+// getHandler must return a particular HTTP GET handler
+func (ti *testInstance) getHandler(t *testing.T, endpoint Endpoint) Handler {
+ t.Helper()
+ handler, ok := ti.getHandlers(t)[endpoint]
+ if !ok {
+ t.Fatalf("must return HTTP GET handler for endpoint: %s", endpoint)
+ }
+ return handler
+}
+
+// postHandler must return a particular HTTP POST handler
+func (ti *testInstance) postHandler(t *testing.T, endpoint Endpoint) Handler {
+ t.Helper()
+ handler, ok := ti.postHandlers(t)[endpoint]
+ if !ok {
+ t.Fatalf("must return HTTP POST handler for endpoint: %s", endpoint)
}
+ return handler
}
// TestHandlers checks that we configured all endpoints and that there are no
@@ -123,21 +92,20 @@ func TestNewLogParameters(t *testing.T) {
func TestHandlers(t *testing.T) {
endpoints := map[Endpoint]bool{
EndpointAddEntry: false,
- EndpointGetEntries: false,
+ EndpointAddCosignature: false,
EndpointGetLatestSth: false,
- EndpointGetProofByHash: false,
- EndpointGetConsistencyProof: false,
- EndpointGetAnchors: false,
EndpointGetStableSth: false,
EndpointGetCosignedSth: false,
- EndpointAddCosignature: false,
+ EndpointGetConsistencyProof: false,
+ EndpointGetProofByHash: false,
+ EndpointGetEntries: false,
}
- i := NewInstance(makeTestLogParameters(t, nil), nil, nil)
+ i := &Instance{nil, newLogParameters(t, nil), nil}
for _, handler := range i.Handlers() {
- if _, ok := endpoints[handler.endpoint]; !ok {
- t.Errorf("got unexpected endpoint: %s", handler.endpoint)
+ if _, ok := endpoints[handler.Endpoint]; !ok {
+ t.Errorf("got unexpected endpoint: %s", handler.Endpoint)
}
- endpoints[handler.endpoint] = true
+ endpoints[handler.Endpoint] = true
}
for endpoint, ok := range endpoints {
if !ok {
@@ -146,104 +114,42 @@ func TestHandlers(t *testing.T) {
}
}
-// TestEndpointPath checks that the endpoint path builder works as expected
-func TestEndpointPath(t *testing.T) {
- base, prefix := "http://example.com", "test"
- for _, table := range []struct {
- endpoint Endpoint
- want string
- }{
- {
- endpoint: EndpointAddEntry,
- want: "http://example.com/test/add-entry",
- },
- {
- endpoint: EndpointGetEntries,
- want: "http://example.com/test/get-entries",
- },
- {
- endpoint: EndpointGetProofByHash,
- want: "http://example.com/test/get-proof-by-hash",
- },
- {
- endpoint: EndpointGetConsistencyProof,
- want: "http://example.com/test/get-consistency-proof",
- },
- {
- endpoint: EndpointGetLatestSth,
- want: "http://example.com/test/get-latest-sth",
- },
- {
- endpoint: EndpointGetAnchors,
- want: "http://example.com/test/get-anchors",
- },
- {
- endpoint: EndpointGetStableSth,
- want: "http://example.com/test/get-stable-sth",
- },
- {
- endpoint: EndpointGetCosignedSth,
- want: "http://example.com/test/get-cosigned-sth",
- },
- {
- endpoint: EndpointAddCosignature,
- want: "http://example.com/test/add-cosignature",
- },
- } {
- if got, want := table.endpoint.Path(base, prefix), table.want; got != want {
- t.Errorf("got %s but wanted %s with multiple components", got, want)
- }
- if got, want := table.endpoint.Path(base+"/"+prefix), table.want; got != want {
- t.Errorf("got %s but wanted %s with one component", got, want)
- }
- }
-}
+// TestGetHandlersRejectPost checks that all get handlers reject post requests
+func TestGetHandlersRejectPost(t *testing.T) {
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
-func mustNewLogId(t *testing.T, namespace *namespace.Namespace) []byte {
- b, err := namespace.Marshal()
- if err != nil {
- t.Fatalf("must marshal log id: %v", err)
- }
- return b
-}
+ for endpoint, handler := range ti.getHandlers(t) {
+ t.Run(string(endpoint), func(t *testing.T) {
+ s := httptest.NewServer(handler)
+ defer s.Close()
-func mustNewNamespaceEd25519V1(t *testing.T, vk []byte) *namespace.Namespace {
- namespace, err := namespace.NewNamespaceEd25519V1(vk)
- if err != nil {
- t.Fatalf("must make ed25519 namespace: %v", err)
+ url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
+ if rsp, err := http.Post(url, "application/json", nil); err != nil {
+ t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err)
+ } else if rsp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
+ }
+ })
}
- return namespace
}
-func mustNewNamespacePool(t *testing.T, anchors []*namespace.Namespace) *namespace.NamespacePool {
- namespaces, err := namespace.NewNamespacePool(anchors)
- if err != nil {
- t.Fatalf("must make namespaces: %v", err)
- }
- return namespaces
-}
+// TestPostHandlersRejectGet checks that all post handlers reject get requests
+func TestPostHandlersRejectGet(t *testing.T) {
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ for endpoint, handler := range ti.postHandlers(t) {
+ t.Run(string(endpoint), func(t *testing.T) {
+ s := httptest.NewServer(handler)
+ defer s.Close()
-// makeTestLogParameters makes a collection of test log parameters.
-//
-// The log's identity is based on testdata.Ed25519{Vk3,Sk3}. The log's accepted
-// submitters are based on testdata.Ed25519Vk. The log's accepted witnesses are
-// based on testdata.Ed25519Vk. The remaining log parameters are based on the
-// global test* variables in this file.
-//
-// For convenience the passed signer is optional (i.e., it may be nil).
-func makeTestLogParameters(t *testing.T, signer crypto.Signer) *LogParameters {
- return &LogParameters{
- LogId: mustNewLogId(t, mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk3)),
- TreeId: testTreeId,
- Prefix: testPrefix,
- MaxRange: testMaxRange,
- Submitters: mustNewNamespacePool(t, []*namespace.Namespace{
- mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk),
- }),
- Witnesses: mustNewNamespacePool(t, []*namespace.Namespace{
- mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk),
- }),
- Signer: signer,
- HashType: testHashType,
+ url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
+ if rsp, err := http.Get(url); err != nil {
+ t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err)
+ } else if rsp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
+ }
+ })
}
}