aboutsummaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-06-07 00:19:40 +0200
committerRasmus Dahlberg <rasmus.dahlberg@kau.se>2021-06-07 00:19:40 +0200
commit932d29fd08c8ff401e471b4f764537493ccbd483 (patch)
treee840a4c62db92e84201fe9ceaa0594d99176792c /pkg
parentbdf7a53d61cf044e526cc9123ca296615f838288 (diff)
parent345fe658fa8a4306caa74f72a618e499343675c2 (diff)
Merge branch 'design' into main
Diffstat (limited to 'pkg')
-rw-r--r--pkg/instance/endpoint.go122
-rw-r--r--pkg/instance/endpoint_test.go432
-rw-r--r--pkg/instance/instance.go159
-rw-r--r--pkg/instance/instance_test.go9
-rw-r--r--pkg/instance/metric.go19
-rw-r--r--pkg/mocks/crypto.go23
-rw-r--r--pkg/mocks/state.go107
-rw-r--r--pkg/mocks/stfe.go110
-rw-r--r--pkg/mocks/trillian.go317
-rw-r--r--pkg/state/state_manager.go154
-rw-r--r--pkg/state/state_manager_test.go393
-rw-r--r--pkg/trillian/client.go178
-rw-r--r--pkg/trillian/client_test.go533
-rw-r--r--pkg/trillian/util.go33
-rw-r--r--pkg/types/ascii.go421
-rw-r--r--pkg/types/ascii_test.go465
-rw-r--r--pkg/types/trunnel.go60
-rw-r--r--pkg/types/trunnel_test.go114
-rw-r--r--pkg/types/types.go155
-rw-r--r--pkg/types/types_test.go58
-rw-r--r--pkg/types/util.go21
21 files changed, 3883 insertions, 0 deletions
diff --git a/pkg/instance/endpoint.go b/pkg/instance/endpoint.go
new file mode 100644
index 0000000..5085c49
--- /dev/null
+++ b/pkg/instance/endpoint.go
@@ -0,0 +1,122 @@
+package stfe
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/golang/glog"
+)
+
+func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling add-entry request")
+ req, err := i.leafRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+ if err := i.Client.AddLeaf(ctx, req); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func addCosignature(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling add-cosignature request")
+ req, err := i.cosignatureRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+ vk := i.Witnesses[*req.KeyHash]
+ if err := i.Stateman.AddCosignature(ctx, &vk, req.Signature); err != nil {
+ return http.StatusBadRequest, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadLatest(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-latest request")
+ sth, err := i.Stateman.Latest(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadToSign(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-to-sign request")
+ sth, err := i.Stateman.ToSign(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getTreeHeadCosigned(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) {
+ glog.V(3).Info("handling get-tree-head-cosigned request")
+ sth, err := i.Stateman.Cosigned(ctx)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := sth.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-consistency-proof request")
+ req, err := i.consistencyProofRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ proof, err := i.Client.GetConsistencyProof(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := proof.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getInclusionProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-proof-by-hash request")
+ req, err := i.inclusionProofRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ proof, err := i.Client.GetInclusionProof(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ if err := proof.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ return http.StatusOK, nil
+}
+
+func getLeaves(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) {
+ glog.V(3).Info("handling get-leaves request")
+ req, err := i.leavesRequestFromHTTP(r)
+ if err != nil {
+ return http.StatusBadRequest, err
+ }
+
+ leaves, err := i.Client.GetLeaves(ctx, req)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ for _, leaf := range *leaves {
+ if err := leaf.MarshalASCII(w); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ }
+ return http.StatusOK, nil
+}
diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go
new file mode 100644
index 0000000..efcd4c0
--- /dev/null
+++ b/pkg/instance/endpoint_test.go
@@ -0,0 +1,432 @@
+package stfe
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/system-transparency/stfe/pkg/mocks"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+var (
+ testWitVK = [types.VerificationKeySize]byte{}
+ testConfig = Config{
+ LogID: hex.EncodeToString(types.Hash([]byte("logid"))[:]),
+ TreeID: 0,
+ Prefix: "testonly",
+ MaxRange: 3,
+ Deadline: 10,
+ Interval: 10,
+ Witnesses: map[[types.HashSize]byte][types.VerificationKeySize]byte{
+ *types.Hash(testWitVK[:]): testWitVK,
+ },
+ }
+ testSTH = &types.SignedTreeHead{
+ TreeHead: types.TreeHead{
+ Timestamp: 0,
+ TreeSize: 0,
+ RootHash: types.Hash(nil),
+ },
+ SigIdent: []*types.SigIdent{
+ &types.SigIdent{
+ Signature: &[types.SignatureSize]byte{},
+ KeyHash: &[types.HashSize]byte{},
+ },
+ },
+ }
+)
+
+func mustHandle(t *testing.T, i Instance, e types.Endpoint) Handler {
+ for _, handler := range i.Handlers() {
+ if handler.Endpoint == e {
+ return handler
+ }
+ }
+ t.Fatalf("must handle endpoint: %v", e)
+ return Handler{}
+}
+
+func TestAddLeaf(t *testing.T) {
+ buf := func() io.Reader {
+ // A valid leaf request that was created manually
+ return bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s",
+ types.ShardHint, types.Delim, "0", types.EOL,
+ types.Checksum, types.Delim, "0000000000000000000000000000000000000000000000000000000000000000", types.EOL,
+ types.SignatureOverMessage, types.Delim, "4cb410a4d48f52f761a7c01abcc28fd71811b84ded5403caed5e21b374f6aac9637cecd36828f17529fd503413d30ab66d7bb37a31dbf09a90d23b9241c45009", types.EOL,
+ types.VerificationKey, types.Delim, "f2b7a00b625469d32502e06e8b7fad1ef258d4ad0c6cd87b846142ab681957d5", types.EOL,
+ types.DomainHint, types.Delim, "example.com", types.EOL,
+ ))
+ }
+ for _, table := range []struct {
+ description string
+ ascii io.Reader // buffer used to populate HTTP request
+ expect bool // set if a mock answer is expected
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: bad request (parser error)",
+ ascii: bytes.NewBufferString("key=value\n"),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad request (signature error)",
+ ascii: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s",
+ types.ShardHint, types.Delim, "1", types.EOL,
+ types.Checksum, types.Delim, "1111111111111111111111111111111111111111111111111111111111111111", types.EOL,
+ types.SignatureOverMessage, types.Delim, "4cb410a4d48f52f761a7c01abcc28fd71811b84ded5403caed5e21b374f6aac9637cecd36828f17529fd503413d30ab66d7bb37a31dbf09a90d23b9241c45009", types.EOL,
+ types.VerificationKey, types.Delim, "f2b7a00b625469d32502e06e8b7fad1ef258d4ad0c6cd87b846142ab681957d5", types.EOL,
+ types.DomainHint, types.Delim, "example.com", types.EOL,
+ )),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: backend failure",
+ ascii: buf(),
+ expect: true,
+ err: fmt.Errorf("something went wrong"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ ascii: buf(),
+ expect: true,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ if table.expect {
+ client.EXPECT().AddLeaf(gomock.Any(), gomock.Any()).Return(table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Client: client,
+ }
+
+ // Create HTTP request
+ url := types.EndpointAddLeaf.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("POST", url, table.ascii)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointAddLeaf).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestAddCosignature(t *testing.T) {
+ buf := func() io.Reader {
+ return bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%x%s",
+ types.Signature, types.Delim, make([]byte, types.SignatureSize), types.EOL,
+ types.KeyHash, types.Delim, *types.Hash(testWitVK[:]), types.EOL,
+ ))
+ }
+ for _, table := range []struct {
+ description string
+ ascii io.Reader // buffer used to populate HTTP request
+ expect bool // set if a mock answer is expected
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: bad request (parser error)",
+ ascii: bytes.NewBufferString("key=value\n"),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: bad request (unknown witness)",
+ ascii: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%x%s",
+ types.Signature, types.Delim, make([]byte, types.SignatureSize), types.EOL,
+ types.KeyHash, types.Delim, *types.Hash(testWitVK[1:]), types.EOL,
+ )),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "invalid: backend failure",
+ ascii: buf(),
+ expect: true,
+ err: fmt.Errorf("something went wrong"),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "valid",
+ ascii: buf(),
+ expect: true,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ stateman := mocks.NewMockStateManager(ctrl)
+ if table.expect {
+ stateman.EXPECT().AddCosignature(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Stateman: stateman,
+ }
+
+ // Create HTTP request
+ url := types.EndpointAddCosignature.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("POST", url, table.ascii)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointAddCosignature).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetTreeHeadLatest(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ expect bool // set if a mock answer is expected
+ rsp *types.SignedTreeHead // signed tree head from Trillian client
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: backend failure",
+ expect: true,
+ err: fmt.Errorf("something went wrong"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ expect: true,
+ rsp: testSTH,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ stateman := mocks.NewMockStateManager(ctrl)
+ if table.expect {
+ stateman.EXPECT().Latest(gomock.Any()).Return(table.rsp, table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Stateman: stateman,
+ }
+
+ // Create HTTP request
+ url := types.EndpointGetTreeHeadLatest.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointGetTreeHeadLatest).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetTreeToSign(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ expect bool // set if a mock answer is expected
+ rsp *types.SignedTreeHead // signed tree head from Trillian client
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: backend failure",
+ expect: true,
+ err: fmt.Errorf("something went wrong"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ expect: true,
+ rsp: testSTH,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ stateman := mocks.NewMockStateManager(ctrl)
+ if table.expect {
+ stateman.EXPECT().ToSign(gomock.Any()).Return(table.rsp, table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Stateman: stateman,
+ }
+
+ // Create HTTP request
+ url := types.EndpointGetTreeHeadToSign.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointGetTreeHeadToSign).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetTreeCosigned(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ expect bool // set if a mock answer is expected
+ rsp *types.SignedTreeHead // signed tree head from Trillian client
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: backend failure",
+ expect: true,
+ err: fmt.Errorf("something went wrong"),
+ wantCode: http.StatusInternalServerError,
+ },
+ {
+ description: "valid",
+ expect: true,
+ rsp: testSTH,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ stateman := mocks.NewMockStateManager(ctrl)
+ if table.expect {
+ stateman.EXPECT().Cosigned(gomock.Any()).Return(table.rsp, table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Stateman: stateman,
+ }
+
+ // Create HTTP request
+ url := types.EndpointGetTreeHeadCosigned.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointGetTreeHeadCosigned).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetConsistencyProof(t *testing.T) {
+ buf := func(oldSize, newSize int) io.Reader {
+ return bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s",
+ types.OldSize, types.Delim, oldSize, types.EOL,
+ types.NewSize, types.Delim, newSize, types.EOL,
+ ))
+ }
+ // values in testProof are not relevant for the test, just need a path
+ testProof := &types.ConsistencyProof{
+ OldSize: 1,
+ NewSize: 2,
+ Path: []*[types.HashSize]byte{
+ types.Hash(nil),
+ },
+ }
+ for _, table := range []struct {
+ description string
+ ascii io.Reader // buffer used to populate HTTP request
+ expect bool // set if a mock answer is expected
+ rsp *types.ConsistencyProof // consistency proof from Trillian client
+ err error // error from Trillian client
+ wantCode int // HTTP status ok
+ }{
+ {
+ description: "invalid: bad request (parser error)",
+ ascii: bytes.NewBufferString("key=value\n"),
+ wantCode: http.StatusBadRequest,
+ },
+ {
+ description: "valid",
+ ascii: buf(1, 2),
+ expect: true,
+ rsp: testProof,
+ wantCode: http.StatusOK,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ if table.expect {
+ client.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ }
+ i := Instance{
+ Config: testConfig,
+ Client: client,
+ }
+
+ // Create HTTP request
+ url := types.EndpointGetConsistencyProof.Path("http://example.com", i.Prefix)
+ req, err := http.NewRequest("POST", url, table.ascii)
+ if err != nil {
+ t.Fatalf("must create http request: %v", err)
+ }
+
+ // Run HTTP request
+ w := httptest.NewRecorder()
+ mustHandle(t, i, types.EndpointGetConsistencyProof).ServeHTTP(w, req)
+ if got, want := w.Code, table.wantCode; got != want {
+ t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetInclusionProof(t *testing.T) {
+}
+
+func TestGetLeaves(t *testing.T) {
+}
diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go
new file mode 100644
index 0000000..c2fe8fa
--- /dev/null
+++ b/pkg/instance/instance.go
@@ -0,0 +1,159 @@
+package stfe
+
+import (
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/system-transparency/stfe/pkg/state"
+ "github.com/system-transparency/stfe/pkg/trillian"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+// Config is a collection of log parameters
+type Config struct {
+ LogID string // H(public key), then hex-encoded
+ TreeID int64 // Merkle tree identifier used by Trillian
+ Prefix string // The portion between base URL and st/v0 (may be "")
+ MaxRange int64 // Maximum number of leaves per get-leaves request
+ Deadline time.Duration // Deadline used for gRPC requests
+ Interval time.Duration // Cosigning frequency
+
+ // Witnesses map trusted witness identifiers to public verification keys
+ Witnesses map[[types.HashSize]byte][types.VerificationKeySize]byte
+}
+
+// Instance is an instance of the log's front-end
+type Instance struct {
+ Config // configuration parameters
+ Client trillian.Client // provides access to the Trillian backend
+ Signer crypto.Signer // provides access to Ed25519 private key
+ Stateman state.StateManager // coordinates access to (co)signed tree heads
+}
+
+// Handler implements the http.Handler interface, and contains a reference
+// to an STFE server instance as well as a function that uses it.
+type Handler struct {
+ Instance *Instance
+ Endpoint types.Endpoint
+ Method string
+ Handler func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error)
+}
+
+// Handlers returns a list of STFE handlers
+func (i *Instance) Handlers() []Handler {
+ return []Handler{
+ Handler{Instance: i, Handler: addLeaf, Endpoint: types.EndpointAddLeaf, Method: http.MethodPost},
+ Handler{Instance: i, Handler: addCosignature, Endpoint: types.EndpointAddCosignature, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getTreeHeadLatest, Endpoint: types.EndpointGetTreeHeadLatest, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getTreeHeadToSign, Endpoint: types.EndpointGetTreeHeadToSign, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getTreeHeadCosigned, Endpoint: types.EndpointGetTreeHeadCosigned, Method: http.MethodGet},
+ Handler{Instance: i, Handler: getConsistencyProof, Endpoint: types.EndpointGetConsistencyProof, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getInclusionProof, Endpoint: types.EndpointGetProofByHash, Method: http.MethodPost},
+ Handler{Instance: i, Handler: getLeaves, Endpoint: types.EndpointGetLeaves, Method: http.MethodPost},
+ }
+}
+
+// Path returns a path that should be configured for this handler
+func (h Handler) Path() string {
+ return h.Endpoint.Path(h.Instance.Prefix, "st", "v0")
+}
+
+// ServeHTTP is part of the http.Handler interface
+func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // export prometheus metrics
+ var now time.Time = time.Now()
+ var statusCode int
+ defer func() {
+ rspcnt.Inc(a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode))
+ latency.Observe(time.Now().Sub(now).Seconds(), a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode))
+ }()
+ reqcnt.Inc(a.Instance.LogID, string(a.Endpoint))
+
+ ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline))
+ defer cancel()
+
+ if r.Method != a.Method {
+ glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method)
+ http.Error(w, "", http.StatusMethodNotAllowed)
+ return
+ }
+
+ statusCode, err := a.Handler(ctx, a.Instance, w, r)
+ if err != nil {
+ glog.Warningf("handler error %s/%s: %v", a.Instance.Prefix, a.Endpoint, err)
+ http.Error(w, fmt.Sprintf("%s%s%s%s", "Error", types.Delim, err.Error(), types.EOL), statusCode)
+ }
+}
+
+func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, error) {
+ var req types.LeafRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+
+ vk := ed25519.PublicKey(req.VerificationKey[:])
+ msg := req.Message.Marshal()
+ sig := req.Signature[:]
+ if !ed25519.Verify(vk, msg, sig) {
+ return nil, fmt.Errorf("invalid signature")
+ }
+ // TODO: check shard hint
+ // TODO: check domain hint
+ return &req, nil
+}
+
+func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) {
+ var req types.CosignatureRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("unpackOctetPost: %v", err)
+ }
+ if _, ok := i.Witnesses[*req.KeyHash]; !ok {
+ return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash)
+ }
+ return &req, nil
+}
+
+func (i *Instance) consistencyProofRequestFromHTTP(r *http.Request) (*types.ConsistencyProofRequest, error) {
+ var req types.ConsistencyProofRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+ if req.OldSize < 1 {
+ return nil, fmt.Errorf("OldSize(%d) must be larger than zero", req.OldSize)
+ }
+ if req.NewSize <= req.OldSize {
+ return nil, fmt.Errorf("NewSize(%d) must be larger than OldSize(%d)", req.NewSize, req.OldSize)
+ }
+ return &req, nil
+}
+
+func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.InclusionProofRequest, error) {
+ var req types.InclusionProofRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+ if req.TreeSize < 1 {
+ return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize)
+ }
+ return &req, nil
+}
+
+func (i *Instance) leavesRequestFromHTTP(r *http.Request) (*types.LeavesRequest, error) {
+ var req types.LeavesRequest
+ if err := req.UnmarshalASCII(r.Body); err != nil {
+ return nil, fmt.Errorf("UnmarshalASCII: %v", err)
+ }
+
+ if req.StartSize > req.EndSize {
+ return nil, fmt.Errorf("StartSize(%d) must be less than or equal to EndSize(%d)", req.StartSize, req.EndSize)
+ }
+ if req.EndSize-req.StartSize+1 > uint64(i.MaxRange) {
+ req.EndSize = req.StartSize + uint64(i.MaxRange) - 1
+ }
+ return &req, nil
+}
diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go
new file mode 100644
index 0000000..45a2837
--- /dev/null
+++ b/pkg/instance/instance_test.go
@@ -0,0 +1,9 @@
+package stfe
+
+import (
+ "testing"
+)
+
+func TestHandlers(t *testing.T) {}
+func TestPath(t *testing.T) {}
+func TestServeHTTP(t *testing.T) {}
diff --git a/pkg/instance/metric.go b/pkg/instance/metric.go
new file mode 100644
index 0000000..db11bd2
--- /dev/null
+++ b/pkg/instance/metric.go
@@ -0,0 +1,19 @@
+package stfe
+
+import (
+ "github.com/google/trillian/monitoring"
+ "github.com/google/trillian/monitoring/prometheus"
+)
+
+var (
+ reqcnt monitoring.Counter // number of incoming http requests
+ rspcnt monitoring.Counter // number of valid http responses
+ latency monitoring.Histogram // request-response latency
+)
+
+func init() {
+ mf := prometheus.MetricFactory{}
+ reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint")
+ rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status")
+ latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status")
+}
diff --git a/pkg/mocks/crypto.go b/pkg/mocks/crypto.go
new file mode 100644
index 0000000..87c883a
--- /dev/null
+++ b/pkg/mocks/crypto.go
@@ -0,0 +1,23 @@
+package mocks
+
+import (
+ "crypto"
+ "crypto/ed25519"
+ "io"
+)
+
+// TestSign implements the signer interface. It can be used to mock an Ed25519
+// signer that always return the same public key, signature, and error.
+type TestSigner struct {
+ PublicKey *[ed25519.PublicKeySize]byte
+ Signature *[ed25519.SignatureSize]byte
+ Error error
+}
+
+func (ts *TestSigner) Public() crypto.PublicKey {
+ return ed25519.PublicKey(ts.PublicKey[:])
+}
+
+func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
+ return ts.Signature[:], ts.Error
+}
diff --git a/pkg/mocks/state.go b/pkg/mocks/state.go
new file mode 100644
index 0000000..41d8d08
--- /dev/null
+++ b/pkg/mocks/state.go
@@ -0,0 +1,107 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/system-transparency/stfe/pkg/state (interfaces: StateManager)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ types "github.com/system-transparency/stfe/pkg/types"
+)
+
+// MockStateManager is a mock of StateManager interface.
+type MockStateManager struct {
+ ctrl *gomock.Controller
+ recorder *MockStateManagerMockRecorder
+}
+
+// MockStateManagerMockRecorder is the mock recorder for MockStateManager.
+type MockStateManagerMockRecorder struct {
+ mock *MockStateManager
+}
+
+// NewMockStateManager creates a new mock instance.
+func NewMockStateManager(ctrl *gomock.Controller) *MockStateManager {
+ mock := &MockStateManager{ctrl: ctrl}
+ mock.recorder = &MockStateManagerMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockStateManager) EXPECT() *MockStateManagerMockRecorder {
+ return m.recorder
+}
+
+// AddCosignature mocks base method.
+func (m *MockStateManager) AddCosignature(arg0 context.Context, arg1 *[32]byte, arg2 *[64]byte) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AddCosignature", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// AddCosignature indicates an expected call of AddCosignature.
+func (mr *MockStateManagerMockRecorder) AddCosignature(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddCosignature", reflect.TypeOf((*MockStateManager)(nil).AddCosignature), arg0, arg1, arg2)
+}
+
+// Cosigned mocks base method.
+func (m *MockStateManager) Cosigned(arg0 context.Context) (*types.SignedTreeHead, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Cosigned", arg0)
+ ret0, _ := ret[0].(*types.SignedTreeHead)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Cosigned indicates an expected call of Cosigned.
+func (mr *MockStateManagerMockRecorder) Cosigned(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cosigned", reflect.TypeOf((*MockStateManager)(nil).Cosigned), arg0)
+}
+
+// Latest mocks base method.
+func (m *MockStateManager) Latest(arg0 context.Context) (*types.SignedTreeHead, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Latest", arg0)
+ ret0, _ := ret[0].(*types.SignedTreeHead)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Latest indicates an expected call of Latest.
+func (mr *MockStateManagerMockRecorder) Latest(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Latest", reflect.TypeOf((*MockStateManager)(nil).Latest), arg0)
+}
+
+// Run mocks base method.
+func (m *MockStateManager) Run(arg0 context.Context) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "Run", arg0)
+}
+
+// Run indicates an expected call of Run.
+func (mr *MockStateManagerMockRecorder) Run(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockStateManager)(nil).Run), arg0)
+}
+
+// ToSign mocks base method.
+func (m *MockStateManager) ToSign(arg0 context.Context) (*types.SignedTreeHead, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ToSign", arg0)
+ ret0, _ := ret[0].(*types.SignedTreeHead)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// ToSign indicates an expected call of ToSign.
+func (mr *MockStateManagerMockRecorder) ToSign(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToSign", reflect.TypeOf((*MockStateManager)(nil).ToSign), arg0)
+}
diff --git a/pkg/mocks/stfe.go b/pkg/mocks/stfe.go
new file mode 100644
index 0000000..def5bc6
--- /dev/null
+++ b/pkg/mocks/stfe.go
@@ -0,0 +1,110 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/system-transparency/stfe/trillian (interfaces: Client)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ types "github.com/system-transparency/stfe/pkg/types"
+)
+
+// MockClient is a mock of Client interface.
+type MockClient struct {
+ ctrl *gomock.Controller
+ recorder *MockClientMockRecorder
+}
+
+// MockClientMockRecorder is the mock recorder for MockClient.
+type MockClientMockRecorder struct {
+ mock *MockClient
+}
+
+// NewMockClient creates a new mock instance.
+func NewMockClient(ctrl *gomock.Controller) *MockClient {
+ mock := &MockClient{ctrl: ctrl}
+ mock.recorder = &MockClientMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockClient) EXPECT() *MockClientMockRecorder {
+ return m.recorder
+}
+
+// AddLeaf mocks base method.
+func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *types.LeafRequest) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// AddLeaf indicates an expected call of AddLeaf.
+func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1)
+}
+
+// GetConsistencyProof mocks base method.
+func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1)
+ ret0, _ := ret[0].(*types.ConsistencyProof)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetConsistencyProof indicates an expected call of GetConsistencyProof.
+func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1)
+}
+
+// GetInclusionProof mocks base method.
+func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *types.InclusionProofRequest) (*types.InclusionProof, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1)
+ ret0, _ := ret[0].(*types.InclusionProof)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProof indicates an expected call of GetInclusionProof.
+func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1)
+}
+
+// GetLeaves mocks base method.
+func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *types.LeavesRequest) (*types.LeafList, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1)
+ ret0, _ := ret[0].(*types.LeafList)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeaves indicates an expected call of GetLeaves.
+func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1)
+}
+
+// GetTreeHead mocks base method.
+func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetTreeHead", arg0)
+ ret0, _ := ret[0].(*types.TreeHead)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetTreeHead indicates an expected call of GetTreeHead.
+func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0)
+}
diff --git a/pkg/mocks/trillian.go b/pkg/mocks/trillian.go
new file mode 100644
index 0000000..8aa3a58
--- /dev/null
+++ b/pkg/mocks/trillian.go
@@ -0,0 +1,317 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/google/trillian (interfaces: TrillianLogClient)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ trillian "github.com/google/trillian"
+ grpc "google.golang.org/grpc"
+)
+
+// MockTrillianLogClient is a mock of TrillianLogClient interface.
+type MockTrillianLogClient struct {
+ ctrl *gomock.Controller
+ recorder *MockTrillianLogClientMockRecorder
+}
+
+// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient.
+type MockTrillianLogClientMockRecorder struct {
+ mock *MockTrillianLogClient
+}
+
+// NewMockTrillianLogClient creates a new mock instance.
+func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient {
+ mock := &MockTrillianLogClient{ctrl: ctrl}
+ mock.recorder = &MockTrillianLogClientMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder {
+ return m.recorder
+}
+
+// AddSequencedLeaf mocks base method.
+func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...)
+ ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// AddSequencedLeaf indicates an expected call of AddSequencedLeaf.
+func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...)
+}
+
+// AddSequencedLeaves mocks base method.
+func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...)
+ ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// AddSequencedLeaves indicates an expected call of AddSequencedLeaves.
+func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...)
+}
+
+// GetConsistencyProof mocks base method.
+func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetConsistencyProof indicates an expected call of GetConsistencyProof.
+func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...)
+}
+
+// GetEntryAndProof mocks base method.
+func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetEntryAndProof indicates an expected call of GetEntryAndProof.
+func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...)
+}
+
+// GetInclusionProof mocks base method.
+func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetInclusionProof", varargs...)
+ ret0, _ := ret[0].(*trillian.GetInclusionProofResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProof indicates an expected call of GetInclusionProof.
+func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...)
+}
+
+// GetInclusionProofByHash mocks base method.
+func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...)
+ ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash.
+func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...)
+}
+
+// GetLatestSignedLogRoot mocks base method.
+func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot.
+func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...)
+}
+
+// GetLeavesByHash mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByHash indicates an expected call of GetLeavesByHash.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...)
+}
+
+// GetLeavesByIndex mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByIndex indicates an expected call of GetLeavesByIndex.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...)
+}
+
+// GetLeavesByRange mocks base method.
+func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...)
+ ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLeavesByRange indicates an expected call of GetLeavesByRange.
+func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...)
+}
+
+// GetSequencedLeafCount mocks base method.
+func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...)
+ ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount.
+func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...)
+}
+
+// InitLog mocks base method.
+func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "InitLog", varargs...)
+ ret0, _ := ret[0].(*trillian.InitLogResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// InitLog indicates an expected call of InitLog.
+func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...)
+}
+
+// QueueLeaf mocks base method.
+func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "QueueLeaf", varargs...)
+ ret0, _ := ret[0].(*trillian.QueueLeafResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// QueueLeaf indicates an expected call of QueueLeaf.
+func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...)
+}
+
+// QueueLeaves mocks base method.
+func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "QueueLeaves", varargs...)
+ ret0, _ := ret[0].(*trillian.QueueLeavesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// QueueLeaves indicates an expected call of QueueLeaves.
+func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...)
+}
diff --git a/pkg/state/state_manager.go b/pkg/state/state_manager.go
new file mode 100644
index 0000000..dfa73f4
--- /dev/null
+++ b/pkg/state/state_manager.go
@@ -0,0 +1,154 @@
+package state
+
+import (
+ "context"
+ "crypto"
+ "fmt"
+ "reflect"
+ "sync"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/google/certificate-transparency-go/schedule"
+ "github.com/system-transparency/stfe/pkg/trillian"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+// StateManager coordinates access to the log's tree heads and (co)signatures
+type StateManager interface {
+ Latest(context.Context) (*types.SignedTreeHead, error)
+ ToSign(context.Context) (*types.SignedTreeHead, error)
+ Cosigned(context.Context) (*types.SignedTreeHead, error)
+ AddCosignature(context.Context, *[types.VerificationKeySize]byte, *[types.SignatureSize]byte) error
+ Run(context.Context)
+}
+
+// StateManagerSingle implements the StateManager interface. It is assumed that
+// the log server is running on a single-instance machine. So, no coordination.
+type StateManagerSingle struct {
+ client trillian.Client
+ signer crypto.Signer
+ interval time.Duration
+ deadline time.Duration
+ sync.RWMutex
+
+ // cosigned is the current cosigned tree head that is being served
+ cosigned types.SignedTreeHead
+
+ // tosign is the current tree head that is being cosigned by witnesses
+ tosign types.SignedTreeHead
+
+ // cosignature keeps track of all cosignatures for the tosign tree head
+ cosignature map[[types.HashSize]byte]*types.SigIdent
+}
+
+func NewStateManagerSingle(client trillian.Client, signer crypto.Signer, interval, deadline time.Duration) (*StateManagerSingle, error) {
+ sm := &StateManagerSingle{
+ client: client,
+ signer: signer,
+ interval: interval,
+ deadline: deadline,
+ }
+
+ ctx, _ := context.WithTimeout(context.Background(), sm.deadline)
+ sth, err := sm.Latest(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("Latest: %v", err)
+ }
+
+ sm.cosigned = *sth
+ sm.tosign = *sth
+ sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{
+ *sth.SigIdent[0].KeyHash: sth.SigIdent[0], // log signature
+ }
+ return sm, nil
+}
+
+func (sm *StateManagerSingle) Run(ctx context.Context) {
+ schedule.Every(ctx, sm.interval, func(ctx context.Context) {
+ ictx, _ := context.WithTimeout(ctx, sm.deadline)
+ nextTreeHead, err := sm.Latest(ictx)
+ if err != nil {
+ glog.Warningf("rotate failed: Latest: %v", err)
+ return
+ }
+
+ sm.Lock()
+ defer sm.Unlock()
+ sm.rotate(nextTreeHead)
+ })
+}
+
+func (sm *StateManagerSingle) Latest(ctx context.Context) (*types.SignedTreeHead, error) {
+ th, err := sm.client.GetTreeHead(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("LatestTreeHead: %v", err)
+ }
+ sth, err := th.Sign(sm.signer)
+ if err != nil {
+ return nil, fmt.Errorf("sign: %v", err)
+ }
+ return sth, nil
+}
+
+func (sm *StateManagerSingle) ToSign(_ context.Context) (*types.SignedTreeHead, error) {
+ sm.RLock()
+ defer sm.RUnlock()
+ return &sm.tosign, nil
+}
+
+func (sm *StateManagerSingle) Cosigned(_ context.Context) (*types.SignedTreeHead, error) {
+ sm.RLock()
+ defer sm.RUnlock()
+ return &sm.cosigned, nil
+}
+
+func (sm *StateManagerSingle) AddCosignature(_ context.Context, vk *[types.VerificationKeySize]byte, sig *[types.SignatureSize]byte) error {
+ sm.Lock()
+ defer sm.Unlock()
+
+ if err := sm.tosign.TreeHead.Verify(vk, sig); err != nil {
+ return fmt.Errorf("Verify: %v", err)
+ }
+ witness := types.Hash(vk[:])
+ if _, ok := sm.cosignature[*witness]; ok {
+ return fmt.Errorf("signature-signer pair is a duplicate")
+ }
+ sm.cosignature[*witness] = &types.SigIdent{
+ Signature: sig,
+ KeyHash: witness,
+ }
+
+ glog.V(3).Infof("accepted new cosignature from witness: %x", *witness)
+ return nil
+}
+
+// rotate rotates the log's cosigned and stable STH. The caller must aquire the
+// source's read-write lock if there are concurrent reads and/or writes.
+func (sm *StateManagerSingle) rotate(next *types.SignedTreeHead) {
+ if reflect.DeepEqual(sm.cosigned.TreeHead, sm.tosign.TreeHead) {
+ // cosigned and tosign are the same. So, we need to merge all
+ // cosignatures that we already had with the new collected ones.
+ for _, sigident := range sm.cosigned.SigIdent {
+ if _, ok := sm.cosignature[*sigident.KeyHash]; !ok {
+ sm.cosignature[*sigident.KeyHash] = sigident
+ }
+ }
+ glog.V(3).Infof("cosigned tree head repeated, merged signatures")
+ }
+ var cosignatures []*types.SigIdent
+ for _, sigident := range sm.cosignature {
+ cosignatures = append(cosignatures, sigident)
+ }
+
+ // Update cosigned tree head
+ sm.cosigned.TreeHead = sm.tosign.TreeHead
+ sm.cosigned.SigIdent = cosignatures
+
+ // Update to-sign tree head
+ sm.tosign = *next
+ sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{
+ *next.SigIdent[0].KeyHash: next.SigIdent[0], // log signature
+ }
+ glog.V(3).Infof("rotated tree heads")
+}
diff --git a/pkg/state/state_manager_test.go b/pkg/state/state_manager_test.go
new file mode 100644
index 0000000..08990cc
--- /dev/null
+++ b/pkg/state/state_manager_test.go
@@ -0,0 +1,393 @@
+package state
+
+import (
+ "bytes"
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "crypto/rand"
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/golang/mock/gomock"
+ "github.com/system-transparency/stfe/pkg/mocks"
+ "github.com/system-transparency/stfe/pkg/types"
+)
+
+var (
+ testSig = &[types.SignatureSize]byte{}
+ testPub = &[types.VerificationKeySize]byte{}
+ testTH = &types.TreeHead{
+ Timestamp: 0,
+ TreeSize: 0,
+ RootHash: types.Hash(nil),
+ }
+ testSigIdent = &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash(testPub[:]),
+ }
+ testSTH = &types.SignedTreeHead{
+ TreeHead: *testTH,
+ SigIdent: []*types.SigIdent{testSigIdent},
+ }
+ testSignerOK = &mocks.TestSigner{testPub, testSig, nil}
+ testSignerErr = &mocks.TestSigner{testPub, testSig, fmt.Errorf("something went wrong")}
+)
+
+func TestNewStateManagerSingle(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ rsp *types.TreeHead
+ err error
+ wantErr bool
+ wantSth *types.SignedTreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ signer: testSignerOK,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: testSignerOK,
+ rsp: testTH,
+ wantSth: testSTH,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err)
+
+ sm, err := NewStateManagerSingle(client, table.signer, time.Duration(0), time.Duration(0))
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := &sm.cosigned, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ if got, want := &sm.tosign, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ // we only have log signature on startup
+ if got, want := len(sm.cosignature), 1; got != want {
+ t.Errorf("got %d cosignatures but wanted %d in test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestLatest(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ rsp *types.TreeHead
+ err error
+ wantErr bool
+ wantSth *types.SignedTreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ signer: testSignerOK,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: signature failure",
+ rsp: testTH,
+ signer: testSignerErr,
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: testSignerOK,
+ rsp: testTH,
+ wantSth: testSTH,
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ client := mocks.NewMockClient(ctrl)
+ client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err)
+ sm := StateManagerSingle{
+ client: client,
+ signer: table.signer,
+ }
+
+ sth, err := sm.Latest(context.Background())
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestToSign(t *testing.T) {
+ description := "valid"
+ sm := StateManagerSingle{
+ tosign: *testSTH,
+ }
+ sth, err := sm.ToSign(context.Background())
+ if err != nil {
+ t.Errorf("ToSign should not fail with error: %v", err)
+ return
+ }
+ if got, want := sth, testSTH; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description)
+ }
+}
+
+func TestCosigned(t *testing.T) {
+ description := "valid"
+ sm := StateManagerSingle{
+ cosigned: *testSTH,
+ }
+ sth, err := sm.Cosigned(context.Background())
+ if err != nil {
+ t.Errorf("Cosigned should not fail with error: %v", err)
+ return
+ }
+ if got, want := sth, testSTH; !reflect.DeepEqual(got, want) {
+ t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description)
+ }
+}
+
+func TestAddCosignature(t *testing.T) {
+ vk, sk, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ t.Fatalf("GenerateKey: %v", err)
+ }
+ if bytes.Equal(vk[:], testPub[:]) {
+ t.Fatalf("Sampled same key as testPub, aborting...")
+ }
+ var vkArray [types.VerificationKeySize]byte
+ copy(vkArray[:], vk[:])
+
+ for _, table := range []struct {
+ description string
+ signer crypto.Signer
+ vk *[types.VerificationKeySize]byte
+ th *types.TreeHead
+ wantErr bool
+ }{
+ {
+ description: "invalid: signature error",
+ signer: sk,
+ vk: testPub, // wrong key for message
+ th: testTH,
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ signer: sk,
+ vk: &vkArray,
+ th: testTH,
+ },
+ } {
+ sth, _ := table.th.Sign(testSignerOK)
+ logKeyHash := sth.SigIdent[0].KeyHash
+ logSigIdent := sth.SigIdent[0]
+ sm := &StateManagerSingle{
+ signer: testSignerOK,
+ cosigned: *sth,
+ tosign: *sth,
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *logKeyHash: logSigIdent,
+ },
+ }
+
+ // Prepare witness signature
+ sth, err := table.th.Sign(table.signer)
+ if err != nil {
+ t.Fatalf("Sign: %v", err)
+ }
+ witnessKeyHash := sth.SigIdent[0].KeyHash
+ witnessSigIdent := sth.SigIdent[0]
+
+ // Add witness signature
+ err = sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+
+ // We should have two signatures (log + witness)
+ if got, want := len(sm.cosignature), 2; got != want {
+ t.Errorf("got %d cosignatures but wanted %v in test %q", got, want, table.description)
+ continue
+ }
+ // check that log signature is there
+ sigident, ok := sm.cosignature[*logKeyHash]
+ if !ok {
+ t.Errorf("log signature is missing")
+ continue
+ }
+ if got, want := sigident, logSigIdent; !reflect.DeepEqual(got, want) {
+ t.Errorf("got log sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ // check that witness signature is there
+ sigident, ok = sm.cosignature[*witnessKeyHash]
+ if !ok {
+ t.Errorf("witness signature is missing")
+ continue
+ }
+ if got, want := sigident, witnessSigIdent; !reflect.DeepEqual(got, want) {
+ t.Errorf("got witness sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ continue
+ }
+
+ // Adding a duplicate signature should give an error
+ if err := sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature); err == nil {
+ t.Errorf("duplicate witness signature accepted as valid")
+ }
+ }
+}
+
+func TestRotate(t *testing.T) {
+ log := testSigIdent
+ wit1 := &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash([]byte("wit1 key")),
+ }
+ wit2 := &types.SigIdent{
+ Signature: testSig,
+ KeyHash: types.Hash([]byte("wit2 key")),
+ }
+ th0 := testTH
+ th1 := &types.TreeHead{
+ Timestamp: 1,
+ TreeSize: 1,
+ RootHash: types.Hash([]byte("1")),
+ }
+ th2 := &types.TreeHead{
+ Timestamp: 2,
+ TreeSize: 2,
+ RootHash: types.Hash([]byte("2")),
+ }
+
+ for _, table := range []struct {
+ description string
+ before, after *StateManagerSingle
+ next *types.SignedTreeHead
+ }{
+ {
+ description: "tosign tree head repated, but got one new witnes signature",
+ before: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log,
+ *wit2.KeyHash: wit2, // the new witness signature
+ },
+ },
+ next: &types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ after: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1, wit2},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log, // after rotate we always have log sig
+ },
+ },
+ },
+ {
+ description: "tosign tree head did not repeat, it got one witness signature",
+ before: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th0,
+ SigIdent: []*types.SigIdent{log, wit1},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log,
+ *wit2.KeyHash: wit2, // the only witness that signed tosign
+ },
+ },
+ next: &types.SignedTreeHead{
+ TreeHead: *th2,
+ SigIdent: []*types.SigIdent{log},
+ },
+ after: &StateManagerSingle{
+ cosigned: types.SignedTreeHead{
+ TreeHead: *th1,
+ SigIdent: []*types.SigIdent{log, wit2},
+ },
+ tosign: types.SignedTreeHead{
+ TreeHead: *th2,
+ SigIdent: []*types.SigIdent{log},
+ },
+ cosignature: map[[types.HashSize]byte]*types.SigIdent{
+ *log.KeyHash: log, // after rotate we always have log sig
+ },
+ },
+ },
+ } {
+ table.before.rotate(table.next)
+ if got, want := table.before.cosigned.TreeHead, table.after.cosigned.TreeHead; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ checkWitnessList(t, table.description, table.before.cosigned.SigIdent, table.after.cosigned.SigIdent)
+ if got, want := table.before.tosign.TreeHead, table.after.tosign.TreeHead; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ checkWitnessList(t, table.description, table.before.tosign.SigIdent, table.after.tosign.SigIdent)
+ if got, want := table.before.cosignature, table.after.cosignature; !reflect.DeepEqual(got, want) {
+ t.Errorf("got cosignature map\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func checkWitnessList(t *testing.T, description string, got, want []*types.SigIdent) {
+ t.Helper()
+ for _, si := range got {
+ found := false
+ for _, sj := range want {
+ if reflect.DeepEqual(si, sj) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("got unexpected signature-signer pair with key hash in test %q: %x", description, si.KeyHash[:])
+ }
+ }
+ if len(got) != len(want) {
+ t.Errorf("got %d signature-signer pairs but wanted %d in test %q", len(got), len(want), description)
+ }
+}
diff --git a/pkg/trillian/client.go b/pkg/trillian/client.go
new file mode 100644
index 0000000..9523e56
--- /dev/null
+++ b/pkg/trillian/client.go
@@ -0,0 +1,178 @@
+package trillian
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/golang/glog"
+ "github.com/google/trillian"
+ ttypes "github.com/google/trillian/types"
+ "github.com/system-transparency/stfe/pkg/types"
+ "google.golang.org/grpc/codes"
+)
+
+type Client interface {
+ AddLeaf(context.Context, *types.LeafRequest) error
+ GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error)
+ GetTreeHead(context.Context) (*types.TreeHead, error)
+ GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error)
+ GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error)
+}
+
+// TrillianClient is a wrapper around the Trillian gRPC client.
+type TrillianClient struct {
+ // TreeID is a Merkle tree identifier that Trillian uses
+ TreeID int64
+
+ // GRPC is a Trillian gRPC client
+ GRPC trillian.TrillianLogClient
+}
+
+func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error {
+ leaf := types.Leaf{
+ Message: req.Message,
+ SigIdent: types.SigIdent{
+ Signature: req.Signature,
+ KeyHash: types.Hash(req.VerificationKey[:]),
+ },
+ }
+ serialized := leaf.Marshal()
+
+ glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized))
+ rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{
+ LogId: c.TreeID,
+ Leaf: &trillian.LogLeaf{
+ LeafValue: serialized,
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return fmt.Errorf("no response")
+ }
+ if rsp.QueuedLeaf == nil {
+ return fmt.Errorf("no queued leaf")
+ }
+ if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists {
+ return fmt.Errorf("leaf is already queued or included")
+ }
+ return nil
+}
+
+func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) {
+ rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{
+ LogId: c.TreeID,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if rsp.SignedLogRoot == nil {
+ return nil, fmt.Errorf("no signed log root")
+ }
+ if rsp.SignedLogRoot.LogRoot == nil {
+ return nil, fmt.Errorf("no log root")
+ }
+ var r ttypes.LogRootV1
+ if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil {
+ return nil, fmt.Errorf("no log root: unmarshal failed: %v", err)
+ }
+ if len(r.RootHash) != types.HashSize {
+ return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash))
+ }
+ return treeHeadFromLogRoot(&r), nil
+}
+
+func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) {
+ rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{
+ LogId: c.TreeID,
+ FirstTreeSize: int64(req.OldSize),
+ SecondTreeSize: int64(req.NewSize),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if rsp.Proof == nil {
+ return nil, fmt.Errorf("no consistency proof")
+ }
+ if len(rsp.Proof.Hashes) == 0 {
+ return nil, fmt.Errorf("not a consistency proof: empty")
+ }
+ path, err := nodePathFromHashes(rsp.Proof.Hashes)
+ if err != nil {
+ return nil, fmt.Errorf("not a consistency proof: %v", err)
+ }
+ return &types.ConsistencyProof{
+ OldSize: req.OldSize,
+ NewSize: req.NewSize,
+ Path: path,
+ }, nil
+}
+
+func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) {
+ rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{
+ LogId: c.TreeID,
+ LeafHash: req.LeafHash[:],
+ TreeSize: int64(req.TreeSize),
+ OrderBySequence: true,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if len(rsp.Proof) != 1 {
+ return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof))
+ }
+ proof := rsp.Proof[0]
+ if len(proof.Hashes) == 0 {
+ return nil, fmt.Errorf("not an inclusion proof: empty")
+ }
+ path, err := nodePathFromHashes(proof.Hashes)
+ if err != nil {
+ return nil, fmt.Errorf("not an inclusion proof: %v", err)
+ }
+ return &types.InclusionProof{
+ TreeSize: req.TreeSize,
+ LeafIndex: uint64(proof.LeafIndex),
+ Path: path,
+ }, nil
+}
+
+func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) {
+ rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{
+ LogId: c.TreeID,
+ StartIndex: int64(req.StartSize),
+ Count: int64(req.EndSize-req.StartSize) + 1,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("backend failure: %v", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("no response")
+ }
+ if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want {
+ return nil, fmt.Errorf("unexpected number of leaves: %d", got)
+ }
+ var list types.LeafList
+ for i, leaf := range rsp.Leaves {
+ leafIndex := int64(req.StartSize + uint64(i))
+ if leafIndex != leaf.LeafIndex {
+ return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex)
+ }
+
+ var l types.Leaf
+ if err := l.Unmarshal(leaf.LeafValue); err != nil {
+ return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err)
+ }
+ list = append(list[:], &l)
+ }
+ return &list, nil
+}
diff --git a/pkg/trillian/client_test.go b/pkg/trillian/client_test.go
new file mode 100644
index 0000000..6b3d881
--- /dev/null
+++ b/pkg/trillian/client_test.go
@@ -0,0 +1,533 @@
+package trillian
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/google/trillian"
+ ttypes "github.com/google/trillian/types"
+ "github.com/system-transparency/stfe/pkg/mocks"
+ "github.com/system-transparency/stfe/pkg/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+func TestAddLeaf(t *testing.T) {
+ req := &types.LeafRequest{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ Signature: &[types.SignatureSize]byte{},
+ VerificationKey: &[types.VerificationKeySize]byte{},
+ DomainHint: "example.com",
+ }
+ for _, table := range []struct {
+ description string
+ req *types.LeafRequest
+ rsp *trillian.QueueLeafResponse
+ err error
+ wantErr bool
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: no queued leaf",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: leaf is already queued or included",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{
+ QueuedLeaf: &trillian.QueuedLogLeaf{
+ Leaf: &trillian.LogLeaf{
+ LeafValue: req.Message.Marshal(),
+ },
+ Status: status.New(codes.AlreadyExists, "duplicate").Proto(),
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.QueueLeafResponse{
+ QueuedLeaf: &trillian.QueuedLogLeaf{
+ Leaf: &trillian.LogLeaf{
+ LeafValue: req.Message.Marshal(),
+ },
+ Status: status.New(codes.OK, "ok").Proto(),
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ err := client.AddLeaf(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ }()
+ }
+}
+
+func TestGetTreeHead(t *testing.T) {
+ // valid root
+ root := &ttypes.LogRootV1{
+ TreeSize: 0,
+ RootHash: make([]byte, types.HashSize),
+ TimestampNanos: 1622585623133599429,
+ }
+ buf, err := root.MarshalBinary()
+ if err != nil {
+ t.Fatalf("must marshal log root: %v", err)
+ }
+ // invalid root
+ root.RootHash = make([]byte, types.HashSize+1)
+ bufBadHash, err := root.MarshalBinary()
+ if err != nil {
+ t.Fatalf("must marshal log root: %v", err)
+ }
+
+ for _, table := range []struct {
+ description string
+ rsp *trillian.GetLatestSignedLogRootResponse
+ err error
+ wantErr bool
+ wantTh *types.TreeHead
+ }{
+ {
+ description: "invalid: backend failure",
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ wantErr: true,
+ },
+ {
+ description: "invalid: no signed log root",
+ rsp: &trillian.GetLatestSignedLogRootResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: no log root",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{},
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: no log root: unmarshal failed",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: buf[1:],
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected hash length",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: bufBadHash,
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ rsp: &trillian.GetLatestSignedLogRootResponse{
+ SignedLogRoot: &trillian.SignedLogRoot{
+ LogRoot: buf,
+ },
+ },
+ wantTh: &types.TreeHead{
+ Timestamp: 1622585623,
+ TreeSize: 0,
+ RootHash: &[types.HashSize]byte{},
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ th, err := client.GetTreeHead(context.Background())
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) {
+ t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetConsistencyProof(t *testing.T) {
+ req := &types.ConsistencyProofRequest{
+ OldSize: 1,
+ NewSize: 3,
+ }
+ for _, table := range []struct {
+ description string
+ req *types.ConsistencyProofRequest
+ rsp *trillian.GetConsistencyProofResponse
+ err error
+ wantErr bool
+ wantProof *types.ConsistencyProof
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: no consistency proof",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{},
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a consistency proof (1/2)",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not a consistency proof (2/2)",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize+1),
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetConsistencyProofResponse{
+ Proof: &trillian.Proof{
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize),
+ },
+ },
+ },
+ wantProof: &types.ConsistencyProof{
+ OldSize: 1,
+ NewSize: 3,
+ Path: []*[types.HashSize]byte{
+ &[types.HashSize]byte{},
+ &[types.HashSize]byte{},
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ proof, err := client.GetConsistencyProof(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) {
+ t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetInclusionProof(t *testing.T) {
+ req := &types.InclusionProofRequest{
+ TreeSize: 4,
+ LeafHash: &[types.HashSize]byte{},
+ }
+ for _, table := range []struct {
+ description string
+ req *types.InclusionProofRequest
+ rsp *trillian.GetInclusionProofByHashResponse
+ err error
+ wantErr bool
+ wantProof *types.InclusionProof
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: bad proof count",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{},
+ &trillian.Proof{},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not an inclusion proof (1/2)",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: not an inclusion proof (2/2)",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize+1),
+ },
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetInclusionProofByHashResponse{
+ Proof: []*trillian.Proof{
+ &trillian.Proof{
+ LeafIndex: 1,
+ Hashes: [][]byte{
+ make([]byte, types.HashSize),
+ make([]byte, types.HashSize),
+ },
+ },
+ },
+ },
+ wantProof: &types.InclusionProof{
+ TreeSize: 4,
+ LeafIndex: 1,
+ Path: []*[types.HashSize]byte{
+ &[types.HashSize]byte{},
+ &[types.HashSize]byte{},
+ },
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ proof, err := client.GetInclusionProof(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) {
+ t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
+
+func TestGetLeaves(t *testing.T) {
+ req := &types.LeavesRequest{
+ StartSize: 1,
+ EndSize: 2,
+ }
+ firstLeaf := &types.Leaf{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ SigIdent: types.SigIdent{
+ Signature: &[types.SignatureSize]byte{},
+ KeyHash: &[types.HashSize]byte{},
+ },
+ }
+ secondLeaf := &types.Leaf{
+ Message: types.Message{
+ ShardHint: 0,
+ Checksum: &[types.HashSize]byte{},
+ },
+ SigIdent: types.SigIdent{
+ Signature: &[types.SignatureSize]byte{},
+ KeyHash: &[types.HashSize]byte{},
+ },
+ }
+
+ for _, table := range []struct {
+ description string
+ req *types.LeavesRequest
+ rsp *trillian.GetLeavesByRangeResponse
+ err error
+ wantErr bool
+ wantLeaves *types.LeafList
+ }{
+ {
+ description: "invalid: backend failure",
+ req: req,
+ err: fmt.Errorf("something went wrong"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: no response",
+ req: req,
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected number of leaves",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected leaf (1/2)",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal(),
+ LeafIndex: 3,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "invalid: unexpected leaf (2/2)",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal()[1:],
+ LeafIndex: 2,
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ req: req,
+ rsp: &trillian.GetLeavesByRangeResponse{
+ Leaves: []*trillian.LogLeaf{
+ &trillian.LogLeaf{
+ LeafValue: firstLeaf.Marshal(),
+ LeafIndex: 1,
+ },
+ &trillian.LogLeaf{
+ LeafValue: secondLeaf.Marshal(),
+ LeafIndex: 2,
+ },
+ },
+ },
+ wantLeaves: &types.LeafList{
+ firstLeaf,
+ secondLeaf,
+ },
+ },
+ } {
+ // Run deferred functions at the end of each iteration
+ func() {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ grpc := mocks.NewMockTrillianLogClient(ctrl)
+ grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err)
+ client := TrillianClient{GRPC: grpc}
+
+ leaves, err := client.GetLeaves(context.Background(), table.req)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ return
+ }
+ if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) {
+ t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }()
+ }
+}
diff --git a/pkg/trillian/util.go b/pkg/trillian/util.go
new file mode 100644
index 0000000..4cf31fb
--- /dev/null
+++ b/pkg/trillian/util.go
@@ -0,0 +1,33 @@
+package trillian
+
+import (
+ "fmt"
+
+ trillian "github.com/google/trillian/types"
+ siglog "github.com/system-transparency/stfe/pkg/types"
+)
+
+func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead {
+ var hash [siglog.HashSize]byte
+ th := siglog.TreeHead{
+ Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000),
+ TreeSize: uint64(lr.TreeSize),
+ RootHash: &hash,
+ }
+ copy(th.RootHash[:], lr.RootHash)
+ return &th
+}
+
+func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) {
+ var path []*[siglog.HashSize]byte
+ for _, hash := range hashes {
+ if len(hash) != siglog.HashSize {
+ return nil, fmt.Errorf("unexpected hash length: %v", len(hash))
+ }
+
+ var h [siglog.HashSize]byte
+ copy(h[:], hash)
+ path = append(path, &h)
+ }
+ return path, nil
+}
diff --git a/pkg/types/ascii.go b/pkg/types/ascii.go
new file mode 100644
index 0000000..d27d79b
--- /dev/null
+++ b/pkg/types/ascii.go
@@ -0,0 +1,421 @@
+package types
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "strconv"
+)
+
+const (
+ // Delim is a key-value separator
+ Delim = "="
+
+ // EOL is a line sepator
+ EOL = "\n"
+
+ // NumField* is the number of unique keys in an incoming ASCII message
+ NumFieldLeaf = 4
+ NumFieldSignedTreeHead = 5
+ NumFieldConsistencyProof = 3
+ NumFieldInclusionProof = 3
+ NumFieldLeavesRequest = 2
+ NumFieldInclusionProofRequest = 2
+ NumFieldConsistencyProofRequest = 2
+ NumFieldLeafRequest = 5
+ NumFieldCosignatureRequest = 2
+
+ // New leaf keys
+ ShardHint = "shard_hint"
+ Checksum = "checksum"
+ SignatureOverMessage = "signature_over_message"
+ VerificationKey = "verification_key"
+ DomainHint = "domain_hint"
+
+ // Inclusion proof keys
+ LeafHash = "leaf_hash"
+ LeafIndex = "leaf_index"
+ InclusionPath = "inclusion_path"
+
+ // Consistency proof keys
+ NewSize = "new_size"
+ OldSize = "old_size"
+ ConsistencyPath = "consistency_path"
+
+ // Range of leaves keys
+ StartSize = "start_size"
+ EndSize = "end_size"
+
+ // Tree head keys
+ Timestamp = "timestamp"
+ TreeSize = "tree_size"
+ RootHash = "root_hash"
+
+ // Signature and signer-identity keys
+ Signature = "signature"
+ KeyHash = "key_hash"
+)
+
+// MessageASCI is a wrapper that manages ASCII key-value pairs
+type MessageASCII struct {
+ m map[string][]string
+}
+
+// NewMessageASCII unpacks an incoming ASCII message
+func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) {
+ buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return nil, fmt.Errorf("ReadAll: %v", err)
+ }
+ lines := bytes.Split(buf, []byte(EOL))
+ if len(lines) <= 1 {
+ return nil, fmt.Errorf("Not enough lines: empty")
+ }
+ lines = lines[:len(lines)-1] // valid message => split gives empty last line
+
+ msg := MessageASCII{make(map[string][]string)}
+ for _, line := range lines {
+ split := bytes.Index(line, []byte(Delim))
+ if split == -1 {
+ return nil, fmt.Errorf("invalid line: %v", string(line))
+ }
+
+ key := string(line[:split])
+ value := string(line[split+len(Delim):])
+ values, ok := msg.m[key]
+ if !ok {
+ values = nil
+ msg.m[key] = values
+ }
+ msg.m[key] = append(values, value)
+ }
+
+ if msg.NumField() != numFieldExpected {
+ return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField())
+ }
+ return &msg, nil
+}
+
+// NumField returns the number of unique keys
+func (msg *MessageASCII) NumField() int {
+ return len(msg.m)
+}
+
+// GetStrings returns a list of strings
+func (msg *MessageASCII) GetStrings(key string) []string {
+ strs, ok := msg.m[key]
+ if !ok {
+ return nil
+ }
+ return strs
+}
+
+// GetString unpacks a string
+func (msg *MessageASCII) GetString(key string) (string, error) {
+ strs := msg.GetStrings(key)
+ if len(strs) != 1 {
+ return "", fmt.Errorf("expected one string: %v", strs)
+ }
+ return strs[0], nil
+}
+
+// GetUint64 unpacks an uint64
+func (msg *MessageASCII) GetUint64(key string) (uint64, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return 0, fmt.Errorf("GetString: %v", err)
+ }
+ num, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("ParseUint: %v", err)
+ }
+ return num, nil
+}
+
+// GetHash unpacks a hash
+func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var hash [HashSize]byte
+ if err := decodeHex(str, hash[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &hash, nil
+}
+
+// GetSignature unpacks a signature
+func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var signature [SignatureSize]byte
+ if err := decodeHex(str, signature[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &signature, nil
+}
+
+// GetVerificationKey unpacks a verification key
+func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) {
+ str, err := msg.GetString(key)
+ if err != nil {
+ return nil, fmt.Errorf("GetString: %v", err)
+ }
+
+ var vk [VerificationKeySize]byte
+ if err := decodeHex(str, vk[:]); err != nil {
+ return nil, fmt.Errorf("decodeHex: %v", err)
+ }
+ return &vk, nil
+}
+
+// decodeHex decodes a hex-encoded string into an already-sized byte slice
+func decodeHex(str string, out []byte) error {
+ buf, err := hex.DecodeString(str)
+ if err != nil {
+ return fmt.Errorf("DecodeString: %v", err)
+ }
+ if len(buf) != len(out) {
+ return fmt.Errorf("invalid length: %v", len(buf))
+ }
+ copy(out, buf)
+ return nil
+}
+
+/*
+ *
+ * MarshalASCII wrappers for types that the log server outputs
+ *
+ */
+func (l *Leaf) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, SignatureOverMessage, hex.EncodeToString(l.Signature[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ return nil
+}
+
+func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, sigident := range sth.SigIdent {
+ if err := sigident.MarshalASCII(w); err != nil {
+ return fmt.Errorf("MarshalASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func (si *SigIdent) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, Signature, hex.EncodeToString(si.Signature[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, KeyHash, hex.EncodeToString(si.KeyHash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ return nil
+}
+
+func (p *ConsistencyProof) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, hash := range p.Path {
+ if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func (p *InclusionProof) MarshalASCII(w io.Writer) error {
+ if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ for _, hash := range p.Path {
+ if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil {
+ return fmt.Errorf("writeASCII: %v", err)
+ }
+ }
+ return nil
+}
+
+func writeASCII(w io.Writer, key, value string) error {
+ if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil {
+ return fmt.Errorf("Fprintf: %v", err)
+ }
+ return nil
+}
+
+/*
+ *
+ * Unmarshal ASCII wrappers that the log server and/or log clients receive.
+ *
+ */
+func (ll *LeafList) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldSignedTreeHead)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ // TreeHead
+ if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil {
+ return fmt.Errorf("GetUint64(Timestamp): %v", err)
+ }
+ if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil {
+ return fmt.Errorf("GetUint64(TreeSize): %v", err)
+ }
+ if sth.RootHash, err = msg.GetHash(RootHash); err != nil {
+ return fmt.Errorf("GetHash(RootHash): %v", err)
+ }
+
+ // SigIdent
+ signatures := msg.GetStrings(Signature)
+ if len(signatures) == 0 {
+ return fmt.Errorf("no signer")
+ }
+ keyHashes := msg.GetStrings(KeyHash)
+ if len(signatures) != len(keyHashes) {
+ return fmt.Errorf("mismatched signature-signer count")
+ }
+ sth.SigIdent = make([]*SigIdent, 0, len(signatures))
+ for i, n := 0, len(signatures); i < n; i++ {
+ var signature [SignatureSize]byte
+ if err := decodeHex(signatures[i], signature[:]); err != nil {
+ return fmt.Errorf("decodeHex: %v", err)
+ }
+ var hash [HashSize]byte
+ if err := decodeHex(keyHashes[i], hash[:]); err != nil {
+ return fmt.Errorf("decodeHex: %v", err)
+ }
+ sth.SigIdent = append(sth.SigIdent, &SigIdent{
+ Signature: &signature,
+ KeyHash: &hash,
+ })
+ }
+ return nil
+}
+
+func (p *InclusionProof) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error {
+ return nil
+}
+
+func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.LeafHash, err = msg.GetHash(LeafHash); err != nil {
+ return fmt.Errorf("GetHash(LeafHash): %v", err)
+ }
+ if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil {
+ return fmt.Errorf("GetUint64(TreeSize): %v", err)
+ }
+ return nil
+}
+
+func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.NewSize, err = msg.GetUint64(NewSize); err != nil {
+ return fmt.Errorf("GetUint64(NewSize): %v", err)
+ }
+ if req.OldSize, err = msg.GetUint64(OldSize); err != nil {
+ return fmt.Errorf("GetUint64(OldSize): %v", err)
+ }
+ return nil
+}
+
+func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldLeavesRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.StartSize, err = msg.GetUint64(StartSize); err != nil {
+ return fmt.Errorf("GetUint64(StartSize): %v", err)
+ }
+ if req.EndSize, err = msg.GetUint64(EndSize); err != nil {
+ return fmt.Errorf("GetUint64(EndSize): %v", err)
+ }
+ return nil
+}
+
+func (req *LeafRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldLeafRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil {
+ return fmt.Errorf("GetUint64(ShardHint): %v", err)
+ }
+ if req.Checksum, err = msg.GetHash(Checksum); err != nil {
+ return fmt.Errorf("GetHash(Checksum): %v", err)
+ }
+ if req.Signature, err = msg.GetSignature(SignatureOverMessage); err != nil {
+ return fmt.Errorf("GetSignature: %v", err)
+ }
+ if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil {
+ return fmt.Errorf("GetVerificationKey: %v", err)
+ }
+ if req.DomainHint, err = msg.GetString(DomainHint); err != nil {
+ return fmt.Errorf("GetString(DomainHint): %v", err)
+ }
+ return nil
+}
+
+func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error {
+ msg, err := NewMessageASCII(r, NumFieldCosignatureRequest)
+ if err != nil {
+ return fmt.Errorf("NewMessageASCII: %v", err)
+ }
+
+ if req.Signature, err = msg.GetSignature(Signature); err != nil {
+ return fmt.Errorf("GetSignature: %v", err)
+ }
+ if req.KeyHash, err = msg.GetHash(KeyHash); err != nil {
+ return fmt.Errorf("GetHash(KeyHash): %v", err)
+ }
+ return nil
+}
diff --git a/pkg/types/ascii_test.go b/pkg/types/ascii_test.go
new file mode 100644
index 0000000..92732f9
--- /dev/null
+++ b/pkg/types/ascii_test.go
@@ -0,0 +1,465 @@
+package types
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+)
+
+/*
+ *
+ * MessageASCII methods and helpers
+ *
+ */
+func TestNewMessageASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ input io.Reader
+ wantErr bool
+ wantMap map[string][]string
+ }{
+ {
+ description: "invalid: not enough lines",
+ input: bytes.NewBufferString(""),
+ wantErr: true,
+ },
+ {
+ description: "invalid: lines must end with new line",
+ input: bytes.NewBufferString("k1=v1\nk2=v2"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: lines must not be empty",
+ input: bytes.NewBufferString("k1=v1\n\nk2=v2\n"),
+ wantErr: true,
+ },
+ {
+ description: "invalid: wrong number of fields",
+ input: bytes.NewBufferString("k1=v1\n"),
+ wantErr: true,
+ },
+ {
+ description: "valid",
+ input: bytes.NewBufferString("k1=v1\nk2=v2\nk2=v3=4\n"),
+ wantMap: map[string][]string{
+ "k1": []string{"v1"},
+ "k2": []string{"v2", "v3=4"},
+ },
+ },
+ } {
+ msg, err := NewMessageASCII(table.input, len(table.wantMap))
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := msg.m, table.wantMap; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestNumField(t *testing.T) {}
+func TestGetStrings(t *testing.T) {}
+func TestGetString(t *testing.T) {}
+func TestGetUint64(t *testing.T) {}
+func TestGetHash(t *testing.T) {}
+func TestGetSignature(t *testing.T) {}
+func TestGetVerificationKey(t *testing.T) {}
+func TestDecodeHex(t *testing.T) {}
+
+/*
+ *
+ * MarshalASCII methods and helpers
+ *
+ */
+func TestLeafMarshalASCII(t *testing.T) {
+ description := "valid: two leaves"
+ leafList := []*Leaf{
+ &Leaf{
+ Message: Message{
+ ShardHint: 123,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ &Leaf{
+ Message: Message{
+ ShardHint: 456,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ // Leaf 1
+ ShardHint, Delim, 123, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ // Leaf 2
+ ShardHint, Delim, 456, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ for _, leaf := range leafList {
+ if err := leaf.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q: %v", false, true, description, err)
+ return
+ }
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestSignedTreeHeadMarshalASCII(t *testing.T) {
+ description := "valid"
+ sth := &SignedTreeHead{
+ TreeHead: TreeHead{
+ Timestamp: 123,
+ TreeSize: 456,
+ RootHash: testBuffer32,
+ },
+ SigIdent: []*SigIdent{
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ Timestamp, Delim, 123, EOL,
+ TreeSize, Delim, 456, EOL,
+ RootHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := sth.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestInclusionProofMarshalASCII(t *testing.T) {
+ description := "valid"
+ proof := InclusionProof{
+ TreeSize: 321,
+ LeafIndex: 123,
+ Path: []*[HashSize]byte{
+ testBuffer32,
+ testBuffer32,
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s",
+ TreeSize, Delim, 321, EOL,
+ LeafIndex, Delim, 123, EOL,
+ InclusionPath, Delim, testBuffer32[:], EOL,
+ InclusionPath, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := proof.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestConsistencyProofMarshalASCII(t *testing.T) {
+ description := "valid"
+ proof := ConsistencyProof{
+ NewSize: 321,
+ OldSize: 123,
+ Path: []*[HashSize]byte{
+ testBuffer32,
+ testBuffer32,
+ },
+ }
+ wantBuf := bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s",
+ NewSize, Delim, 321, EOL,
+ OldSize, Delim, 123, EOL,
+ ConsistencyPath, Delim, testBuffer32[:], EOL,
+ ConsistencyPath, Delim, testBuffer32[:], EOL,
+ ))
+ buf := bytes.NewBuffer(nil)
+ if err := proof.MarshalASCII(buf); err != nil {
+ t.Errorf("expected error %v but got %v in test %q", false, true, description)
+ return
+ }
+ if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description)
+ }
+}
+
+func TestWriteASCII(t *testing.T) {
+}
+
+/*
+ *
+ * UnmarshalASCII methods and helpers
+ *
+ */
+func TestLeafListUnmarshalASCII(t *testing.T) {}
+
+func TestSignedTreeHeadUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantSth *SignedTreeHead
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s",
+ Timestamp, Delim, 123, EOL,
+ TreeSize, Delim, 456, EOL,
+ RootHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ )),
+ wantSth: &SignedTreeHead{
+ TreeHead: TreeHead{
+ Timestamp: 123,
+ TreeSize: 456,
+ RootHash: testBuffer32,
+ },
+ SigIdent: []*SigIdent{
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ &SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ },
+ } {
+ var sth SignedTreeHead
+ err := sth.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &sth, table.wantSth; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestInclusionProofUnmarshalASCII(t *testing.T) {}
+func TestConsistencyProofUnmarshalASCII(t *testing.T) {}
+
+func TestInclusionProofRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *InclusionProofRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%d%s",
+ LeafHash, Delim, testBuffer32[:], EOL,
+ TreeSize, Delim, 123, EOL,
+ )),
+ wantReq: &InclusionProofRequest{
+ LeafHash: testBuffer32,
+ TreeSize: 123,
+ },
+ },
+ } {
+ var req InclusionProofRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestConsistencyProofRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *ConsistencyProofRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s",
+ NewSize, Delim, 321, EOL,
+ OldSize, Delim, 123, EOL,
+ )),
+ wantReq: &ConsistencyProofRequest{
+ NewSize: 321,
+ OldSize: 123,
+ },
+ },
+ } {
+ var req ConsistencyProofRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestLeavesRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *LeavesRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%d%s",
+ StartSize, Delim, 123, EOL,
+ EndSize, Delim, 456, EOL,
+ )),
+ wantReq: &LeavesRequest{
+ StartSize: 123,
+ EndSize: 456,
+ },
+ },
+ } {
+ var req LeavesRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestLeafRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *LeafRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%s%s",
+ ShardHint, Delim, 123, EOL,
+ Checksum, Delim, testBuffer32[:], EOL,
+ SignatureOverMessage, Delim, testBuffer64[:], EOL,
+ VerificationKey, Delim, testBuffer32[:], EOL,
+ DomainHint, Delim, "example.com", EOL,
+ )),
+ wantReq: &LeafRequest{
+ Message: Message{
+ ShardHint: 123,
+ Checksum: testBuffer32,
+ },
+ Signature: testBuffer64,
+ VerificationKey: testBuffer32,
+ DomainHint: "example.com",
+ },
+ },
+ } {
+ var req LeafRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
+
+func TestCosignatureRequestUnmarshalASCII(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ buf io.Reader
+ wantErr bool
+ wantReq *CosignatureRequest
+ }{
+ {
+ description: "valid",
+ buf: bytes.NewBufferString(fmt.Sprintf(
+ "%s%s%x%s"+"%s%s%x%s",
+ Signature, Delim, testBuffer64[:], EOL,
+ KeyHash, Delim, testBuffer32[:], EOL,
+ )),
+ wantReq: &CosignatureRequest{
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ } {
+ var req CosignatureRequest
+ err := req.UnmarshalASCII(table.buf)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) {
+ t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description)
+ }
+ }
+}
diff --git a/pkg/types/trunnel.go b/pkg/types/trunnel.go
new file mode 100644
index 0000000..268f6f7
--- /dev/null
+++ b/pkg/types/trunnel.go
@@ -0,0 +1,60 @@
+package types
+
+import (
+ "encoding/binary"
+ "fmt"
+)
+
+const (
+ // MessageSize is the number of bytes in a Trunnel-encoded leaf message
+ MessageSize = 8 + HashSize
+ // LeafSize is the number of bytes in a Trunnel-encoded leaf
+ LeafSize = MessageSize + SignatureSize + HashSize
+)
+
+// Marshal returns a Trunnel-encoded message
+func (m *Message) Marshal() []byte {
+ buf := make([]byte, MessageSize)
+ binary.BigEndian.PutUint64(buf, m.ShardHint)
+ copy(buf[8:], m.Checksum[:])
+ return buf
+}
+
+// Marshal returns a Trunnel-encoded leaf
+func (l *Leaf) Marshal() []byte {
+ buf := l.Message.Marshal()
+ buf = append(buf, l.SigIdent.Signature[:]...)
+ buf = append(buf, l.SigIdent.KeyHash[:]...)
+ return buf
+}
+
+// Marshal returns a Trunnel-encoded tree head
+func (th *TreeHead) Marshal() []byte {
+ buf := make([]byte, 8+8+HashSize)
+ binary.BigEndian.PutUint64(buf[0:8], th.Timestamp)
+ binary.BigEndian.PutUint64(buf[8:16], th.TreeSize)
+ copy(buf[16:], th.RootHash[:])
+ return buf
+}
+
+// Unmarshal parses the Trunnel-encoded buffer as a leaf
+func (l *Leaf) Unmarshal(buf []byte) error {
+ if len(buf) != LeafSize {
+ return fmt.Errorf("invalid leaf size: %v", len(buf))
+ }
+ // Shard hint
+ l.ShardHint = binary.BigEndian.Uint64(buf)
+ offset := 8
+ // Checksum
+ l.Checksum = &[HashSize]byte{}
+ copy(l.Checksum[:], buf[offset:offset+HashSize])
+ offset += HashSize
+ // Signature
+ l.Signature = &[SignatureSize]byte{}
+ copy(l.Signature[:], buf[offset:offset+SignatureSize])
+ offset += SignatureSize
+ // KeyHash
+ l.KeyHash = &[HashSize]byte{}
+ copy(l.KeyHash[:], buf[offset:])
+ return nil
+}
diff --git a/pkg/types/trunnel_test.go b/pkg/types/trunnel_test.go
new file mode 100644
index 0000000..297578c
--- /dev/null
+++ b/pkg/types/trunnel_test.go
@@ -0,0 +1,114 @@
+package types
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+)
+
+var (
+ testBuffer32 = &[32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
+ testBuffer64 = &[64]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}
+)
+
+func TestMarshalMessage(t *testing.T) {
+ description := "valid: shard hint 72623859790382856, checksum 0x00,0x01,..."
+ message := &Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:],
+ }, nil)
+ if got := message.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got message\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestMarshalLeaf(t *testing.T) {
+ description := "valid: shard hint 72623859790382856, buffers 0x00,0x01,..."
+ leaf := &Leaf{
+ Message: Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:], testBuffer64[:], testBuffer32[:],
+ }, nil)
+ if got := leaf.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestMarshalTreeHead(t *testing.T) {
+ description := "valid: timestamp 16909060, tree size 72623859790382856, root hash 0x00,0x01,..."
+ th := &TreeHead{
+ Timestamp: 16909060,
+ TreeSize: 72623859790382856,
+ RootHash: testBuffer32,
+ }
+ want := bytes.Join([][]byte{
+ []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04},
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:],
+ }, nil)
+ if got := th.Marshal(); !bytes.Equal(got, want) {
+ t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description)
+ }
+}
+
+func TestUnmarshalLeaf(t *testing.T) {
+ for _, table := range []struct {
+ description string
+ serialized []byte
+ wantErr bool
+ want *Leaf
+ }{
+ {
+ description: "invalid: not enough bytes",
+ serialized: make([]byte, LeafSize-1),
+ wantErr: true,
+ },
+ {
+ description: "invalid: too many bytes",
+ serialized: make([]byte, LeafSize+1),
+ wantErr: true,
+ },
+ {
+ description: "valid: shard hint 72623859790382856, buffers 0x00,0x01,...",
+ serialized: bytes.Join([][]byte{
+ []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
+ testBuffer32[:], testBuffer64[:], testBuffer32[:],
+ }, nil),
+ want: &Leaf{
+ Message: Message{
+ ShardHint: 72623859790382856,
+ Checksum: testBuffer32,
+ },
+ SigIdent: SigIdent{
+ Signature: testBuffer64,
+ KeyHash: testBuffer32,
+ },
+ },
+ },
+ } {
+ var leaf Leaf
+ err := leaf.Unmarshal(table.serialized)
+ if got, want := err != nil, table.wantErr; got != want {
+ t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err)
+ }
+ if err != nil {
+ continue
+ }
+ if got, want := &leaf, table.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, table.description)
+ }
+ }
+}
diff --git a/pkg/types/types.go b/pkg/types/types.go
new file mode 100644
index 0000000..9ca7db8
--- /dev/null
+++ b/pkg/types/types.go
@@ -0,0 +1,155 @@
+package types
+
+import (
+ "crypto"
+ "crypto/ed25519"
+ "crypto/sha256"
+ "fmt"
+ "strings"
+)
+
+const (
+ HashSize = sha256.Size
+ SignatureSize = ed25519.SignatureSize
+ VerificationKeySize = ed25519.PublicKeySize
+
+ EndpointAddLeaf = Endpoint("add-leaf")
+ EndpointAddCosignature = Endpoint("add-cosignature")
+ EndpointGetTreeHeadLatest = Endpoint("get-tree-head-latest")
+ EndpointGetTreeHeadToSign = Endpoint("get-tree-head-to-sign")
+ EndpointGetTreeHeadCosigned = Endpoint("get-tree-head-cosigned")
+ EndpointGetProofByHash = Endpoint("get-proof-by-hash")
+ EndpointGetConsistencyProof = Endpoint("get-consistency-proof")
+ EndpointGetLeaves = Endpoint("get-leaves")
+)
+
+// Endpoint is a named HTTP API endpoint
+type Endpoint string
+
+// Path joins a number of components to form a full endpoint path. For example,
+// EndpointAddLeaf.Path("example.com", "st/v0") -> example.com/st/v0/add-leaf.
+func (e Endpoint) Path(components ...string) string {
+ return strings.Join(append(components, string(e)), "/")
+}
+
+// Leaf is the log's Merkle tree leaf.
+type Leaf struct {
+ Message
+ SigIdent
+}
+
+// Message is composed of a shard hint and a checksum. The submitter selects
+// these values to fit the log's shard interval and the opaque data in question.
+type Message struct {
+ ShardHint uint64
+ Checksum *[HashSize]byte
+}
+
+// SigIdent is composed of a signature-signer pair. The signature is computed
+// over the Trunnel-serialized leaf message. KeyHash identifies the signer.
+type SigIdent struct {
+ Signature *[SignatureSize]byte
+ KeyHash *[HashSize]byte
+}
+
+// SignedTreeHead is composed of a tree head and a list of signature-signer
+// pairs. Each signature is computed over the Trunnel-serialized tree head.
+type SignedTreeHead struct {
+ TreeHead
+ SigIdent []*SigIdent
+}
+
+// TreeHead is the log's tree head.
+type TreeHead struct {
+ Timestamp uint64
+ TreeSize uint64
+ RootHash *[HashSize]byte
+}
+
+// ConsistencyProof is a consistency proof that proves the log's append-only
+// property.
+type ConsistencyProof struct {
+ NewSize uint64
+ OldSize uint64
+ Path []*[HashSize]byte
+}
+
+// InclusionProof is an inclusion proof that proves a leaf is included in the
+// log.
+type InclusionProof struct {
+ TreeSize uint64
+ LeafIndex uint64
+ Path []*[HashSize]byte
+}
+
+// LeafList is a list of leaves
+type LeafList []*Leaf
+
+// ConsistencyProofRequest is a get-consistency-proof request
+type ConsistencyProofRequest struct {
+ NewSize uint64
+ OldSize uint64
+}
+
+// InclusionProofRequest is a get-proof-by-hash request
+type InclusionProofRequest struct {
+ LeafHash *[HashSize]byte
+ TreeSize uint64
+}
+
+// LeavesRequest is a get-leaves request
+type LeavesRequest struct {
+ StartSize uint64
+ EndSize uint64
+}
+
+// LeafRequest is an add-leaf request
+type LeafRequest struct {
+ Message
+ Signature *[SignatureSize]byte
+ VerificationKey *[VerificationKeySize]byte
+ DomainHint string
+}
+
+// CosignatureRequest is an add-cosignature request
+type CosignatureRequest struct {
+ SigIdent
+}
+
+// Sign signs the tree head using the log's signature scheme
+func (th *TreeHead) Sign(signer crypto.Signer) (*SignedTreeHead, error) {
+ sig, err := signer.Sign(nil, th.Marshal(), crypto.Hash(0))
+ if err != nil {
+ return nil, fmt.Errorf("Sign: %v", err)
+ }
+
+ sigident := SigIdent{
+ KeyHash: Hash(signer.Public().(ed25519.PublicKey)[:]),
+ Signature: &[SignatureSize]byte{},
+ }
+ copy(sigident.Signature[:], sig)
+ return &SignedTreeHead{
+ TreeHead: *th,
+ SigIdent: []*SigIdent{
+ &sigident,
+ },
+ }, nil
+}
+
+// Verify verifies the tree head signature using the log's signature scheme
+func (th *TreeHead) Verify(vk *[VerificationKeySize]byte, sig *[SignatureSize]byte) error {
+ if !ed25519.Verify(ed25519.PublicKey(vk[:]), th.Marshal(), sig[:]) {
+ return fmt.Errorf("invalid tree head signature")
+ }
+ return nil
+}
+
+// Verify checks if a leaf is included in the log
+func (p *InclusionProof) Verify(leaf *Leaf, th *TreeHead) error { // TODO
+ return nil
+}
+
+// Verify checks if two tree heads are consistent
+func (p *ConsistencyProof) Verify(oldTH, newTH *TreeHead) error { // TODO
+ return nil
+}
diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go
new file mode 100644
index 0000000..da89c59
--- /dev/null
+++ b/pkg/types/types_test.go
@@ -0,0 +1,58 @@
+package types
+
+import (
+ "testing"
+)
+
+func TestEndpointPath(t *testing.T) {
+ base, prefix, proto := "example.com", "log", "st/v0"
+ for _, table := range []struct {
+ endpoint Endpoint
+ want string
+ }{
+ {
+ endpoint: EndpointAddLeaf,
+ want: "example.com/log/st/v0/add-leaf",
+ },
+ {
+ endpoint: EndpointAddCosignature,
+ want: "example.com/log/st/v0/add-cosignature",
+ },
+ {
+ endpoint: EndpointGetTreeHeadLatest,
+ want: "example.com/log/st/v0/get-tree-head-latest",
+ },
+ {
+ endpoint: EndpointGetTreeHeadToSign,
+ want: "example.com/log/st/v0/get-tree-head-to-sign",
+ },
+ {
+ endpoint: EndpointGetTreeHeadCosigned,
+ want: "example.com/log/st/v0/get-tree-head-cosigned",
+ },
+ {
+ endpoint: EndpointGetConsistencyProof,
+ want: "example.com/log/st/v0/get-consistency-proof",
+ },
+ {
+ endpoint: EndpointGetProofByHash,
+ want: "example.com/log/st/v0/get-proof-by-hash",
+ },
+ {
+ endpoint: EndpointGetLeaves,
+ want: "example.com/log/st/v0/get-leaves",
+ },
+ } {
+ if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want {
+ t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want)
+ }
+ if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want {
+ t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want)
+ }
+ }
+}
+
+func TestTreeHeadSign(t *testing.T) {}
+func TestTreeHeadVerify(t *testing.T) {}
+func TestInclusionProofVerify(t *testing.T) {}
+func TestConsistencyProofVerify(t *testing.T) {}
diff --git a/pkg/types/util.go b/pkg/types/util.go
new file mode 100644
index 0000000..3cd7dfa
--- /dev/null
+++ b/pkg/types/util.go
@@ -0,0 +1,21 @@
+package types
+
+import (
+ "crypto/sha256"
+)
+
+const (
+ LeafHashPrefix = 0x00
+)
+
+func Hash(buf []byte) *[HashSize]byte {
+ var ret [HashSize]byte
+ hash := sha256.New()
+ hash.Write(buf)
+ copy(ret[:], hash.Sum(nil))
+ return &ret
+}
+
+func HashLeaf(buf []byte) *[HashSize]byte {
+ return Hash(append([]byte{LeafHashPrefix}, buf...))
+}