aboutsummaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-06-06 15:59:35 +0200
committerRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-06-06 15:59:35 +0200
commit47bacfd5c5d22470340e0823c4ad37b45914b68e (patch)
tree0984b55ced1e537e8d1d681c77157f1a47e47cdc /pkg
parent0285454c34b0b3003bc8ede3e304b843ad949be8 (diff)
started using the refactored packages in siglog server
Diffstat (limited to 'pkg')
-rw-r--r--pkg/instance/endpoint.go122
-rw-r--r--pkg/instance/endpoint_test.go480
-rw-r--r--pkg/instance/instance.go90
-rw-r--r--pkg/instance/instance_test.go158
-rw-r--r--pkg/instance/metric.go23
-rw-r--r--pkg/instance/request.go77
-rw-r--r--pkg/instance/request_test.go318
-rw-r--r--pkg/mocks/crypto.go23
-rw-r--r--pkg/mocks/stfe.go110
-rw-r--r--pkg/mocks/trillian.go317
-rw-r--r--pkg/state/state_manager.go154
-rw-r--r--pkg/state/state_manager_test.go393
-rw-r--r--pkg/trillian/client.go178
-rw-r--r--pkg/trillian/client_test.go533
-rw-r--r--pkg/trillian/util.go33
-rw-r--r--pkg/types/ascii.go421
-rw-r--r--pkg/types/ascii_test.go465
-rw-r--r--pkg/types/trunnel.go60
-rw-r--r--pkg/types/trunnel_test.go114
-rw-r--r--pkg/types/types.go155
-rw-r--r--pkg/types/types_test.go58
-rw-r--r--pkg/types/util.go21
22 files changed, 4303 insertions, 0 deletions
diff --git a/pkg/instance/endpoint.go b/pkg/instance/endpoint.go
new file mode 100644
index 0000000..5085c49
--- /dev/null
+++ b/pkg/instance/endpoint.go
@@ -0,0 +1,122 @@
+package stfe
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/golang/glog"
+)
+
+func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling add-entry request")
+ req, err := i.leafRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+ if err := i.Client.AddLeaf(ctx, req); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func addCosignature(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling add-cosignature request")
+ req, err := i.cosignatureRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+ vk := i.Witnesses[*req.KeyHash]
+ if err := i.Stateman.AddCosignature(ctx, &vk, req.Signature); err != nil {
+ return http.StatusBadRequest, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadLatest(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-latest request")
+ sth, err := i.Stateman.Latest(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadToSign(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-to-sign request")
+ sth, err := i.Stateman.ToSign(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadCosigned(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-cosigned request")
+ sth, err := i.Stateman.Cosigned(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-consistency-proof request")
+ req, err := i.consistencyProofRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ proof, err := i.Client.GetConsistencyProof(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := proof.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getInclusionProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-proof-by-hash request")
+ req, err := i.inclusionProofRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ proof, err := i.Client.GetInclusionProof(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := proof.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getLeaves(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-leaves request")
+ req, err := i.leavesRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ leaves, err := i.Client.GetLeaves(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ for _, leaf := range *leaves {
+ if err := leaf.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ }
+ return http.StatusOK, nil
+}
diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go
new file mode 100644
index 0000000..8511b8d
--- /dev/null
+++ b/pkg/instance/endpoint_test.go
@@ -0,0 +1,480 @@
+package stfe
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ cttestdata "github.com/google/certificate-transparency-go/trillian/testdata"
+ "github.com/google/trillian"
+ "github.com/system-transparency/stfe/pkg/testdata"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+func TestEndpointAddEntry(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ trsp *trillian.QueueLeafResponse
+ terr error
+ wantCode int
+ }{
+ {
+ description: "invalid: bad request: empty",
+ breq: bytes.NewBuffer(nil),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad Trillian response: error",
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter),
+ terr: fmt.Errorf("backend failure"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter),
+ trsp: testdata.DefaultTQlr(t, false),
+ wantCode: http.StatusOK,
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ url := EndpointAddEntry.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+ if table.trsp != nil || table.terr != nil {
+ ti.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr)
+ }
+
+ w := httptest.NewRecorder()
+ ti.postHandler(t, EndpointAddEntry).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointAddCosignature(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ wantCode int
+ }{
+ {
+ description: "invalid: bad request: empty",
+ breq: bytes.NewBuffer(nil),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: signed wrong sth", // newLogParameters() use testdata.Ed25519VkLog as default
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog2), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "valid",
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness),
+ wantCode: http.StatusOK,
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ url := EndpointAddCosignature.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ w := httptest.NewRecorder()
+ ti.postHandler(t, EndpointAddCosignature).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetLatestSth(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ trsp *trillian.GetLatestSignedLogRootResponse
+ terr error
+ wantCode int
+ wantItem *types.StItem
+ }{
+ {
+ description: "backend failure",
+ terr: fmt.Errorf("backend failure"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ trsp: testdata.DefaultTSlr(t),
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, cttestdata.NewSignerWithFixedSig(nil, testdata.Signature))
+ ti.ctrl.Finish()
+
+ // Setup and run client query
+ url := EndpointGetLatestSth.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ if table.trsp != nil || table.terr != nil {
+ ti.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr)
+ }
+
+ w := httptest.NewRecorder()
+ ti.getHandler(t, EndpointGetLatestSth).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItem
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetStableSth(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ useBadSource bool
+ wantCode int
+ wantItem *types.StItem
+ }{
+ {
+ description: "invalid: sth source failure",
+ useBadSource: true,
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ ti.ctrl.Finish()
+ if table.useBadSource {
+ ti.instance.SthSource = &ActiveSthSource{}
+ }
+
+ // Setup and run client query
+ url := EndpointGetStableSth.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ ti.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItem
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetCosignedSth(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ useBadSource bool
+ wantCode int
+ wantItem *types.StItem
+ }{
+ {
+ description: "invalid: sth source failure",
+ useBadSource: true,
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ ti.ctrl.Finish()
+ if table.useBadSource {
+ ti.instance.SthSource = &ActiveSthSource{}
+ }
+
+ // Setup and run client query
+ url := EndpointGetCosignedSth.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ ti.getHandler(t, EndpointGetCosignedSth).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItem
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetProofByHash(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ trsp *trillian.GetInclusionProofByHashResponse
+ terr error
+ wantCode int
+ wantItem *types.StItem
+ }{
+ {
+ description: "invalid: bad request: empty",
+ breq: bytes.NewBuffer(nil),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad Trillian response: error",
+ breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})),
+ terr: fmt.Errorf("backend failure"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})),
+ trsp: testdata.DefaultTGipbhr(t),
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultInclusionProof(t, 1),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ url := EndpointGetProofByHash.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+ if table.trsp != nil || table.terr != nil {
+ ti.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr)
+ }
+
+ w := httptest.NewRecorder()
+ ti.postHandler(t, EndpointGetProofByHash).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItem
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetConsistencyProof(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ trsp *trillian.GetConsistencyProofResponse
+ terr error
+ wantCode int
+ wantItem *types.StItem
+ }{
+ {
+ description: "invalid: bad request: empty",
+ breq: bytes.NewBuffer(nil),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad Trillian response: error",
+ breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})),
+ terr: fmt.Errorf("backend failure"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})),
+ trsp: testdata.DefaultTGcpr(t),
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultConsistencyProof(t, 1, 2),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ url := EndpointGetConsistencyProof.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+ if table.trsp != nil || table.terr != nil {
+ ti.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr)
+ }
+
+ w := httptest.NewRecorder()
+ ti.postHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItem
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestEndpointGetEntriesV1(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ trsp *trillian.GetLeavesByRangeResponse
+ terr error
+ wantCode int
+ wantItem *types.StItemList
+ }{
+ {
+ description: "invalid: bad request: empty",
+ breq: bytes.NewBuffer(nil),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad Trillian response: error",
+ breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: 0})),
+ terr: fmt.Errorf("backend failure"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid", // remember that newLogParameters() have testdata.MaxRange configured
+ breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange - 1)})),
+ trsp: testdata.DefaultTGlbrr(t, 0, testdata.MaxRange-1),
+ wantCode: http.StatusOK,
+ wantItem: testdata.DefaultStItemList(t, 0, uint64(testdata.MaxRange)-1),
+ },
+ } {
+ func() { // run deferred functions at the end of each iteration
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ url := EndpointGetEntries.Path("http://example.com", ti.instance.LogParameters.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+ if table.trsp != nil || table.terr != nil {
+ ti.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr)
+ }
+
+ w := httptest.NewRecorder()
+ ti.postHandler(t, EndpointGetEntries).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description)
+ }
+ if w.Code != http.StatusOK {
+ return
+ }
+
+ var item types.StItemList
+ if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil {
+ t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err)
+ }
+ if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) {
+ t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+// TODO: TestWriteOctetResponse
+func TestWriteOctetResponse(t *testing.T) {
+}
+
+// deadlineMatcher implements gomock.Matcher, such that an error is raised if
+// there is no context.Context deadline set
+type deadlineMatcher struct{}
+
+// newDeadlineMatcher returns a new DeadlineMatcher
+func newDeadlineMatcher() gomock.Matcher {
+ return &deadlineMatcher{}
+}
+
+// Matches returns true if the passed interface is a context with a deadline
+func (dm *deadlineMatcher) Matches(i interface{}) bool {
+ ctx, ok := i.(context.Context)
+ if !ok {
+ return false
+ }
+ _, ok = ctx.Deadline()
+ return ok
+}
+
+// String is needed to implement gomock.Matcher
+func (dm *deadlineMatcher) String() string {
+ return fmt.Sprintf("deadlineMatcher{}")
+}
diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go
new file mode 100644
index 0000000..3441a0a
--- /dev/null
+++ b/pkg/instance/instance.go
@@ -0,0 +1,90 @@
+package stfe
+
+import (
+ "context"
+ "crypto"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/system-transparency/stfe/pkg/state"
+ "github.com/system-transparency/stfe/pkg/trillian"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+// Config is a collection of log parameters
+type Config struct {
+ LogID string // H(public key), then hex-encoded
+ TreeID int64 // Merkle tree identifier used by Trillian
+ Prefix string // The portion between base URL and st/v0 (may be "")
+ MaxRange int64 // Maximum number of leaves per get-leaves request
+ Deadline time.Duration // Deadline used for gRPC requests
+ Interval time.Duration // Cosigning frequency
+
+ // Witnesses map trusted witness identifiers to public verification keys
+ Witnesses map[[types.HashSize]byte][types.VerificationKeySize]byte
+}
+
+// Instance is an instance of the log's front-end
+type Instance struct {
+ Config // configuration parameters
+ Client trillian.Client // provides access to the Trillian backend
+ Signer crypto.Signer // provides access to Ed25519 private key
+ Stateman state.StateManager // coordinates access to (co)signed tree heads
+}
+
+// Handler implements the http.Handler interface, and contains a reference
+// to an STFE server instance as well as a function that uses it.
+type Handler struct {
+ Instance *Instance
+ Endpoint types.Endpoint
+ Method string
+ Handler func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error)
+}
+
+// Handlers returns a list of STFE handlers
+func (i *Instance) Handlers() []Handler {
+ return []Handler{
+ Handler{Instance: i, Handler: addLeaf, Endpoint: types.EndpointAddLeaf, Method: http.MethodPost},
+ Handler{Instance: i, Handler: addCosignature, Endpoint: types.EndpointAddCosignature, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getTreeHeadLatest, Endpoint: types.EndpointGetTreeHeadLatest, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getTreeHeadToSign, Endpoint: types.EndpointGetTreeHeadToSign, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getTreeHeadCosigned, Endpoint: types.EndpointGetTreeHeadCosigned, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getConsistencyProof, Endpoint: types.EndpointGetConsistencyProof, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getInclusionProof, Endpoint: types.EndpointGetProofByHash, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getLeaves, Endpoint: types.EndpointGetLeaves, Method: http.MethodPost},
+ }
+}
+
+// Path returns a path that should be configured for this handler
+func (h Handler) Path() string {
+ return h.Endpoint.Path(h.Instance.Prefix, "st", "v0")
+}
+
+// ServeHTTP is part of the http.Handler interface
+func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // export prometheus metrics
+ var now time.Time = time.Now()
+ var statusCode int
+ defer func() {
+ rspcnt.Inc(a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode))
+ latency.Observe(time.Now().Sub(now).Seconds(), a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode))
+ }()
+ reqcnt.Inc(a.Instance.LogID, string(a.Endpoint))
+
+ ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline))
+ defer cancel()
+
+ if r.Method != a.Method {
+ glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method)
+ http.Error(w, "", http.StatusMethodNotAllowed)
+ return
+ }
+
+ statusCode, err := a.Handler(ctx, a.Instance, w, r)
+ if err != nil {
+ glog.Warningf("handler error %s/%s: %v", a.Instance.Prefix, a.Endpoint, err)
+ http.Error(w, fmt.Sprintf("%s%s%s%s", "Error", types.Delim, err.Error(), types.EOL), statusCode)
+ }
+}
diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go
new file mode 100644
index 0000000..a7a3d8a
--- /dev/null
+++ b/pkg/instance/instance_test.go
@@ -0,0 +1,158 @@
+package stfe
+
+import (
+ "crypto"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/google/certificate-transparency-go/trillian/mockclient"
+ "github.com/system-transparency/stfe/pkg/testdata"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+type testInstance struct {
+ ctrl *gomock.Controller
+ client *mockclient.MockTrillianLogClient
+ instance *Instance
+}
+
+// newTestInstances sets up a test instance that uses default log parameters
+// with an optional signer, see newLogParameters() for further details. The
+// SthSource is instantiated with an ActiveSthSource that has (i) the default
+// STH as the currently cosigned STH based on testdata.Ed25519VkWitness, and
+// (ii) the default STH without any cosignatures as the currently stable STH.
+func newTestInstance(t *testing.T, signer crypto.Signer) *testInstance {
+ t.Helper()
+ ctrl := gomock.NewController(t)
+ client := mockclient.NewMockTrillianLogClient(ctrl)
+ return &testInstance{
+ ctrl: ctrl,
+ client: client,
+ instance: &Instance{
+ Client: client,
+ LogParameters: newLogParameters(t, signer),
+ SthSource: &ActiveSthSource{
+ client: client,
+ logParameters: newLogParameters(t, signer),
+ currCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}),
+ nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil),
+ cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool),
+ },
+ },
+ }
+}
+
+// getHandlers returns all endpoints that use HTTP GET as a map to handlers
+func (ti *testInstance) getHandlers(t *testing.T) map[Endpoint]Handler {
+ t.Helper()
+ return map[Endpoint]Handler{
+ EndpointGetLatestSth: Handler{Instance: ti.instance, Handler: getLatestSth, Endpoint: EndpointGetLatestSth, Method: http.MethodGet},
+ EndpointGetStableSth: Handler{Instance: ti.instance, Handler: getStableSth, Endpoint: EndpointGetStableSth, Method: http.MethodGet},
+ EndpointGetCosignedSth: Handler{Instance: ti.instance, Handler: getCosignedSth, Endpoint: EndpointGetCosignedSth, Method: http.MethodGet},
+ }
+}
+
+// postHandlers returns all endpoints that use HTTP POST as a map to handlers
+func (ti *testInstance) postHandlers(t *testing.T) map[Endpoint]Handler {
+ t.Helper()
+ return map[Endpoint]Handler{
+ EndpointAddEntry: Handler{Instance: ti.instance, Handler: addEntry, Endpoint: EndpointAddEntry, Method: http.MethodPost},
+ EndpointAddCosignature: Handler{Instance: ti.instance, Handler: addCosignature, Endpoint: EndpointAddCosignature, Method: http.MethodPost},
+ EndpointGetConsistencyProof: Handler{Instance: ti.instance, Handler: getConsistencyProof, Endpoint: EndpointGetConsistencyProof, Method: http.MethodPost},
+ EndpointGetProofByHash: Handler{Instance: ti.instance, Handler: getProofByHash, Endpoint: EndpointGetProofByHash, Method: http.MethodPost},
+ EndpointGetEntries: Handler{Instance: ti.instance, Handler: getEntries, Endpoint: EndpointGetEntries, Method: http.MethodPost},
+ }
+}
+
+// getHandler must return a particular HTTP GET handler
+func (ti *testInstance) getHandler(t *testing.T, endpoint Endpoint) Handler {
+ t.Helper()
+ handler, ok := ti.getHandlers(t)[endpoint]
+ if !ok {
+ t.Fatalf("must return HTTP GET handler for endpoint: %s", endpoint)
+ }
+ return handler
+}
+
+// postHandler must return a particular HTTP POST handler
+func (ti *testInstance) postHandler(t *testing.T, endpoint Endpoint) Handler {
+ t.Helper()
+ handler, ok := ti.postHandlers(t)[endpoint]
+ if !ok {
+ t.Fatalf("must return HTTP POST handler for endpoint: %s", endpoint)
+ }
+ return handler
+}
+
+// TestHandlers checks that we configured all endpoints and that there are no
+// unexpected ones.
+func TestHandlers(t *testing.T) {
+ endpoints := map[Endpoint]bool{
+ EndpointAddEntry: false,
+ EndpointAddCosignature: false,
+ EndpointGetLatestSth: false,
+ EndpointGetStableSth: false,
+ EndpointGetCosignedSth: false,
+ EndpointGetConsistencyProof: false,
+ EndpointGetProofByHash: false,
+ EndpointGetEntries: false,
+ }
+ i := &Instance{nil, newLogParameters(t, nil), nil}
+ for _, handler := range i.Handlers() {
+ if _, ok := endpoints[handler.Endpoint]; !ok {
+ t.Errorf("got unexpected endpoint: %s", handler.Endpoint)
+ }
+ endpoints[handler.Endpoint] = true
+ }
+ for endpoint, ok := range endpoints {
+ if !ok {
+ t.Errorf("endpoint %s is not configured", endpoint)
+ }
+ }
+}
+
+// TestGetHandlersRejectPost checks that all get handlers reject post requests
+func TestGetHandlersRejectPost(t *testing.T) {
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ for endpoint, handler := range ti.getHandlers(t) {
+ t.Run(string(endpoint), func(t *testing.T) {
+ s := httptest.NewServer(handler)
+ defer s.Close()
+
+ url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
+ if rsp, err := http.Post(url, "application/json", nil); err != nil {
+ t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err)
+ } else if rsp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
+ }
+ })
+ }
+}
+
+// TestPostHandlersRejectGet checks that all post handlers reject get requests
+func TestPostHandlersRejectGet(t *testing.T) {
+ ti := newTestInstance(t, nil)
+ defer ti.ctrl.Finish()
+
+ for endpoint, handler := range ti.postHandlers(t) {
+ t.Run(string(endpoint), func(t *testing.T) {
+ s := httptest.NewServer(handler)
+ defer s.Close()
+
+ url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
+ if rsp, err := http.Get(url); err != nil {
+ t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err)
+ } else if rsp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
+ }
+ })
+ }
+}
+
+// TODO: TestHandlerPath
+func TestHandlerPath(t *testing.T) {
+}
diff --git a/pkg/instance/metric.go b/pkg/instance/metric.go
new file mode 100644
index 0000000..7e3e8b2
--- /dev/null
+++ b/pkg/instance/metric.go
@@ -0,0 +1,23 @@
+package stfe
+
+import (
+ "github.com/google/trillian/monitoring"
+ "github.com/google/trillian/monitoring/prometheus"
+)
+
+var (
+ reqcnt monitoring.Counter // number of incoming http requests
+ rspcnt monitoring.Counter // number of valid http responses
+ latency monitoring.Histogram // request-response latency
+ lastSthTimestamp monitoring.Gauge // unix timestamp from the most recent sth
+ lastSthSize monitoring.Gauge // tree size of most recent sth
+)
+
+func init() {
+ mf := prometheus.MetricFactory{}
+ reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint")
+ rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status")
+ latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status")
+ lastSthTimestamp = mf.NewGauge("last_sth_timestamp", "unix timestamp while handling the most recent sth", "logid")
+ lastSthSize = mf.NewGauge("last_sth_size", "most recent sth tree size", "logid")
+}
diff --git a/pkg/instance/request.go b/pkg/instance/request.go
new file mode 100644
index 0000000..7475b26
--- /dev/null
+++ b/pkg/instance/request.go
@@ -0,0 +1,77 @@
+package stfe
+
+import (
+ "crypto/ed25519"
+ "fmt"
+ "net/http"
+
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, error) {
+ var req types.LeafRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+
+ vk := ed25519.PublicKey(req.VerificationKey[:])
+ msg := req.Message.Marshal()
+ sig := req.Signature[:]
+ if !ed25519.Verify(vk, msg, sig) {
+ return nil, fmt.Errorf("invalid signature")
+ }
+ // TODO: check shard hint
+ // TODO: check domain hint
+ return &req, nil
+}
+
+func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) {
+ var req types.CosignatureRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("unpackOctetPost: %v", err)
+ }
+ if _, ok := i.Witnesses[*req.KeyHash]; !ok {
+ return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash)
+ }
+ return &req, nil
+}
+
+func (i *Instance) consistencyProofRequestFromHTTP(r *http.Request) (*types.ConsistencyProofRequest, error) {
+ var req types.ConsistencyProofRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+ if req.OldSize < 1 {
+ return nil, fmt.Errorf("OldSize(%d) must be larger than zero", req.OldSize)
+ }
+ if req.NewSize <= req.OldSize {
+ return nil, fmt.Errorf("NewSize(%d) must be larger than OldSize(%d)", req.NewSize, req.OldSize)
+ }
+ return &req, nil
+}
+
+func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.InclusionProofRequest, error) {
+ var req types.InclusionProofRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+ if req.TreeSize < 1 {
+ return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize)
+ }
+ return &req, nil
+}
+
+func (i *Instance) leavesRequestFromHTTP(r *http.Request) (*types.LeavesRequest, error) {
+ var req types.LeavesRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+
+ if req.StartSize > req.EndSize {
+ return nil, fmt.Errorf("StartSize(%d) must be less than or equal to EndSize(%d)", req.StartSize, req.EndSize)
+ }
+ if req.EndSize-req.StartSize+1 > uint64(i.MaxRange) {
+ req.EndSize = req.StartSize + uint64(i.MaxRange) - 1
+ }
+ return &req, nil
+}
diff --git a/pkg/instance/request_test.go b/pkg/instance/request_test.go
new file mode 100644
index 0000000..0a5a908
--- /dev/null
+++ b/pkg/instance/request_test.go
@@ -0,0 +1,318 @@
+package stfe
+
+import (
+ "bytes"
+ //"fmt"
+ "reflect"
+ "testing"
+ //"testing/iotest"
+
+ "net/http"
+
+ "github.com/system-transparency/stfe/pkg/testdata"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+func TestParseAddEntryV1Request(t *testing.T) {
+ lp := newLogParameters(t, nil)
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ wantErr bool
+ }{
+ {
+ description: "invalid: nothing to unpack",
+ breq: bytes.NewBuffer(nil),
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a signed checksum entry",
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness),
+ wantErr: true,
+ },
+ {
+ description: "invalid: untrusted submitter", // only testdata.Ed25519VkSubmitter is registered by default in newLogParameters()
+
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter2),
+ wantErr: true,
+ },
+ {
+ description: "invalid: signature does not cover message",
+
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter),
+ }, // TODO: add test case that disables submitter policy (i.e., unregistered namespaces are accepted)
+ } {
+ url := EndpointAddEntry.Path("http://example.com", lp.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ _, err = lp.parseAddEntryV1Request(req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }
+}
+
+func TestParseAddCosignatureV1Request(t *testing.T) {
+ lp := newLogParameters(t, nil)
+ for _, table := range []struct {
+ description string
+ breq *bytes.Buffer
+ wantErr bool
+ }{
+ {
+ description: "invalid: nothing to unpack",
+ breq: bytes.NewBuffer(nil),
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a cosigned sth",
+ breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no cosignature",
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, nil),
+ wantErr: true,
+ },
+ {
+ description: "invalid: untrusted witness", // only testdata.Ed25519VkWitness is registered by default in newLogParameters()
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness2),
+ wantErr: true,
+ },
+ {
+ description: "invalid: signature does not cover message",
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness),
+ }, // TODO: add test case that disables witness policy (i.e., unregistered namespaces are accepted)
+ } {
+ url := EndpointAddCosignature.Path("http://example.com", lp.Prefix)
+ req, err := http.NewRequest("POST", url, table.breq)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ _, err = lp.parseAddCosignatureV1Request(req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }
+}
+
+func TestNewGetConsistencyProofRequest(t *testing.T) {
+ lp := newLogParameters(t, nil)
+ for _, table := range []struct {
+ description string
+ req *types.GetConsistencyProofV1
+ wantErr bool
+ }{
+ {
+ description: "invalid: nothing to unpack",
+ req: nil,
+ wantErr: true,
+ },
+ {
+ description: "invalid: first must be larger than zero",
+ req: &types.GetConsistencyProofV1{First: 0, Second: 0},
+ wantErr: true,
+ },
+ {
+ description: "invalid: second must be larger than first",
+ req: &types.GetConsistencyProofV1{First: 2, Second: 1},
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: &types.GetConsistencyProofV1{First: 1, Second: 2},
+ },
+ } {
+ var buf *bytes.Buffer
+ if table.req == nil {
+ buf = bytes.NewBuffer(nil)
+ } else {
+ buf = bytes.NewBuffer(marshal(t, *table.req))
+ }
+
+ url := EndpointGetConsistencyProof.Path("http://example.com", lp.Prefix)
+ req, err := http.NewRequest("POST", url, buf)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ _, err = lp.parseGetConsistencyProofV1Request(req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }
+}
+
+func TestNewGetProofByHashRequest(t *testing.T) {
+ lp := newLogParameters(t, nil)
+ for _, table := range []struct {
+ description string
+ req *types.GetProofByHashV1
+ wantErr bool
+ }{
+ {
+ description: "invalid: nothing to unpack",
+ req: nil,
+ wantErr: true,
+ },
+ {
+ description: "invalid: no entry in an empty tree",
+ req: &types.GetProofByHashV1{TreeSize: 0, Hash: testdata.LeafHash},
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: &types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash},
+ },
+ } {
+ var buf *bytes.Buffer
+ if table.req == nil {
+ buf = bytes.NewBuffer(nil)
+ } else {
+ buf = bytes.NewBuffer(marshal(t, *table.req))
+ }
+
+ url := EndpointGetProofByHash.Path("http://example.com", lp.Prefix)
+ req, err := http.NewRequest("POST", url, buf)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ _, err = lp.parseGetProofByHashV1Request(req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }
+}
+
+func TestParseGetEntriesV1Request(t *testing.T) {
+ lp := newLogParameters(t, nil)
+ for _, table := range []struct {
+ description string
+ req *types.GetEntriesV1
+ wantErr bool
+ wantReq *types.GetEntriesV1
+ }{
+ {
+ description: "invalid: nothing to unpack",
+ req: nil,
+ wantErr: true,
+ },
+ {
+ description: "invalid: start must be larger than end",
+ req: &types.GetEntriesV1{Start: 1, End: 0},
+ wantErr: true,
+ },
+ {
+ description: "valid: want truncated range",
+ req: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange)},
+ wantReq: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange) - 1},
+ },
+ {
+ description: "valid",
+ req: &types.GetEntriesV1{Start: 0, End: 0},
+ wantReq: &types.GetEntriesV1{Start: 0, End: 0},
+ },
+ } {
+ var buf *bytes.Buffer
+ if table.req == nil {
+ buf = bytes.NewBuffer(nil)
+ } else {
+ buf = bytes.NewBuffer(marshal(t, *table.req))
+ }
+
+ url := EndpointGetEntries.Path("http://example.com", lp.Prefix)
+ req, err := http.NewRequest("POST", url, buf)
+ if err != nil {
+ t.Fatalf("failed creating http request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/octet-stream")
+
+ output, err := lp.parseGetEntriesV1Request(req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := output, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got request\n%v\n\tbut wanted\n%v\n\t in test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestUnpackOctetPost(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ req *http.Request
+ out interface{}
+ wantErr bool
+ }{
+ //{
+ // description: "invalid: cannot read request body",
+ // req: func() *http.Request {
+ // req, err := http.NewRequest(http.MethodPost, "", iotest.ErrReader(fmt.Errorf("bad reader")))
+ // if err != nil {
+ // t.Fatalf("must make new http request: %v", err)
+ // }
+ // return req
+ // }(),
+ // out: &types.StItem{},
+ // wantErr: true,
+ //}, // testcase requires Go 1.16
+ {
+ description: "invalid: cannot unmarshal",
+ req: func() *http.Request {
+ req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer(nil))
+ if err != nil {
+ t.Fatalf("must make new http request: %v", err)
+ }
+ return req
+ }(),
+ out: &types.StItem{},
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: func() *http.Request {
+ req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer([]byte{0}))
+ if err != nil {
+ t.Fatalf("must make new http request: %v", err)
+ }
+ return req
+ }(),
+ out: &struct{ SomeUint8 uint8 }{},
+ },
+ } {
+ err := unpackOctetPost(table.req, table.out)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q", got, want, table.description)
+ }
+ }
+}
+
+func marshal(t *testing.T, out interface{}) []byte {
+ b, err := types.Marshal(out)
+ if err != nil {
+ t.Fatalf("must marshal: %v", err)
+ }
+ return b
+}
diff --git a/pkg/mocks/crypto.go b/pkg/mocks/crypto.go
new file mode 100644
index 0000000..87c883a
--- /dev/null
+++ b/pkg/mocks/crypto.go
@@ -0,0 +1,23 @@
+package mocks
+
+import (
+ "crypto"
+ "crypto/ed25519"
+ "io"
+)
+
+// TestSign implements the signer interface. It can be used to mock an Ed25519
+// signer that always return the same public key, signature, and error.
+type TestSigner struct {
+ PublicKey *[ed25519.PublicKeySize]byte
+ Signature *[ed25519.SignatureSize]byte
+ Error error
+}
+
+func (ts *TestSigner) Public() crypto.PublicKey {
+ return ed25519.PublicKey(ts.PublicKey[:])
+}
+
+func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
+ return ts.Signature[:], ts.Error
+}
diff --git a/pkg/mocks/stfe.go b/pkg/mocks/stfe.go
new file mode 100644
index 0000000..def5bc6
--- /dev/null
+++ b/pkg/mocks/stfe.go
@@ -0,0 +1,110 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/system-transparency/stfe/trillian (interfaces: Client)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ types "github.com/system-transparency/stfe/pkg/types"
+)
+
+// MockClient is a mock of Client interface.
+type MockClient struct {
+ ctrl *gomock.Controller
+ recorder *MockClientMockRecorder
+}
+
+// MockClientMockRecorder is the mock recorder for MockClient.
+type MockClientMockRecorder struct {
+ mock *MockClient
+}
+
+// NewMockClient creates a new mock instance.
+func NewMockClient(ctrl *gomock.Controller) *MockClient {
+ mock := &MockClient{ctrl: ctrl}
+ mock.recorder = &MockClientMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockClient) EXPECT() *MockClientMockRecorder {
+ return m.recorder
+}
+
+// AddLeaf mocks base method.
+func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *types.LeafRequest) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// AddLeaf indicates an expected call of AddLeaf.
+func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1)
+}
+
+// GetConsistencyProof mocks base method.
+func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1)
+ ret0, _ := ret[0].(*types.ConsistencyProof)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetConsistencyProof indicates an expected call of GetConsistencyProof.
+func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1)
+}
+
+// GetInclusionProof mocks base method.
+func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *types.InclusionProofRequest) (*types.InclusionProof, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1)
+ ret0, _ := ret[0].(*types.InclusionProof)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProof indicates an expected call of GetInclusionProof.
+func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1)
+}
+
+// GetLeaves mocks base method.
+func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *types.LeavesRequest) (*types.LeafList, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1)
+ ret0, _ := ret[0].(*types.LeafList)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeaves indicates an expected call of GetLeaves.
+func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1)
+}
+
+// GetTreeHead mocks base method.
+func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetTreeHead", arg0)
+ ret0, _ := ret[0].(*types.TreeHead)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetTreeHead indicates an expected call of GetTreeHead.
+func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0)
+}
diff --git a/pkg/mocks/trillian.go b/pkg/mocks/trillian.go
new file mode 100644
index 0000000..8aa3a58
--- /dev/null
+++ b/pkg/mocks/trillian.go
@@ -0,0 +1,317 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/google/trillian (interfaces: TrillianLogClient)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ trillian "github.com/google/trillian"
+ grpc "google.golang.org/grpc"
+)
+
+// MockTrillianLogClient is a mock of TrillianLogClient interface.
+type MockTrillianLogClient struct {
+ ctrl *gomock.Controller
+ recorder *MockTrillianLogClientMockRecorder
+}
+
+// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient.
+type MockTrillianLogClientMockRecorder struct {
+ mock *MockTrillianLogClient
+}
+
+// NewMockTrillianLogClient creates a new mock instance.
+func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient {
+ mock := &MockTrillianLogClient{ctrl: ctrl}
+ mock.recorder = &MockTrillianLogClientMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder {
+ return m.recorder
+}
+
+// AddSequencedLeaf mocks base method.
+func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...)
+ ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// AddSequencedLeaf indicates an expected call of AddSequencedLeaf.
+func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...)
+}
+
+// AddSequencedLeaves mocks base method.
+func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...)
+ ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// AddSequencedLeaves indicates an expected call of AddSequencedLeaves.
+func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...)
+}
+
+// GetConsistencyProof mocks base method.
+func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetConsistencyProof indicates an expected call of GetConsistencyProof.
+func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...)
+}
+
+// GetEntryAndProof mocks base method.
+func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetEntryAndProof indicates an expected call of GetEntryAndProof.
+func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...)
+}
+
+// GetInclusionProof mocks base method.
+func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetInclusionProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetInclusionProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProof indicates an expected call of GetInclusionProof.
+func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...)
+}
+
+// GetInclusionProofByHash mocks base method.
+func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...)
+ ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash.
+func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...)
+}
+
+// GetLatestSignedLogRoot mocks base method.
+func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot.
+func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...)
+}
+
+// GetLeavesByHash mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByHash indicates an expected call of GetLeavesByHash.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...)
+}
+
+// GetLeavesByIndex mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByIndex indicates an expected call of GetLeavesByIndex.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...)
+}
+
+// GetLeavesByRange mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByRange indicates an expected call of GetLeavesByRange.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...)
+}
+
+// GetSequencedLeafCount mocks base method.
+func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...)
+ ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount.
+func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...)
+}
+
+// InitLog mocks base method.
+func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "InitLog", varargs...)
+ ret0, _ := ret[0].(*trillian.InitLogResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// InitLog indicates an expected call of InitLog.
+func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...)
+}
+
+// QueueLeaf mocks base method.
+func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "QueueLeaf", varargs...)
+ ret0, _ := ret[0].(*trillian.QueueLeafResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// QueueLeaf indicates an expected call of QueueLeaf.
+func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...)
+}
+
+// QueueLeaves mocks base method.
+func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "QueueLeaves", varargs...)
+ ret0, _ := ret[0].(*trillian.QueueLeavesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// QueueLeaves indicates an expected call of QueueLeaves.
+func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...)
+}
diff --git a/pkg/state/state_manager.go b/pkg/state/state_manager.go
new file mode 100644
index 0000000..dfa73f4
--- /dev/null
+++ b/pkg/state/state_manager.go
@@ -0,0 +1,154 @@
+package state
+
+import (
+ "context"
+ "crypto"
+ "fmt"
+ "reflect"
+ "sync"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/google/certificate-transparency-go/schedule"
+ "github.com/system-transparency/stfe/pkg/trillian"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+// StateManager coordinates access to the log's tree heads and (co)signatures
+type StateManager interface {
+ Latest(context.Context) (*types.SignedTreeHead, error)
+ ToSign(context.Context) (*types.SignedTreeHead, error)
+ Cosigned(context.Context) (*types.SignedTreeHead, error)
+ AddCosignature(context.Context, *[types.VerificationKeySize]byte, *[types.SignatureSize]byte) error
+ Run(context.Context)
+}
+
+// StateManagerSingle implements the StateManager interface. It is assumed that
+// the log server is running on a single-instance machine. So, no coordination.
+type StateManagerSingle struct {
+ client trillian.Client
+ signer crypto.Signer
+ interval time.Duration
+ deadline time.Duration
+ sync.RWMutex
+
+ // cosigned is the current cosigned tree head that is being served
+ cosigned types.SignedTreeHead
+
+ // tosign is the current tree head that is being cosigned by witnesses
+ tosign types.SignedTreeHead
+
+ // cosignature keeps track of all cosignatures for the tosign tree head
+ cosignature map[[types.HashSize]byte]*types.SigIdent
+}
+
+func NewStateManagerSingle(client trillian.Client, signer crypto.Signer, interval, deadline time.Duration) (*StateManagerSingle, error) {
+ sm := &StateManagerSingle{
+ client: client,
+ signer: signer,
+ interval: interval,
+ deadline: deadline,
+ }
+
+ ctx, _ := context.WithTimeout(context.Background(), sm.deadline)
+ sth, err := sm.Latest(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("Latest: %v", err)
+ }
+
+ sm.cosigned = *sth
+ sm.tosign = *sth
+ sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{
+ *sth.SigIdent[0].KeyHash: sth.SigIdent[0], // log signature
+ }
+ return sm, nil
+}
+
+func (sm *StateManagerSingle) Run(ctx context.Context) {
+ schedule.Every(ctx, sm.interval, func(ctx context.Context) {
+ ictx, _ := context.WithTimeout(ctx, sm.deadline)
+ nextTreeHead, err := sm.Latest(ictx)
+ if err != nil {
+ glog.Warningf("rotate failed: Latest: %v", err)
+ return
+ }
+
+ sm.Lock()
+ defer sm.Unlock()
+ sm.rotate(nextTreeHead)
+ })
+}
+
+func (sm *StateManagerSingle) Latest(ctx context.Context) (*types.SignedTreeHead, error) {
+ th, err := sm.client.GetTreeHead(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("LatestTreeHead: %v", err)
+ }
+ sth, err := th.Sign(sm.signer)
+ if err != nil {
+ return nil, fmt.Errorf("sign: %v", err)
+ }
+ return sth, nil
+}
+
+func (sm *StateManagerSingle) ToSign(_ context.Context) (*types.SignedTreeHead, error) {
+ sm.RLock()
+ defer sm.RUnlock()
+ return &sm.tosign, nil
+}
+
+func (sm *StateManagerSingle) Cosigned(_ context.Context) (*types.SignedTreeHead, error) {
+ sm.RLock()
+ defer sm.RUnlock()
+ return &sm.cosigned, nil
+}
+
+func (sm *StateManagerSingle) AddCosignature(_ context.Context, vk *[types.VerificationKeySize]byte, sig *[types.SignatureSize]byte) error {
+ sm.Lock()
+ defer sm.Unlock()
+
+ if err := sm.tosign.TreeHead.Verify(vk, sig); err != nil {
+ return fmt.Errorf("Verify: %v", err)
+ }
+ witness := types.Hash(vk[:])
+ if _, ok := sm.cosignature[*witness]; ok {
+ return fmt.Errorf("signature-signer pair is a duplicate")
+ }
+ sm.cosignature[*witness] = &types.SigIdent{
+ Signature: sig,
+ KeyHash: witness,
+ }
+
+ glog.V(3).Infof("accepted new cosignature from witness: %x", *witness)
+ return nil
+}
+
+// rotate rotates the log's cosigned and stable STH. The caller must aquire the
+// source's read-write lock if there are concurrent reads and/or writes.
+func (sm *StateManagerSingle) rotate(next *types.SignedTreeHead) {
+ if reflect.DeepEqual(sm.cosigned.TreeHead, sm.tosign.TreeHead) {
+ // cosigned and tosign are the same. So, we need to merge all
+ // cosignatures that we already had with the new collected ones.
+ for _, sigident := range sm.cosigned.SigIdent {
+ if _, ok := sm.cosignature[*sigident.KeyHash]; !ok {
+ sm.cosignature[*sigident.KeyHash] = sigident
+ }
+ }
+ glog.V(3).Infof("cosigned tree head repeated, merged signatures")
+ }
+ var cosignatures []*types.SigIdent
+ for _, sigident := range sm.cosignature {
+ cosignatures = append(cosignatures, sigident)
+ }
+
+ // Update cosigned tree head
+ sm.cosigned.TreeHead = sm.tosign.TreeHead
+ sm.cosigned.SigIdent = cosignatures
+
+ // Update to-sign tree head
+ sm.tosign = *next
+ sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{
+ *next.SigIdent[0].KeyHash: next.SigIdent[0], // log signature
+ }
+ glog.V(3).Infof("rotated tree heads")
+}
diff --git a/pkg/state/state_manager_test.go b/pkg/state/state_manager_test.go
new file mode 100644
index 0000000..08990cc
--- /dev/null
+++ b/pkg/state/state_manager_test.go
@@ -0,0 +1,393 @@
+package state
+
+import (
+ "bytes"
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "crypto/rand"
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/golang/mock/gomock"
+ "github.com/system-transparency/stfe/pkg/mocks"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+var (
+ testSig = &[types.SignatureSize]byte{}
+ testPub = &[types.VerificationKeySize]byte{}
+ testTH = &types.TreeHead{
+ Timestamp: 0,
+ TreeSize: 0,
+ RootHash: types.Hash(nil),
+ }
+ testSigIdent = &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash(testPub[:]),
+ }
+ testSTH = &types.SignedTreeHead{
+ TreeHead: *testTH,
+ SigIdent: []*types.SigIdent{testSigIdent},
+ }
+ testSignerOK = &mocks.TestSigner{testPub, testSig, nil}
+ testSignerErr = &mocks.TestSigner{testPub, testSig, fmt.Errorf("something went wrong")}
+)
+
+func TestNewStateManagerSingle(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ rsp *types.TreeHead
+ err error
+ wantErr bool
+ wantSth *types.SignedTreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ signer: testSignerOK,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: testSignerOK,
+ rsp: testTH,
+ wantSth: testSTH,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err)
+
+ sm, err := NewStateManagerSingle(client, table.signer, time.Duration(0), time.Duration(0))
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := &sm.cosigned, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ if got, want := &sm.tosign, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ // we only have log signature on startup
+ if got, want := len(sm.cosignature), 1; got != want {
+ t.Errorf("got %d cosignatures but wanted %d in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestLatest(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ rsp *types.TreeHead
+ err error
+ wantErr bool
+ wantSth *types.SignedTreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ signer: testSignerOK,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: signature failure",
+ rsp: testTH,
+ signer: testSignerErr,
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: testSignerOK,
+ rsp: testTH,
+ wantSth: testSTH,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err)
+ sm := StateManagerSingle{
+ client: client,
+ signer: table.signer,
+ }
+
+ sth, err := sm.Latest(context.Background())
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestToSign(t *testing.T) {
+ description := "valid"
+ sm := StateManagerSingle{
+ tosign: *testSTH,
+ }
+ sth, err := sm.ToSign(context.Background())
+ if err != nil {
+ t.Errorf("ToSign should not fail with error: %v", err)
+ return
+ }
+ if got, want := sth, testSTH; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description)
+ }
+}
+
+func TestCosigned(t *testing.T) {
+ description := "valid"
+ sm := StateManagerSingle{
+ cosigned: *testSTH,
+ }
+ sth, err := sm.Cosigned(context.Background())
+ if err != nil {
+ t.Errorf("Cosigned should not fail with error: %v", err)
+ return
+ }
+ if got, want := sth, testSTH; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description)
+ }
+}
+
+func TestAddCosignature(t *testing.T) {
+ vk, sk, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ t.Fatalf("GenerateKey: %v", err)
+ }
+ if bytes.Equal(vk[:], testPub[:]) {
+ t.Fatalf("Sampled same key as testPub, aborting...")
+ }
+ var vkArray [types.VerificationKeySize]byte
+ copy(vkArray[:], vk[:])
+
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ vk *[types.VerificationKeySize]byte
+ th *types.TreeHead
+ wantErr bool
+ }{
+ {
+ description: "invalid: signature error",
+ signer: sk,
+ vk: testPub, // wrong key for message
+ th: testTH,
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: sk,
+ vk: &vkArray,
+ th: testTH,
+ },
+ } {
+ sth, _ := table.th.Sign(testSignerOK)
+ logKeyHash := sth.SigIdent[0].KeyHash
+ logSigIdent := sth.SigIdent[0]
+ sm := &StateManagerSingle{
+ signer: testSignerOK,
+ cosigned: *sth,
+ tosign: *sth,
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *logKeyHash: logSigIdent,
+ },
+ }
+
+ // Prepare witness signature
+ sth, err := table.th.Sign(table.signer)
+ if err != nil {
+ t.Fatalf("Sign: %v", err)
+ }
+ witnessKeyHash := sth.SigIdent[0].KeyHash
+ witnessSigIdent := sth.SigIdent[0]
+
+ // Add witness signature
+ err = sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+
+ // We should have two signatures (log + witness)
+ if got, want := len(sm.cosignature), 2; got != want {
+ t.Errorf("got %d cosignatures but wanted %v in test %q", got, want, table.description)
+ continue
+ }
+ // check that log signature is there
+ sigident, ok := sm.cosignature[*logKeyHash]
+ if !ok {
+ t.Errorf("log signature is missing")
+ continue
+ }
+ if got, want := sigident, logSigIdent; !reflect.DeepEqual(got, want) {
+ t.Errorf("got log sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ // check that witness signature is there
+ sigident, ok = sm.cosignature[*witnessKeyHash]
+ if !ok {
+ t.Errorf("witness signature is missing")
+ continue
+ }
+ if got, want := sigident, witnessSigIdent; !reflect.DeepEqual(got, want) {
+ t.Errorf("got witness sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ continue
+ }
+
+ // Adding a duplicate signature should give an error
+ if err := sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature); err == nil {
+ t.Errorf("duplicate witness signature accepted as valid")
+ }
+ }
+}
+
+func TestRotate(t *testing.T) {
+ log := testSigIdent
+ wit1 := &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash([]byte("wit1 key")),
+ }
+ wit2 := &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash([]byte("wit2 key")),
+ }
+ th0 := testTH
+ th1 := &types.TreeHead{
+ Timestamp: 1,
+ TreeSize: 1,
+ RootHash: types.Hash([]byte("1")),
+ }
+ th2 := &types.TreeHead{
+ Timestamp: 2,
+ TreeSize: 2,
+ RootHash: types.Hash([]byte("2")),
+ }
+
+ for _, table := range []struct {
+ description string
+ before, after *StateManagerSingle
+ next *types.SignedTreeHead
+ }{
+ {
+ description: "tosign tree head repated, but got one new witnes signature",
+ before: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log,
+ *wit2.KeyHash: wit2, // the new witness signature
+ },
+ },
+ next: &types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ after: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1, wit2},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log, // after rotate we always have log sig
+ },
+ },
+ },
+ {
+ description: "tosign tree head did not repeat, it got one witness signature",
+ before: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log,
+ *wit2.KeyHash: wit2, // the only witness that signed tosign
+ },
+ },
+ next: &types.SignedTreeHead{
+ TreeHead: *th2,
+ SigIdent: []*types.SigIdent{log},
+ },
+ after: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log, wit2},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th2,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log, // after rotate we always have log sig
+ },
+ },
+ },
+ } {
+ table.before.rotate(table.next)
+ if got, want := table.before.cosigned.TreeHead, table.after.cosigned.TreeHead; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ checkWitnessList(t, table.description, table.before.cosigned.SigIdent, table.after.cosigned.SigIdent)
+ if got, want := table.before.tosign.TreeHead, table.after.tosign.TreeHead; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ checkWitnessList(t, table.description, table.before.tosign.SigIdent, table.after.tosign.SigIdent)
+ if got, want := table.before.cosignature, table.after.cosignature; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosignature map\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func checkWitnessList(t *testing.T, description string, got, want []*types.SigIdent) {
+ t.Helper()
+ for _, si := range got {
+ found := false
+ for _, sj := range want {
+ if reflect.DeepEqual(si, sj) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("got unexpected signature-signer pair with key hash in test %q: %x", description, si.KeyHash[:])
+ }
+ }
+ if len(got) != len(want) {
+ t.Errorf("got %d signature-signer pairs but wanted %d in test %q", len(got), len(want), description)
+ }
+}
diff --git a/pkg/trillian/client.go b/pkg/trillian/client.go
new file mode 100644
index 0000000..9523e56
--- /dev/null
+++ b/pkg/trillian/client.go
@@ -0,0 +1,178 @@
+package trillian
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/golang/glog"
+ "github.com/google/trillian"
+ ttypes "github.com/google/trillian/types"
+ "github.com/system-transparency/stfe/pkg/types"
+ "google.golang.org/grpc/codes"
+)
+
+type Client interface {
+ AddLeaf(context.Context, *types.LeafRequest) error
+ GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error)
+ GetTreeHead(context.Context) (*types.TreeHead, error)
+ GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error)
+ GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error)
+}
+
+// TrillianClient is a wrapper around the Trillian gRPC client.
+type TrillianClient struct {
+ // TreeID is a Merkle tree identifier that Trillian uses
+ TreeID int64
+
+ // GRPC is a Trillian gRPC client
+ GRPC trillian.TrillianLogClient
+}
+
+func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error {
+ leaf := types.Leaf{
+ Message: req.Message,
+ SigIdent: types.SigIdent{
+ Signature: req.Signature,
+ KeyHash: types.Hash(req.VerificationKey[:]),
+ },
+ }
+ serialized := leaf.Marshal()
+
+ glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized))
+ rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{
+ LogId: c.TreeID,
+ Leaf: &trillian.LogLeaf{
+ LeafValue: serialized,
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return fmt.Errorf("no response")
+ }
+ if rsp.QueuedLeaf == nil {
+ return fmt.Errorf("no queued leaf")
+ }
+ if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists {
+ return fmt.Errorf("leaf is already queued or included")
+ }
+ return nil
+}
+
+func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) {
+ rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{
+ LogId: c.TreeID,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if rsp.SignedLogRoot == nil {
+ return nil, fmt.Errorf("no signed log root")
+ }
+ if rsp.SignedLogRoot.LogRoot == nil {
+ return nil, fmt.Errorf("no log root")
+ }
+ var r ttypes.LogRootV1
+ if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil {
+ return nil, fmt.Errorf("no log root: unmarshal failed: %v", err)
+ }
+ if len(r.RootHash) != types.HashSize {
+ return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash))
+ }
+ return treeHeadFromLogRoot(&r), nil
+}
+
+func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) {
+ rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{
+ LogId: c.TreeID,
+ FirstTreeSize: int64(req.OldSize),
+ SecondTreeSize: int64(req.NewSize),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if rsp.Proof == nil {
+ return nil, fmt.Errorf("no consistency proof")
+ }
+ if len(rsp.Proof.Hashes) == 0 {
+ return nil, fmt.Errorf("not a consistency proof: empty")
+ }
+ path, err := nodePathFromHashes(rsp.Proof.Hashes)
+ if err != nil {
+ return nil, fmt.Errorf("not a consistency proof: %v", err)
+ }
+ return &types.ConsistencyProof{
+ OldSize: req.OldSize,
+ NewSize: req.NewSize,
+ Path: path,
+ }, nil
+}
+
+func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) {
+ rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{
+ LogId: c.TreeID,
+ LeafHash: req.LeafHash[:],
+ TreeSize: int64(req.TreeSize),
+ OrderBySequence: true,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if len(rsp.Proof) != 1 {
+ return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof))
+ }
+ proof := rsp.Proof[0]
+ if len(proof.Hashes) == 0 {
+ return nil, fmt.Errorf("not an inclusion proof: empty")
+ }
+ path, err := nodePathFromHashes(proof.Hashes)
+ if err != nil {
+ return nil, fmt.Errorf("not an inclusion proof: %v", err)
+ }
+ return &types.InclusionProof{
+ TreeSize: req.TreeSize,
+ LeafIndex: uint64(proof.LeafIndex),
+ Path: path,
+ }, nil
+}
+
+func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) {
+ rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{
+ LogId: c.TreeID,
+ StartIndex: int64(req.StartSize),
+ Count: int64(req.EndSize-req.StartSize) + 1,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want {
+ return nil, fmt.Errorf("unexpected number of leaves: %d", got)
+ }
+ var list types.LeafList
+ for i, leaf := range rsp.Leaves {
+ leafIndex := int64(req.StartSize + uint64(i))
+ if leafIndex != leaf.LeafIndex {
+ return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex)
+ }
+
+ var l types.Leaf
+ if err := l.Unmarshal(leaf.LeafValue); err != nil {
+ return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err)
+ }
+ list = append(list[:], &l)
+ }
+ return &list, nil
+}
diff --git a/pkg/trillian/client_test.go b/pkg/trillian/client_test.go
new file mode 100644
index 0000000..6b3d881
--- /dev/null
+++ b/pkg/trillian/client_test.go
@@ -0,0 +1,533 @@
+package trillian
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/google/trillian"
+ ttypes "github.com/google/trillian/types"
+ "github.com/system-transparency/stfe/pkg/mocks"
+ "github.com/system-transparency/stfe/pkg/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+func TestAddLeaf(t *testing.T) {
+ req := &types.LeafRequest{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ Signature: &[types.SignatureSize]byte{},
+ VerificationKey: &[types.VerificationKeySize]byte{},
+ DomainHint: "example.com",
+ }
+ for _, table := range []struct {
+ description string
+ req *types.LeafRequest
+ rsp *trillian.QueueLeafResponse
+ err error
+ wantErr bool
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: no queued leaf",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: leaf is already queued or included",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{
+ QueuedLeaf: &trillian.QueuedLogLeaf{
+ Leaf: &trillian.LogLeaf{
+ LeafValue: req.Message.Marshal(),
+ },
+ Status: status.New(codes.AlreadyExists, "duplicate").Proto(),
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{
+ QueuedLeaf: &trillian.QueuedLogLeaf{
+ Leaf: &trillian.LogLeaf{
+ LeafValue: req.Message.Marshal(),
+ },
+ Status: status.New(codes.OK, "ok").Proto(),
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ err := client.AddLeaf(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }()
+ }
+}
+
+func TestGetTreeHead(t *testing.T) {
+ // valid root
+ root := &ttypes.LogRootV1{
+ TreeSize: 0,
+ RootHash: make([]byte, types.HashSize),
+ TimestampNanos: 1622585623133599429,
+ }
+ buf, err := root.MarshalBinary()
+ if err != nil {
+ t.Fatalf("must marshal log root: %v", err)
+ }
+ // invalid root
+ root.RootHash = make([]byte, types.HashSize+1)
+ bufBadHash, err := root.MarshalBinary()
+ if err != nil {
+ t.Fatalf("must marshal log root: %v", err)
+ }
+
+ for _, table := range []struct {
+ description string
+ rsp *trillian.GetLatestSignedLogRootResponse
+ err error
+ wantErr bool
+ wantTh *types.TreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ wantErr: true,
+ },
+ {
+ description: "invalid: no signed log root",
+ rsp: &trillian.GetLatestSignedLogRootResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: no log root",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{},
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: no log root: unmarshal failed",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: buf[1:],
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected hash length",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: bufBadHash,
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: buf,
+ },
+ },
+ wantTh: &types.TreeHead{
+ Timestamp: 1622585623,
+ TreeSize: 0,
+ RootHash: &[types.HashSize]byte{},
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ th, err := client.GetTreeHead(context.Background())
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetConsistencyProof(t *testing.T) {
+ req := &types.ConsistencyProofRequest{
+ OldSize: 1,
+ NewSize: 3,
+ }
+ for _, table := range []struct {
+ description string
+ req *types.ConsistencyProofRequest
+ rsp *trillian.GetConsistencyProofResponse
+ err error
+ wantErr bool
+ wantProof *types.ConsistencyProof
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: no consistency proof",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a consistency proof (1/2)",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a consistency proof (2/2)",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize+1),
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize),
+ },
+ },
+ },
+ wantProof: &types.ConsistencyProof{
+ OldSize: 1,
+ NewSize: 3,
+ Path: []*[types.HashSize]byte{
+ &[types.HashSize]byte{},
+ &[types.HashSize]byte{},
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ proof, err := client.GetConsistencyProof(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) {
+ t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetInclusionProof(t *testing.T) {
+ req := &types.InclusionProofRequest{
+ TreeSize: 4,
+ LeafHash: &[types.HashSize]byte{},
+ }
+ for _, table := range []struct {
+ description string
+ req *types.InclusionProofRequest
+ rsp *trillian.GetInclusionProofByHashResponse
+ err error
+ wantErr bool
+ wantProof *types.InclusionProof
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: bad proof count",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{},
+ &trillian.Proof{},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not an inclusion proof (1/2)",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not an inclusion proof (2/2)",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize+1),
+ },
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize),
+ },
+ },
+ },
+ },
+ wantProof: &types.InclusionProof{
+ TreeSize: 4,
+ LeafIndex: 1,
+ Path: []*[types.HashSize]byte{
+ &[types.HashSize]byte{},
+ &[types.HashSize]byte{},
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ proof, err := client.GetInclusionProof(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) {
+ t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetLeaves(t *testing.T) {
+ req := &types.LeavesRequest{
+ StartSize: 1,
+ EndSize: 2,
+ }
+ firstLeaf := &types.Leaf{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ SigIdent: types.SigIdent{
+ Signature: &[types.SignatureSize]byte{},
+ KeyHash: &[types.HashSize]byte{},
+ },
+ }
+ secondLeaf := &types.Leaf{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ SigIdent: types.SigIdent{
+ Signature: &[types.SignatureSize]byte{},
+ KeyHash: &[types.HashSize]byte{},
+ },
+ }
+
+ for _, table := range []struct {
+ description string
+ req *types.LeavesRequest
+ rsp *trillian.GetLeavesByRangeResponse
+ err error
+ wantErr bool
+ wantLeaves *types.LeafList
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected number of leaves",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected leaf (1/2)",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal(),
+ LeafIndex: 3,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected leaf (2/2)",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal()[1:],
+ LeafIndex: 2,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal(),
+ LeafIndex: 2,
+ },
+ },
+ },
+ wantLeaves: &types.LeafList{
+ firstLeaf,
+ secondLeaf,
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ leaves, err := client.GetLeaves(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) {
+ t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
diff --git a/pkg/trillian/util.go b/pkg/trillian/util.go
new file mode 100644
index 0000000..4cf31fb
--- /dev/null
+++ b/pkg/trillian/util.go
@@ -0,0 +1,33 @@
+package trillian
+
+import (
+ "fmt"
+
+ trillian "github.com/google/trillian/types"
+ siglog "github.com/system-transparency/stfe/pkg/types"
+)
+
+func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead {
+ var hash [siglog.HashSize]byte
+ th := siglog.TreeHead{
+ Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000),
+ TreeSize: uint64(lr.TreeSize),
+ RootHash: &hash,
+ }
+ copy(th.RootHash[:], lr.RootHash)
+ return &th
+}
+
+func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) {
+ var path []*[siglog.HashSize]byte
+ for _, hash := range hashes {
+ if len(hash) != siglog.HashSize {
+ return nil, fmt.Errorf("unexpected hash length: %v", len(hash))
+ }
+
+ var h [siglog.HashSize]byte
+ copy(h[:], hash)
+ path = append(path, &h)
+ }
+ return path, nil
+}
diff --git a/pkg/types/ascii.go b/pkg/types/ascii.go
new file mode 100644
index 0000000..d27d79b
--- /dev/null
+++ b/pkg/types/ascii.go
@@ -0,0 +1,421 @@
+package types
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "strconv"
+)
+
+const (
+ // Delim is a key-value separator
+ Delim = "="
+
+ // EOL is a line sepator
+ EOL = "\n"
+
+ // NumField* is the number of unique keys in an incoming ASCII message
+ NumFieldLeaf = 4
+ NumFieldSignedTreeHead = 5
+ NumFieldConsistencyProof = 3
+ NumFieldInclusionProof = 3
+ NumFieldLeavesRequest = 2
+ NumFieldInclusionProofRequest = 2
+ NumFieldConsistencyProofRequest = 2
+ NumFieldLeafRequest = 5
+ NumFieldCosignatureRequest = 2
+
+ // New leaf keys
+ ShardHint = "shard_hint"
+ Checksum = "checksum"
+ SignatureOverMessage = "signature_over_message"
+ VerificationKey = "verification_key"
+ DomainHint = "domain_hint"
+
+ // Inclusion proof keys
+ LeafHash = "leaf_hash"
+ LeafIndex = "leaf_index"
+ InclusionPath = "inclusion_path"
+
+ // Consistency proof keys
+ NewSize = "new_size"
+ OldSize = "old_size"
+ ConsistencyPath = "consistency_path"
+
+ // Range of leaves keys
+ StartSize = "start_size"
+ EndSize = "end_size"
+
+ // Tree head keys
+ Timestamp = "timestamp"
+ TreeSize = "tree_size"
+ RootHash = "root_hash"
+
+ // Signature and signer-identity keys
+ Signature = "signature"
+ KeyHash = "key_hash"
+)
+
+// MessageASCI is a wrapper that manages ASCII key-value pairs
+type MessageASCII struct {
+ m map[string][]string
+}
+
+// NewMessageASCII unpacks an incoming ASCII message
+func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) {
+ buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return nil, fmt.Errorf("ReadAll: %v", err)
+ }
+ lines := bytes.Split(buf, []byte(EOL))
+ if len(lines) <= 1 {
+ return nil, fmt.Errorf("Not enough lines: empty")
+ }
+ lines = lines[:len(lines)-1] // valid message => split gives empty last line
+
+ msg := MessageASCII{make(map[string][]string)}
+ for _, line := range lines {
+ split := bytes.Index(line, []byte(Delim))
+ if split == -1 {
+ return nil, fmt.Errorf("invalid line: %v", string(line))
+ }
+
+ key := string(line[:split])
+ value := string(line[split+len(Delim):])
+ values, ok := msg.m[key]
+ if !ok {
+ values = nil
+ msg.m[key] = values
+ }
+ msg.m[key] = append(values, value)
+ }
+
+ if msg.NumField() != numFieldExpected {
+ return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField())
+ }
+ return &msg, nil
+}
+
+// NumField returns the number of unique keys
+func (msg *MessageASCII) NumField() int {
+ return len(msg.m)
+}
+
+// GetStrings returns a list of strings
+func (msg *MessageASCII) GetStrings(key string) []string {
+ strs, ok := msg.m[key]
+ if !ok {
+ return nil
+ }
+ return strs
+}
+
+// GetString unpacks a string
+func (msg *MessageASCII) GetString(key string) (string, error) {
+ strs := msg.GetStrings(key)
+ if len(strs) != 1 {
+ return "", fmt.Errorf("expected one string: %v", strs)
+ }
+ return strs[0], nil
+}
+
+// GetUint64 unpacks an uint64
+func (msg *MessageASCII) GetUint64(key string) (uint64, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return 0, fmt.Errorf("GetString: %v", err)
+ }
+ num, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("ParseUint: %v", err)
+ }
+ return num, nil
+}
+
+// GetHash unpacks a hash
+func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var hash [HashSize]byte
+ if err := decodeHex(str, hash[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &hash, nil
+}
+
+// GetSignature unpacks a signature
+func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var signature [SignatureSize]byte
+ if err := decodeHex(str, signature[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &signature, nil
+}
+
+// GetVerificationKey unpacks a verification key
+func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var vk [VerificationKeySize]byte
+ if err := decodeHex(str, vk[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &vk, nil
+}
+
+// decodeHex decodes a hex-encoded string into an already-sized byte slice
+func decodeHex(str string, out []byte) error {
+ buf, err := hex.DecodeString(str)
+ if err != nil {
+ return fmt.Errorf("DecodeString: %v", err)
+ }
+ if len(buf) != len(out) {
+ return fmt.Errorf("invalid length: %v", len(buf))
+ }
+ copy(out, buf)
+ return nil
+}
+
+/*
+ *
+ * MarshalASCII wrappers for types that the log server outputs
+ *
+ */
+func (l *Leaf) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, SignatureOverMessage, hex.EncodeToString(l.Signature[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ return nil
+}
+
+func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, sigident := range sth.SigIdent {
+ if err := sigident.MarshalASCII(w); err != nil {
+ return fmt.Errorf("MarshalASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func (si *SigIdent) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, Signature, hex.EncodeToString(si.Signature[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, KeyHash, hex.EncodeToString(si.KeyHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ return nil
+}
+
+func (p *ConsistencyProof) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, hash := range p.Path {
+ if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func (p *InclusionProof) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, hash := range p.Path {
+ if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func writeASCII(w io.Writer, key, value string) error {
+ if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil {
+ return fmt.Errorf("Fprintf: %v", err)
+ }
+ return nil
+}
+
+/*
+ *
+ * Unmarshal ASCII wrappers that the log server and/or log clients receive.
+ *
+ */
+func (ll *LeafList) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldSignedTreeHead)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ // TreeHead
+ if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil {
+ return fmt.Errorf("GetUint64(Timestamp): %v", err)
+ }
+ if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil {
+ return fmt.Errorf("GetUint64(TreeSize): %v", err)
+ }
+ if sth.RootHash, err = msg.GetHash(RootHash); err != nil {
+ return fmt.Errorf("GetHash(RootHash): %v", err)
+ }
+
+ // SigIdent
+ signatures := msg.GetStrings(Signature)
+ if len(signatures) == 0 {
+ return fmt.Errorf("no signer")
+ }
+ keyHashes := msg.GetStrings(KeyHash)
+ if len(signatures) != len(keyHashes) {
+ return fmt.Errorf("mismatched signature-signer count")
+ }
+ sth.SigIdent = make([]*SigIdent, 0, len(signatures))
+ for i, n := 0, len(signatures); i < n; i++ {
+ var signature [SignatureSize]byte
+ if err := decodeHex(signatures[i], signature[:]); err != nil {
+ return fmt.Errorf("decodeHex: %v", err)
+ }
+ var hash [HashSize]byte
+ if err := decodeHex(keyHashes[i], hash[:]); err != nil {
+ return fmt.Errorf("decodeHex: %v", err)
+ }
+ sth.SigIdent = append(sth.SigIdent, &SigIdent{
+ Signature: &signature,
+ KeyHash: &hash,
+ })
+ }
+ return nil
+}
+
+func (p *InclusionProof) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.LeafHash, err = msg.GetHash(LeafHash); err != nil {
+ return fmt.Errorf("GetHash(LeafHash): %v", err)
+ }
+ if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil {
+ return fmt.Errorf("GetUint64(TreeSize): %v", err)
+ }
+ return nil
+}
+
+func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.NewSize, err = msg.GetUint64(NewSize); err != nil {
+ return fmt.Errorf("GetUint64(NewSize): %v", err)
+ }
+ if req.OldSize, err = msg.GetUint64(OldSize); err != nil {
+ return fmt.Errorf("GetUint64(OldSize): %v", err)
+ }
+ return nil
+}
+
+func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldLeavesRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.StartSize, err = msg.GetUint64(StartSize); err != nil {
+ return fmt.Errorf("GetUint64(StartSize): %v", err)
+ }
+ if req.EndSize, err = msg.GetUint64(EndSize); err != nil {
+ return fmt.Errorf("GetUint64(EndSize): %v", err)
+ }
+ return nil
+}
+
+func (req *LeafRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldLeafRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil {
+ return fmt.Errorf("GetUint64(ShardHint): %v", err)
+ }
+ if req.Checksum, err = msg.GetHash(Checksum); err != nil {
+ return fmt.Errorf("GetHash(Checksum): %v", err)
+ }
+ if req.Signature, err = msg.GetSignature(SignatureOverMessage); err != nil {
+ return fmt.Errorf("GetSignature: %v", err)
+ }
+ if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil {
+ return fmt.Errorf("GetVerificationKey: %v", err)
+ }
+ if req.DomainHint, err = msg.GetString(DomainHint); err != nil {
+ return fmt.Errorf("GetString(DomainHint): %v", err)
+ }
+ return nil
+}
+
+func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldCosignatureRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.Signature, err = msg.GetSignature(Signature); err != nil {
+ return fmt.Errorf("GetSignature: %v", err)
+ }
+ if req.KeyHash, err = msg.GetHash(KeyHash); err != nil {
+ return fmt.Errorf("GetHash(KeyHash): %v", err)
+ }
+ return nil
+}
diff --git a/pkg/types/ascii_test.go b/pkg/types/ascii_test.go
new file mode 100644
index 0000000..92732f9
--- /dev/null
+++ b/pkg/types/ascii_test.go
@@ -0,0 +1,465 @@
+package types
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+)
+
+/*
+ *
+ * MessageASCII methods and helpers
+ *
+ */
+func TestNewMessageASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ input io.Reader
+ wantErr bool
+ wantMap map[string][]string
+ }{
+ {
+ description: "invalid: not enough lines",
+ input: bytes.NewBufferString(""),
+ wantErr: true,
+ },
+ {
+ description: "invalid: lines must end with new line",
+ input: bytes.NewBufferString("k1=v1\nk2=v2"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: lines must not be empty",
+ input: bytes.NewBufferString("k1=v1\n\nk2=v2\n"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: wrong number of fields",
+ input: bytes.NewBufferString("k1=v1\n"),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ input: bytes.NewBufferString("k1=v1\nk2=v2\nk2=v3=4\n"),
+ wantMap: map[string][]string{
+ "k1": []string{"v1"},
+ "k2": []string{"v2", "v3=4"},
+ },
+ },
+ } {
+ msg, err := NewMessageASCII(table.input, len(table.wantMap))
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := msg.m, table.wantMap; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestNumField(t *testing.T) {}
+func TestGetStrings(t *testing.T) {}
+func TestGetString(t *testing.T) {}
+func TestGetUint64(t *testing.T) {}
+func TestGetHash(t *testing.T) {}
+func TestGetSignature(t *testing.T) {}
+func TestGetVerificationKey(t *testing.T) {}
+func TestDecodeHex(t *testing.T) {}
+
+/*
+ *
+ * MarshalASCII methods and helpers
+ *
+ */
+func TestLeafMarshalASCII(t *testing.T) {
+ description := "valid: two leaves"
+ leafList := []*Leaf{
+ &Leaf{
+ Message: Message{
+ ShardHint: 123,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ &Leaf{
+ Message: Message{
+ ShardHint: 456,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ // Leaf 1
+ ShardHint, Delim, 123, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ // Leaf 2
+ ShardHint, Delim, 456, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ for _, leaf := range leafList {
+ if err := leaf.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q: %v", false, true, description, err)
+ return
+ }
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestSignedTreeHeadMarshalASCII(t *testing.T) {
+ description := "valid"
+ sth := &SignedTreeHead{
+ TreeHead: TreeHead{
+ Timestamp: 123,
+ TreeSize: 456,
+ RootHash: testBuffer32,
+ },
+ SigIdent: []*SigIdent{
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ Timestamp, Delim, 123, EOL,
+ TreeSize, Delim, 456, EOL,
+ RootHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := sth.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestInclusionProofMarshalASCII(t *testing.T) {
+ description := "valid"
+ proof := InclusionProof{
+ TreeSize: 321,
+ LeafIndex: 123,
+ Path: []*[HashSize]byte{
+ testBuffer32,
+ testBuffer32,
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s",
+ TreeSize, Delim, 321, EOL,
+ LeafIndex, Delim, 123, EOL,
+ InclusionPath, Delim, testBuffer32[:], EOL,
+ InclusionPath, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := proof.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestConsistencyProofMarshalASCII(t *testing.T) {
+ description := "valid"
+ proof := ConsistencyProof{
+ NewSize: 321,
+ OldSize: 123,
+ Path: []*[HashSize]byte{
+ testBuffer32,
+ testBuffer32,
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s",
+ NewSize, Delim, 321, EOL,
+ OldSize, Delim, 123, EOL,
+ ConsistencyPath, Delim, testBuffer32[:], EOL,
+ ConsistencyPath, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := proof.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestWriteASCII(t *testing.T) {
+}
+
+/*
+ *
+ * UnmarshalASCII methods and helpers
+ *
+ */
+func TestLeafListUnmarshalASCII(t *testing.T) {}
+
+func TestSignedTreeHeadUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantSth *SignedTreeHead
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ Timestamp, Delim, 123, EOL,
+ TreeSize, Delim, 456, EOL,
+ RootHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ )),
+ wantSth: &SignedTreeHead{
+ TreeHead: TreeHead{
+ Timestamp: 123,
+ TreeSize: 456,
+ RootHash: testBuffer32,
+ },
+ SigIdent: []*SigIdent{
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ },
+ } {
+ var sth SignedTreeHead
+ err := sth.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &sth, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestInclusionProofUnmarshalASCII(t *testing.T) {}
+func TestConsistencyProofUnmarshalASCII(t *testing.T) {}
+
+func TestInclusionProofRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *InclusionProofRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%d%s",
+ LeafHash, Delim, testBuffer32[:], EOL,
+ TreeSize, Delim, 123, EOL,
+ )),
+ wantReq: &InclusionProofRequest{
+ LeafHash: testBuffer32,
+ TreeSize: 123,
+ },
+ },
+ } {
+ var req InclusionProofRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestConsistencyProofRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *ConsistencyProofRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s",
+ NewSize, Delim, 321, EOL,
+ OldSize, Delim, 123, EOL,
+ )),
+ wantReq: &ConsistencyProofRequest{
+ NewSize: 321,
+ OldSize: 123,
+ },
+ },
+ } {
+ var req ConsistencyProofRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestLeavesRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *LeavesRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s",
+ StartSize, Delim, 123, EOL,
+ EndSize, Delim, 456, EOL,
+ )),
+ wantReq: &LeavesRequest{
+ StartSize: 123,
+ EndSize: 456,
+ },
+ },
+ } {
+ var req LeavesRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestLeafRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *LeafRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%s%s",
+ ShardHint, Delim, 123, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ VerificationKey, Delim, testBuffer32[:], EOL,
+ DomainHint, Delim, "example.com", EOL,
+ )),
+ wantReq: &LeafRequest{
+ Message: Message{
+ ShardHint: 123,
+ Checksum: testBuffer32,
+ },
+ Signature: testBuffer64,
+ VerificationKey: testBuffer32,
+ DomainHint: "example.com",
+ },
+ },
+ } {
+ var req LeafRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestCosignatureRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *CosignatureRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%x%s",
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ )),
+ wantReq: &CosignatureRequest{
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ } {
+ var req CosignatureRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
diff --git a/pkg/types/trunnel.go b/pkg/types/trunnel.go
new file mode 100644
index 0000000..268f6f7
--- /dev/null
+++ b/pkg/types/trunnel.go
@@ -0,0 +1,60 @@
+package types
+
+import (
+ "encoding/binary"
+ "fmt"
+)
+
+const (
+ // MessageSize is the number of bytes in a Trunnel-encoded leaf message
+ MessageSize = 8 + HashSize
+ // LeafSize is the number of bytes in a Trunnel-encoded leaf
+ LeafSize = MessageSize + SignatureSize + HashSize
+)
+
+// Marshal returns a Trunnel-encoded message
+func (m *Message) Marshal() []byte {
+ buf := make([]byte, MessageSize)
+ binary.BigEndian.PutUint64(buf, m.ShardHint)
+ copy(buf[8:], m.Checksum[:])
+ return buf
+}
+
+// Marshal returns a Trunnel-encoded leaf
+func (l *Leaf) Marshal() []byte {
+ buf := l.Message.Marshal()
+ buf = append(buf, l.SigIdent.Signature[:]...)
+ buf = append(buf, l.SigIdent.KeyHash[:]...)
+ return buf
+}
+
+// Marshal returns a Trunnel-encoded tree head
+func (th *TreeHead) Marshal() []byte {
+ buf := make([]byte, 8+8+HashSize)
+ binary.BigEndian.PutUint64(buf[0:8], th.Timestamp)
+ binary.BigEndian.PutUint64(buf[8:16], th.TreeSize)
+ copy(buf[16:], th.RootHash[:])
+ return buf
+}
+
+// Unmarshal parses the Trunnel-encoded buffer as a leaf
+func (l *Leaf) Unmarshal(buf []byte) error {
+ if len(buf) != LeafSize {
+ return fmt.Errorf("invalid leaf size: %v", len(buf))
+ }
+ // Shard hint
+ l.ShardHint = binary.BigEndian.Uint64(buf)
+ offset := 8
+ // Checksum
+ l.Checksum = &[HashSize]byte{}
+ copy(l.Checksum[:], buf[offset:offset+HashSize])
+ offset += HashSize
+ // Signature
+ l.Signature = &[SignatureSize]byte{}
+ copy(l.Signature[:], buf[offset:offset+SignatureSize])
+ offset += SignatureSize
+ // KeyHash
+ l.KeyHash = &[HashSize]byte{}
+ copy(l.KeyHash[:], buf[offset:])
+ return nil
+}
diff --git a/pkg/types/trunnel_test.go b/pkg/types/trunnel_test.go
new file mode 100644
index 0000000..297578c
--- /dev/null
+++ b/pkg/types/trunnel_test.go
@@ -0,0 +1,114 @@
+package types
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+)
+
+var (
+ testBuffer32 = &[32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
+ testBuffer64 = &[64]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}
+)
+
+func TestMarshalMessage(t *testing.T) {
+ description := "valid: shard hint 72623859790382856, checksum 0x00,0x01,..."
+ message := &Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:],
+ }, nil)
+ if got := message.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got message\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestMarshalLeaf(t *testing.T) {
+ description := "valid: shard hint 72623859790382856, buffers 0x00,0x01,..."
+ leaf := &Leaf{
+ Message: Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:], testBuffer64[:], testBuffer32[:],
+ }, nil)
+ if got := leaf.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestMarshalTreeHead(t *testing.T) {
+ description := "valid: timestamp 16909060, tree size 72623859790382856, root hash 0x00,0x01,..."
+ th := &TreeHead{
+ Timestamp: 16909060,
+ TreeSize: 72623859790382856,
+ RootHash: testBuffer32,
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04},
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:],
+ }, nil)
+ if got := th.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestUnmarshalLeaf(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ serialized []byte
+ wantErr bool
+ want *Leaf
+ }{
+ {
+ description: "invalid: not enough bytes",
+ serialized: make([]byte, LeafSize-1),
+ wantErr: true,
+ },
+ {
+ description: "invalid: too many bytes",
+ serialized: make([]byte, LeafSize+1),
+ wantErr: true,
+ },
+ {
+ description: "valid: shard hint 72623859790382856, buffers 0x00,0x01,...",
+ serialized: bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:], testBuffer64[:], testBuffer32[:],
+ }, nil),
+ want: &Leaf{
+ Message: Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ } {
+ var leaf Leaf
+ err := leaf.Unmarshal(table.serialized)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &leaf, table.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, table.description)
+ }
+ }
+}
diff --git a/pkg/types/types.go b/pkg/types/types.go
new file mode 100644
index 0000000..9ca7db8
--- /dev/null
+++ b/pkg/types/types.go
@@ -0,0 +1,155 @@
+package types
+
+import (
+ "crypto"
+ "crypto/ed25519"
+ "crypto/sha256"
+ "fmt"
+ "strings"
+)
+
+const (
+ HashSize = sha256.Size
+ SignatureSize = ed25519.SignatureSize
+ VerificationKeySize = ed25519.PublicKeySize
+
+ EndpointAddLeaf = Endpoint("add-leaf")
+ EndpointAddCosignature = Endpoint("add-cosignature")
+ EndpointGetTreeHeadLatest = Endpoint("get-tree-head-latest")
+ EndpointGetTreeHeadToSign = Endpoint("get-tree-head-to-sign")
+ EndpointGetTreeHeadCosigned = Endpoint("get-tree-head-cosigned")
+ EndpointGetProofByHash = Endpoint("get-proof-by-hash")
+ EndpointGetConsistencyProof = Endpoint("get-consistency-proof")
+ EndpointGetLeaves = Endpoint("get-leaves")
+)
+
+// Endpoint is a named HTTP API endpoint
+type Endpoint string
+
+// Path joins a number of components to form a full endpoint path. For example,
+// EndpointAddLeaf.Path("example.com", "st/v0") -> example.com/st/v0/add-leaf.
+func (e Endpoint) Path(components ...string) string {
+ return strings.Join(append(components, string(e)), "/")
+}
+
+// Leaf is the log's Merkle tree leaf.
+type Leaf struct {
+ Message
+ SigIdent
+}
+
+// Message is composed of a shard hint and a checksum. The submitter selects
+// these values to fit the log's shard interval and the opaque data in question.
+type Message struct {
+ ShardHint uint64
+ Checksum *[HashSize]byte
+}
+
+// SigIdent is composed of a signature-signer pair. The signature is computed
+// over the Trunnel-serialized leaf message. KeyHash identifies the signer.
+type SigIdent struct {
+ Signature *[SignatureSize]byte
+ KeyHash *[HashSize]byte
+}
+
+// SignedTreeHead is composed of a tree head and a list of signature-signer
+// pairs. Each signature is computed over the Trunnel-serialized tree head.
+type SignedTreeHead struct {
+ TreeHead
+ SigIdent []*SigIdent
+}
+
+// TreeHead is the log's tree head.
+type TreeHead struct {
+ Timestamp uint64
+ TreeSize uint64
+ RootHash *[HashSize]byte
+}
+
+// ConsistencyProof is a consistency proof that proves the log's append-only
+// property.
+type ConsistencyProof struct {
+ NewSize uint64
+ OldSize uint64
+ Path []*[HashSize]byte
+}
+
+// InclusionProof is an inclusion proof that proves a leaf is included in the
+// log.
+type InclusionProof struct {
+ TreeSize uint64
+ LeafIndex uint64
+ Path []*[HashSize]byte
+}
+
+// LeafList is a list of leaves
+type LeafList []*Leaf
+
+// ConsistencyProofRequest is a get-consistency-proof request
+type ConsistencyProofRequest struct {
+ NewSize uint64
+ OldSize uint64
+}
+
+// InclusionProofRequest is a get-proof-by-hash request
+type InclusionProofRequest struct {
+ LeafHash *[HashSize]byte
+ TreeSize uint64
+}
+
+// LeavesRequest is a get-leaves request
+type LeavesRequest struct {
+ StartSize uint64
+ EndSize uint64
+}
+
+// LeafRequest is an add-leaf request
+type LeafRequest struct {
+ Message
+ Signature *[SignatureSize]byte
+ VerificationKey *[VerificationKeySize]byte
+ DomainHint string
+}
+
+// CosignatureRequest is an add-cosignature request
+type CosignatureRequest struct {
+ SigIdent
+}
+
+// Sign signs the tree head using the log's signature scheme
+func (th *TreeHead) Sign(signer crypto.Signer) (*SignedTreeHead, error) {
+ sig, err := signer.Sign(nil, th.Marshal(), crypto.Hash(0))
+ if err != nil {
+ return nil, fmt.Errorf("Sign: %v", err)
+ }
+
+ sigident := SigIdent{
+ KeyHash: Hash(signer.Public().(ed25519.PublicKey)[:]),
+ Signature: &[SignatureSize]byte{},
+ }
+ copy(sigident.Signature[:], sig)
+ return &SignedTreeHead{
+ TreeHead: *th,
+ SigIdent: []*SigIdent{
+ &sigident,
+ },
+ }, nil
+}
+
+// Verify verifies the tree head signature using the log's signature scheme
+func (th *TreeHead) Verify(vk *[VerificationKeySize]byte, sig *[SignatureSize]byte) error {
+ if !ed25519.Verify(ed25519.PublicKey(vk[:]), th.Marshal(), sig[:]) {
+ return fmt.Errorf("invalid tree head signature")
+ }
+ return nil
+}
+
+// Verify checks if a leaf is included in the log
+func (p *InclusionProof) Verify(leaf *Leaf, th *TreeHead) error { // TODO
+ return nil
+}
+
+// Verify checks if two tree heads are consistent
+func (p *ConsistencyProof) Verify(oldTH, newTH *TreeHead) error { // TODO
+ return nil
+}
diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go
new file mode 100644
index 0000000..da89c59
--- /dev/null
+++ b/pkg/types/types_test.go
@@ -0,0 +1,58 @@
+package types
+
+import (
+ "testing"
+)
+
+func TestEndpointPath(t *testing.T) {
+ base, prefix, proto := "example.com", "log", "st/v0"
+ for _, table := range []struct {
+ endpoint Endpoint
+ want string
+ }{
+ {
+ endpoint: EndpointAddLeaf,
+ want: "example.com/log/st/v0/add-leaf",
+ },
+ {
+ endpoint: EndpointAddCosignature,
+ want: "example.com/log/st/v0/add-cosignature",
+ },
+ {
+ endpoint: EndpointGetTreeHeadLatest,
+ want: "example.com/log/st/v0/get-tree-head-latest",
+ },
+ {
+ endpoint: EndpointGetTreeHeadToSign,
+ want: "example.com/log/st/v0/get-tree-head-to-sign",
+ },
+ {
+ endpoint: EndpointGetTreeHeadCosigned,
+ want: "example.com/log/st/v0/get-tree-head-cosigned",
+ },
+ {
+ endpoint: EndpointGetConsistencyProof,
+ want: "example.com/log/st/v0/get-consistency-proof",
+ },
+ {
+ endpoint: EndpointGetProofByHash,
+ want: "example.com/log/st/v0/get-proof-by-hash",
+ },
+ {
+ endpoint: EndpointGetLeaves,
+ want: "example.com/log/st/v0/get-leaves",
+ },
+ } {
+ if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want {
+ t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want)
+ }
+ if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want {
+ t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want)
+ }
+ }
+}
+
+func TestTreeHeadSign(t *testing.T) {}
+func TestTreeHeadVerify(t *testing.T) {}
+func TestInclusionProofVerify(t *testing.T) {}
+func TestConsistencyProofVerify(t *testing.T) {}
diff --git a/pkg/types/util.go b/pkg/types/util.go
new file mode 100644
index 0000000..3cd7dfa
--- /dev/null
+++ b/pkg/types/util.go
@@ -0,0 +1,21 @@
+package types
+
+import (
+ "crypto/sha256"
+)
+
+const (
+ LeafHashPrefix = 0x00
+)
+
+func Hash(buf []byte) *[HashSize]byte {
+ var ret [HashSize]byte
+ hash := sha256.New()
+ hash.Write(buf)
+ copy(ret[:], hash.Sum(nil))
+ return &ret
+}
+
+func HashLeaf(buf []byte) *[HashSize]byte {
+ return Hash(append([]byte{LeafHashPrefix}, buf...))
+}