diff options
| author | Rasmus Dahlberg <rasmus.dahlberg@kau.se> | 2021-06-07 00:19:40 +0200 | 
|---|---|---|
| committer | Rasmus Dahlberg <rasmus.dahlberg@kau.se> | 2021-06-07 00:19:40 +0200 | 
| commit | 932d29fd08c8ff401e471b4f764537493ccbd483 (patch) | |
| tree | e840a4c62db92e84201fe9ceaa0594d99176792c /pkg | |
| parent | bdf7a53d61cf044e526cc9123ca296615f838288 (diff) | |
| parent | 345fe658fa8a4306caa74f72a618e499343675c2 (diff) | |
Merge branch 'design' into main
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/instance/endpoint.go | 122 | ||||
| -rw-r--r-- | pkg/instance/endpoint_test.go | 432 | ||||
| -rw-r--r-- | pkg/instance/instance.go | 159 | ||||
| -rw-r--r-- | pkg/instance/instance_test.go | 9 | ||||
| -rw-r--r-- | pkg/instance/metric.go | 19 | ||||
| -rw-r--r-- | pkg/mocks/crypto.go | 23 | ||||
| -rw-r--r-- | pkg/mocks/state.go | 107 | ||||
| -rw-r--r-- | pkg/mocks/stfe.go | 110 | ||||
| -rw-r--r-- | pkg/mocks/trillian.go | 317 | ||||
| -rw-r--r-- | pkg/state/state_manager.go | 154 | ||||
| -rw-r--r-- | pkg/state/state_manager_test.go | 393 | ||||
| -rw-r--r-- | pkg/trillian/client.go | 178 | ||||
| -rw-r--r-- | pkg/trillian/client_test.go | 533 | ||||
| -rw-r--r-- | pkg/trillian/util.go | 33 | ||||
| -rw-r--r-- | pkg/types/ascii.go | 421 | ||||
| -rw-r--r-- | pkg/types/ascii_test.go | 465 | ||||
| -rw-r--r-- | pkg/types/trunnel.go | 60 | ||||
| -rw-r--r-- | pkg/types/trunnel_test.go | 114 | ||||
| -rw-r--r-- | pkg/types/types.go | 155 | ||||
| -rw-r--r-- | pkg/types/types_test.go | 58 | ||||
| -rw-r--r-- | pkg/types/util.go | 21 | 
21 files changed, 3883 insertions, 0 deletions
| diff --git a/pkg/instance/endpoint.go b/pkg/instance/endpoint.go new file mode 100644 index 0000000..5085c49 --- /dev/null +++ b/pkg/instance/endpoint.go @@ -0,0 +1,122 @@ +package stfe + +import ( +	"context" +	"net/http" + +	"github.com/golang/glog" +) + +func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { +	glog.V(3).Info("handling add-entry request") +	req, err := i.leafRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} +	if err := i.Client.AddLeaf(ctx, req); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func addCosignature(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { +	glog.V(3).Info("handling add-cosignature request") +	req, err := i.cosignatureRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} +	vk := i.Witnesses[*req.KeyHash] +	if err := i.Stateman.AddCosignature(ctx, &vk, req.Signature); err != nil { +		return http.StatusBadRequest, err +	} +	return http.StatusOK, nil +} + +func getTreeHeadLatest(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { +	glog.V(3).Info("handling get-tree-head-latest request") +	sth, err := i.Stateman.Latest(ctx) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := sth.MarshalASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getTreeHeadToSign(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { +	glog.V(3).Info("handling get-tree-head-to-sign request") +	sth, err := i.Stateman.ToSign(ctx) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := sth.MarshalASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getTreeHeadCosigned(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { +	glog.V(3).Info("handling get-tree-head-cosigned request") +	sth, err := i.Stateman.Cosigned(ctx) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := sth.MarshalASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { +	glog.V(3).Info("handling get-consistency-proof request") +	req, err := i.consistencyProofRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	proof, err := i.Client.GetConsistencyProof(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := proof.MarshalASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getInclusionProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { +	glog.V(3).Info("handling get-proof-by-hash request") +	req, err := i.inclusionProofRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	proof, err := i.Client.GetInclusionProof(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := proof.MarshalASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getLeaves(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { +	glog.V(3).Info("handling get-leaves request") +	req, err := i.leavesRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	leaves, err := i.Client.GetLeaves(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	for _, leaf := range *leaves { +		if err := leaf.MarshalASCII(w); err != nil { +			return http.StatusInternalServerError, err +		} +	} +	return http.StatusOK, nil +} diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go new file mode 100644 index 0000000..efcd4c0 --- /dev/null +++ b/pkg/instance/endpoint_test.go @@ -0,0 +1,432 @@ +package stfe + +import ( +	"bytes" +	"encoding/hex" +	"fmt" +	"io" +	"net/http" +	"net/http/httptest" +	"testing" + +	"github.com/golang/mock/gomock" +	"github.com/system-transparency/stfe/pkg/mocks" +	"github.com/system-transparency/stfe/pkg/types" +) + +var ( +	testWitVK  = [types.VerificationKeySize]byte{} +	testConfig = Config{ +		LogID:    hex.EncodeToString(types.Hash([]byte("logid"))[:]), +		TreeID:   0, +		Prefix:   "testonly", +		MaxRange: 3, +		Deadline: 10, +		Interval: 10, +		Witnesses: map[[types.HashSize]byte][types.VerificationKeySize]byte{ +			*types.Hash(testWitVK[:]): testWitVK, +		}, +	} +	testSTH = &types.SignedTreeHead{ +		TreeHead: types.TreeHead{ +			Timestamp: 0, +			TreeSize:  0, +			RootHash:  types.Hash(nil), +		}, +		SigIdent: []*types.SigIdent{ +			&types.SigIdent{ +				Signature: &[types.SignatureSize]byte{}, +				KeyHash:   &[types.HashSize]byte{}, +			}, +		}, +	} +) + +func mustHandle(t *testing.T, i Instance, e types.Endpoint) Handler { +	for _, handler := range i.Handlers() { +		if handler.Endpoint == e { +			return handler +		} +	} +	t.Fatalf("must handle endpoint: %v", e) +	return Handler{} +} + +func TestAddLeaf(t *testing.T) { +	buf := func() io.Reader { +		// A valid leaf request that was created manually +		return bytes.NewBufferString(fmt.Sprintf( +			"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s", +			types.ShardHint, types.Delim, "0", types.EOL, +			types.Checksum, types.Delim, "0000000000000000000000000000000000000000000000000000000000000000", types.EOL, +			types.SignatureOverMessage, types.Delim, "4cb410a4d48f52f761a7c01abcc28fd71811b84ded5403caed5e21b374f6aac9637cecd36828f17529fd503413d30ab66d7bb37a31dbf09a90d23b9241c45009", types.EOL, +			types.VerificationKey, types.Delim, "f2b7a00b625469d32502e06e8b7fad1ef258d4ad0c6cd87b846142ab681957d5", types.EOL, +			types.DomainHint, types.Delim, "example.com", types.EOL, +		)) +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader // buffer used to populate HTTP request +		expect      bool      // set if a mock answer is expected +		err         error     // error from Trillian client +		wantCode    int       // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (signature error)", +			ascii: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s"+"%s%s%s%s", +				types.ShardHint, types.Delim, "1", types.EOL, +				types.Checksum, types.Delim, "1111111111111111111111111111111111111111111111111111111111111111", types.EOL, +				types.SignatureOverMessage, types.Delim, "4cb410a4d48f52f761a7c01abcc28fd71811b84ded5403caed5e21b374f6aac9637cecd36828f17529fd503413d30ab66d7bb37a31dbf09a90d23b9241c45009", types.EOL, +				types.VerificationKey, types.Delim, "f2b7a00b625469d32502e06e8b7fad1ef258d4ad0c6cd87b846142ab681957d5", types.EOL, +				types.DomainHint, types.Delim, "example.com", types.EOL, +			)), +			wantCode: http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			ascii:       buf(), +			expect:      true, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().AddLeaf(gomock.Any(), gomock.Any()).Return(table.err) +			} +			i := Instance{ +				Config: testConfig, +				Client: client, +			} + +			// Create HTTP request +			url := types.EndpointAddLeaf.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointAddLeaf).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestAddCosignature(t *testing.T) { +	buf := func() io.Reader { +		return bytes.NewBufferString(fmt.Sprintf( +			"%s%s%x%s"+"%s%s%x%s", +			types.Signature, types.Delim, make([]byte, types.SignatureSize), types.EOL, +			types.KeyHash, types.Delim, *types.Hash(testWitVK[:]), types.EOL, +		)) +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader // buffer used to populate HTTP request +		expect      bool      // set if a mock answer is expected +		err         error     // error from Trillian client +		wantCode    int       // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (unknown witness)", +			ascii: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%x%s"+"%s%s%x%s", +				types.Signature, types.Delim, make([]byte, types.SignatureSize), types.EOL, +				types.KeyHash, types.Delim, *types.Hash(testWitVK[1:]), types.EOL, +			)), +			wantCode: http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "valid", +			ascii:       buf(), +			expect:      true, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocks.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().AddCosignature(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.err) +			} +			i := Instance{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointAddCosignature.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointAddCosignature).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetTreeHeadLatest(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool                  // set if a mock answer is expected +		rsp         *types.SignedTreeHead // signed tree head from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "invalid: backend failure", +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			expect:      true, +			rsp:         testSTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocks.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().Latest(gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadLatest.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetTreeHeadLatest).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetTreeToSign(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool                  // set if a mock answer is expected +		rsp         *types.SignedTreeHead // signed tree head from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "invalid: backend failure", +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			expect:      true, +			rsp:         testSTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocks.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().ToSign(gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadToSign.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetTreeHeadToSign).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetTreeCosigned(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool                  // set if a mock answer is expected +		rsp         *types.SignedTreeHead // signed tree head from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "invalid: backend failure", +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			expect:      true, +			rsp:         testSTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocks.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().Cosigned(gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadCosigned.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetTreeHeadCosigned).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetConsistencyProof(t *testing.T) { +	buf := func(oldSize, newSize int) io.Reader { +		return bytes.NewBufferString(fmt.Sprintf( +			"%s%s%d%s"+"%s%s%d%s", +			types.OldSize, types.Delim, oldSize, types.EOL, +			types.NewSize, types.Delim, newSize, types.EOL, +		)) +	} +	// values in testProof are not relevant for the test, just need a path +	testProof := &types.ConsistencyProof{ +		OldSize: 1, +		NewSize: 2, +		Path: []*[types.HashSize]byte{ +			types.Hash(nil), +		}, +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader               // buffer used to populate HTTP request +		expect      bool                    // set if a mock answer is expected +		rsp         *types.ConsistencyProof // consistency proof from Trillian client +		err         error                   // error from Trillian client +		wantCode    int                     // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "valid", +			ascii:       buf(1, 2), +			expect:      true, +			rsp:         testProof, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			i := Instance{ +				Config: testConfig, +				Client: client, +			} + +			// Create HTTP request +			url := types.EndpointGetConsistencyProof.Path("http://example.com", i.Prefix) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandle(t, i, types.EndpointGetConsistencyProof).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetInclusionProof(t *testing.T) { +} + +func TestGetLeaves(t *testing.T) { +} diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go new file mode 100644 index 0000000..c2fe8fa --- /dev/null +++ b/pkg/instance/instance.go @@ -0,0 +1,159 @@ +package stfe + +import ( +	"context" +	"crypto" +	"crypto/ed25519" +	"fmt" +	"net/http" +	"time" + +	"github.com/golang/glog" +	"github.com/system-transparency/stfe/pkg/state" +	"github.com/system-transparency/stfe/pkg/trillian" +	"github.com/system-transparency/stfe/pkg/types" +) + +// Config is a collection of log parameters +type Config struct { +	LogID    string        // H(public key), then hex-encoded +	TreeID   int64         // Merkle tree identifier used by Trillian +	Prefix   string        // The portion between base URL and st/v0 (may be "") +	MaxRange int64         // Maximum number of leaves per get-leaves request +	Deadline time.Duration // Deadline used for gRPC requests +	Interval time.Duration // Cosigning frequency + +	// Witnesses map trusted witness identifiers to public verification keys +	Witnesses map[[types.HashSize]byte][types.VerificationKeySize]byte +} + +// Instance is an instance of the log's front-end +type Instance struct { +	Config                      // configuration parameters +	Client   trillian.Client    // provides access to the Trillian backend +	Signer   crypto.Signer      // provides access to Ed25519 private key +	Stateman state.StateManager // coordinates access to (co)signed tree heads +} + +// Handler implements the http.Handler interface, and contains a reference +// to an STFE server instance as well as a function that uses it. +type Handler struct { +	Instance *Instance +	Endpoint types.Endpoint +	Method   string +	Handler  func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error) +} + +// Handlers returns a list of STFE handlers +func (i *Instance) Handlers() []Handler { +	return []Handler{ +		Handler{Instance: i, Handler: addLeaf, Endpoint: types.EndpointAddLeaf, Method: http.MethodPost}, +		Handler{Instance: i, Handler: addCosignature, Endpoint: types.EndpointAddCosignature, Method: http.MethodPost}, +		Handler{Instance: i, Handler: getTreeHeadLatest, Endpoint: types.EndpointGetTreeHeadLatest, Method: http.MethodGet}, +		Handler{Instance: i, Handler: getTreeHeadToSign, Endpoint: types.EndpointGetTreeHeadToSign, Method: http.MethodGet}, +		Handler{Instance: i, Handler: getTreeHeadCosigned, Endpoint: types.EndpointGetTreeHeadCosigned, Method: http.MethodGet}, +		Handler{Instance: i, Handler: getConsistencyProof, Endpoint: types.EndpointGetConsistencyProof, Method: http.MethodPost}, +		Handler{Instance: i, Handler: getInclusionProof, Endpoint: types.EndpointGetProofByHash, Method: http.MethodPost}, +		Handler{Instance: i, Handler: getLeaves, Endpoint: types.EndpointGetLeaves, Method: http.MethodPost}, +	} +} + +// Path returns a path that should be configured for this handler +func (h Handler) Path() string { +	return h.Endpoint.Path(h.Instance.Prefix, "st", "v0") +} + +// ServeHTTP is part of the http.Handler interface +func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +	// export prometheus metrics +	var now time.Time = time.Now() +	var statusCode int +	defer func() { +		rspcnt.Inc(a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) +		latency.Observe(time.Now().Sub(now).Seconds(), a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) +	}() +	reqcnt.Inc(a.Instance.LogID, string(a.Endpoint)) + +	ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline)) +	defer cancel() + +	if r.Method != a.Method { +		glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method) +		http.Error(w, "", http.StatusMethodNotAllowed) +		return +	} + +	statusCode, err := a.Handler(ctx, a.Instance, w, r) +	if err != nil { +		glog.Warningf("handler error %s/%s: %v", a.Instance.Prefix, a.Endpoint, err) +		http.Error(w, fmt.Sprintf("%s%s%s%s", "Error", types.Delim, err.Error(), types.EOL), statusCode) +	} +} + +func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, error) { +	var req types.LeafRequest +	if err := req.UnmarshalASCII(r.Body); err != nil { +		return nil, fmt.Errorf("UnmarshalASCII: %v", err) +	} + +	vk := ed25519.PublicKey(req.VerificationKey[:]) +	msg := req.Message.Marshal() +	sig := req.Signature[:] +	if !ed25519.Verify(vk, msg, sig) { +		return nil, fmt.Errorf("invalid signature") +	} +	// TODO: check shard hint +	// TODO: check domain hint +	return &req, nil +} + +func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) { +	var req types.CosignatureRequest +	if err := req.UnmarshalASCII(r.Body); err != nil { +		return nil, fmt.Errorf("unpackOctetPost: %v", err) +	} +	if _, ok := i.Witnesses[*req.KeyHash]; !ok { +		return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash) +	} +	return &req, nil +} + +func (i *Instance) consistencyProofRequestFromHTTP(r *http.Request) (*types.ConsistencyProofRequest, error) { +	var req types.ConsistencyProofRequest +	if err := req.UnmarshalASCII(r.Body); err != nil { +		return nil, fmt.Errorf("UnmarshalASCII: %v", err) +	} +	if req.OldSize < 1 { +		return nil, fmt.Errorf("OldSize(%d) must be larger than zero", req.OldSize) +	} +	if req.NewSize <= req.OldSize { +		return nil, fmt.Errorf("NewSize(%d) must be larger than OldSize(%d)", req.NewSize, req.OldSize) +	} +	return &req, nil +} + +func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.InclusionProofRequest, error) { +	var req types.InclusionProofRequest +	if err := req.UnmarshalASCII(r.Body); err != nil { +		return nil, fmt.Errorf("UnmarshalASCII: %v", err) +	} +	if req.TreeSize < 1 { +		return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize) +	} +	return &req, nil +} + +func (i *Instance) leavesRequestFromHTTP(r *http.Request) (*types.LeavesRequest, error) { +	var req types.LeavesRequest +	if err := req.UnmarshalASCII(r.Body); err != nil { +		return nil, fmt.Errorf("UnmarshalASCII: %v", err) +	} + +	if req.StartSize > req.EndSize { +		return nil, fmt.Errorf("StartSize(%d) must be less than or equal to EndSize(%d)", req.StartSize, req.EndSize) +	} +	if req.EndSize-req.StartSize+1 > uint64(i.MaxRange) { +		req.EndSize = req.StartSize + uint64(i.MaxRange) - 1 +	} +	return &req, nil +} diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go new file mode 100644 index 0000000..45a2837 --- /dev/null +++ b/pkg/instance/instance_test.go @@ -0,0 +1,9 @@ +package stfe + +import ( +	"testing" +) + +func TestHandlers(t *testing.T)  {} +func TestPath(t *testing.T)      {} +func TestServeHTTP(t *testing.T) {} diff --git a/pkg/instance/metric.go b/pkg/instance/metric.go new file mode 100644 index 0000000..db11bd2 --- /dev/null +++ b/pkg/instance/metric.go @@ -0,0 +1,19 @@ +package stfe + +import ( +	"github.com/google/trillian/monitoring" +	"github.com/google/trillian/monitoring/prometheus" +) + +var ( +	reqcnt  monitoring.Counter   // number of incoming http requests +	rspcnt  monitoring.Counter   // number of valid http responses +	latency monitoring.Histogram // request-response latency +) + +func init() { +	mf := prometheus.MetricFactory{} +	reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint") +	rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status") +	latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status") +} diff --git a/pkg/mocks/crypto.go b/pkg/mocks/crypto.go new file mode 100644 index 0000000..87c883a --- /dev/null +++ b/pkg/mocks/crypto.go @@ -0,0 +1,23 @@ +package mocks + +import ( +	"crypto" +	"crypto/ed25519" +	"io" +) + +// TestSign implements the signer interface.  It can be used to mock an Ed25519 +// signer that always return the same public key, signature, and error. +type TestSigner struct { +	PublicKey *[ed25519.PublicKeySize]byte +	Signature *[ed25519.SignatureSize]byte +	Error     error +} + +func (ts *TestSigner) Public() crypto.PublicKey { +	return ed25519.PublicKey(ts.PublicKey[:]) +} + +func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { +	return ts.Signature[:], ts.Error +} diff --git a/pkg/mocks/state.go b/pkg/mocks/state.go new file mode 100644 index 0000000..41d8d08 --- /dev/null +++ b/pkg/mocks/state.go @@ -0,0 +1,107 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/system-transparency/stfe/pkg/state (interfaces: StateManager) + +// Package mocks is a generated GoMock package. +package mocks + +import ( +	context "context" +	reflect "reflect" + +	gomock "github.com/golang/mock/gomock" +	types "github.com/system-transparency/stfe/pkg/types" +) + +// MockStateManager is a mock of StateManager interface. +type MockStateManager struct { +	ctrl     *gomock.Controller +	recorder *MockStateManagerMockRecorder +} + +// MockStateManagerMockRecorder is the mock recorder for MockStateManager. +type MockStateManagerMockRecorder struct { +	mock *MockStateManager +} + +// NewMockStateManager creates a new mock instance. +func NewMockStateManager(ctrl *gomock.Controller) *MockStateManager { +	mock := &MockStateManager{ctrl: ctrl} +	mock.recorder = &MockStateManagerMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateManager) EXPECT() *MockStateManagerMockRecorder { +	return m.recorder +} + +// AddCosignature mocks base method. +func (m *MockStateManager) AddCosignature(arg0 context.Context, arg1 *[32]byte, arg2 *[64]byte) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddCosignature", arg0, arg1, arg2) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// AddCosignature indicates an expected call of AddCosignature. +func (mr *MockStateManagerMockRecorder) AddCosignature(arg0, arg1, arg2 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddCosignature", reflect.TypeOf((*MockStateManager)(nil).AddCosignature), arg0, arg1, arg2) +} + +// Cosigned mocks base method. +func (m *MockStateManager) Cosigned(arg0 context.Context) (*types.SignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Cosigned", arg0) +	ret0, _ := ret[0].(*types.SignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// Cosigned indicates an expected call of Cosigned. +func (mr *MockStateManagerMockRecorder) Cosigned(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cosigned", reflect.TypeOf((*MockStateManager)(nil).Cosigned), arg0) +} + +// Latest mocks base method. +func (m *MockStateManager) Latest(arg0 context.Context) (*types.SignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Latest", arg0) +	ret0, _ := ret[0].(*types.SignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// Latest indicates an expected call of Latest. +func (mr *MockStateManagerMockRecorder) Latest(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Latest", reflect.TypeOf((*MockStateManager)(nil).Latest), arg0) +} + +// Run mocks base method. +func (m *MockStateManager) Run(arg0 context.Context) { +	m.ctrl.T.Helper() +	m.ctrl.Call(m, "Run", arg0) +} + +// Run indicates an expected call of Run. +func (mr *MockStateManagerMockRecorder) Run(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockStateManager)(nil).Run), arg0) +} + +// ToSign mocks base method. +func (m *MockStateManager) ToSign(arg0 context.Context) (*types.SignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "ToSign", arg0) +	ret0, _ := ret[0].(*types.SignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// ToSign indicates an expected call of ToSign. +func (mr *MockStateManagerMockRecorder) ToSign(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToSign", reflect.TypeOf((*MockStateManager)(nil).ToSign), arg0) +} diff --git a/pkg/mocks/stfe.go b/pkg/mocks/stfe.go new file mode 100644 index 0000000..def5bc6 --- /dev/null +++ b/pkg/mocks/stfe.go @@ -0,0 +1,110 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/system-transparency/stfe/trillian (interfaces: Client) + +// Package mocks is a generated GoMock package. +package mocks + +import ( +	context "context" +	reflect "reflect" + +	gomock "github.com/golang/mock/gomock" +	types "github.com/system-transparency/stfe/pkg/types" +) + +// MockClient is a mock of Client interface. +type MockClient struct { +	ctrl     *gomock.Controller +	recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { +	mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { +	mock := &MockClient{ctrl: ctrl} +	mock.recorder = &MockClientMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { +	return m.recorder +} + +// AddLeaf mocks base method. +func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *types.LeafRequest) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// AddLeaf indicates an expected call of AddLeaf. +func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1) +} + +// GetConsistencyProof mocks base method. +func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1) +	ret0, _ := ret[0].(*types.ConsistencyProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1) +} + +// GetInclusionProof mocks base method. +func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *types.InclusionProofRequest) (*types.InclusionProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1) +	ret0, _ := ret[0].(*types.InclusionProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1) +} + +// GetLeaves mocks base method. +func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *types.LeavesRequest) (*types.LeafList, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1) +	ret0, _ := ret[0].(*types.LeafList) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeaves indicates an expected call of GetLeaves. +func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1) +} + +// GetTreeHead mocks base method. +func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetTreeHead", arg0) +	ret0, _ := ret[0].(*types.TreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetTreeHead indicates an expected call of GetTreeHead. +func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0) +} diff --git a/pkg/mocks/trillian.go b/pkg/mocks/trillian.go new file mode 100644 index 0000000..8aa3a58 --- /dev/null +++ b/pkg/mocks/trillian.go @@ -0,0 +1,317 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/google/trillian (interfaces: TrillianLogClient) + +// Package mocks is a generated GoMock package. +package mocks + +import ( +	context "context" +	reflect "reflect" + +	gomock "github.com/golang/mock/gomock" +	trillian "github.com/google/trillian" +	grpc "google.golang.org/grpc" +) + +// MockTrillianLogClient is a mock of TrillianLogClient interface. +type MockTrillianLogClient struct { +	ctrl     *gomock.Controller +	recorder *MockTrillianLogClientMockRecorder +} + +// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient. +type MockTrillianLogClientMockRecorder struct { +	mock *MockTrillianLogClient +} + +// NewMockTrillianLogClient creates a new mock instance. +func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient { +	mock := &MockTrillianLogClient{ctrl: ctrl} +	mock.recorder = &MockTrillianLogClientMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder { +	return m.recorder +} + +// AddSequencedLeaf mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...) +	ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddSequencedLeaf indicates an expected call of AddSequencedLeaf. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...) +} + +// AddSequencedLeaves mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...) +	ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...) +} + +// GetConsistencyProof mocks base method. +func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...) +} + +// GetEntryAndProof mocks base method. +func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetEntryAndProof indicates an expected call of GetEntryAndProof. +func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...) +} + +// GetInclusionProof mocks base method. +func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetInclusionProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetInclusionProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...) +} + +// GetInclusionProofByHash mocks base method. +func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...) +	ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...) +} + +// GetLatestSignedLogRoot mocks base method. +func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...) +	ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot. +func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...) +} + +// GetLeavesByHash mocks base method. +func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByHash indicates an expected call of GetLeavesByHash. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...) +} + +// GetLeavesByIndex mocks base method. +func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByIndex indicates an expected call of GetLeavesByIndex. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...) +} + +// GetLeavesByRange mocks base method. +func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByRange indicates an expected call of GetLeavesByRange. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...) +} + +// GetSequencedLeafCount mocks base method. +func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...) +	ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount. +func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...) +} + +// InitLog mocks base method. +func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "InitLog", varargs...) +	ret0, _ := ret[0].(*trillian.InitLogResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// InitLog indicates an expected call of InitLog. +func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...) +} + +// QueueLeaf mocks base method. +func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "QueueLeaf", varargs...) +	ret0, _ := ret[0].(*trillian.QueueLeafResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// QueueLeaf indicates an expected call of QueueLeaf. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...) +} + +// QueueLeaves mocks base method. +func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "QueueLeaves", varargs...) +	ret0, _ := ret[0].(*trillian.QueueLeavesResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// QueueLeaves indicates an expected call of QueueLeaves. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...) +} diff --git a/pkg/state/state_manager.go b/pkg/state/state_manager.go new file mode 100644 index 0000000..dfa73f4 --- /dev/null +++ b/pkg/state/state_manager.go @@ -0,0 +1,154 @@ +package state + +import ( +	"context" +	"crypto" +	"fmt" +	"reflect" +	"sync" +	"time" + +	"github.com/golang/glog" +	"github.com/google/certificate-transparency-go/schedule" +	"github.com/system-transparency/stfe/pkg/trillian" +	"github.com/system-transparency/stfe/pkg/types" +) + +// StateManager coordinates access to the log's tree heads and (co)signatures +type StateManager interface { +	Latest(context.Context) (*types.SignedTreeHead, error) +	ToSign(context.Context) (*types.SignedTreeHead, error) +	Cosigned(context.Context) (*types.SignedTreeHead, error) +	AddCosignature(context.Context, *[types.VerificationKeySize]byte, *[types.SignatureSize]byte) error +	Run(context.Context) +} + +// StateManagerSingle implements the StateManager interface.  It is assumed that +// the log server is running on a single-instance machine.  So, no coordination. +type StateManagerSingle struct { +	client   trillian.Client +	signer   crypto.Signer +	interval time.Duration +	deadline time.Duration +	sync.RWMutex + +	// cosigned is the current cosigned tree head that is being served +	cosigned types.SignedTreeHead + +	// tosign is the current tree head that is being cosigned by witnesses +	tosign types.SignedTreeHead + +	// cosignature keeps track of all cosignatures for the tosign tree head +	cosignature map[[types.HashSize]byte]*types.SigIdent +} + +func NewStateManagerSingle(client trillian.Client, signer crypto.Signer, interval, deadline time.Duration) (*StateManagerSingle, error) { +	sm := &StateManagerSingle{ +		client:   client, +		signer:   signer, +		interval: interval, +		deadline: deadline, +	} + +	ctx, _ := context.WithTimeout(context.Background(), sm.deadline) +	sth, err := sm.Latest(ctx) +	if err != nil { +		return nil, fmt.Errorf("Latest: %v", err) +	} + +	sm.cosigned = *sth +	sm.tosign = *sth +	sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ +		*sth.SigIdent[0].KeyHash: sth.SigIdent[0], // log signature +	} +	return sm, nil +} + +func (sm *StateManagerSingle) Run(ctx context.Context) { +	schedule.Every(ctx, sm.interval, func(ctx context.Context) { +		ictx, _ := context.WithTimeout(ctx, sm.deadline) +		nextTreeHead, err := sm.Latest(ictx) +		if err != nil { +			glog.Warningf("rotate failed: Latest: %v", err) +			return +		} + +		sm.Lock() +		defer sm.Unlock() +		sm.rotate(nextTreeHead) +	}) +} + +func (sm *StateManagerSingle) Latest(ctx context.Context) (*types.SignedTreeHead, error) { +	th, err := sm.client.GetTreeHead(ctx) +	if err != nil { +		return nil, fmt.Errorf("LatestTreeHead: %v", err) +	} +	sth, err := th.Sign(sm.signer) +	if err != nil { +		return nil, fmt.Errorf("sign: %v", err) +	} +	return sth, nil +} + +func (sm *StateManagerSingle) ToSign(_ context.Context) (*types.SignedTreeHead, error) { +	sm.RLock() +	defer sm.RUnlock() +	return &sm.tosign, nil +} + +func (sm *StateManagerSingle) Cosigned(_ context.Context) (*types.SignedTreeHead, error) { +	sm.RLock() +	defer sm.RUnlock() +	return &sm.cosigned, nil +} + +func (sm *StateManagerSingle) AddCosignature(_ context.Context, vk *[types.VerificationKeySize]byte, sig *[types.SignatureSize]byte) error { +	sm.Lock() +	defer sm.Unlock() + +	if err := sm.tosign.TreeHead.Verify(vk, sig); err != nil { +		return fmt.Errorf("Verify: %v", err) +	} +	witness := types.Hash(vk[:]) +	if _, ok := sm.cosignature[*witness]; ok { +		return fmt.Errorf("signature-signer pair is a duplicate") +	} +	sm.cosignature[*witness] = &types.SigIdent{ +		Signature: sig, +		KeyHash:   witness, +	} + +	glog.V(3).Infof("accepted new cosignature from witness: %x", *witness) +	return nil +} + +// rotate rotates the log's cosigned and stable STH.  The caller must aquire the +// source's read-write lock if there are concurrent reads and/or writes. +func (sm *StateManagerSingle) rotate(next *types.SignedTreeHead) { +	if reflect.DeepEqual(sm.cosigned.TreeHead, sm.tosign.TreeHead) { +		// cosigned and tosign are the same.  So, we need to merge all +		// cosignatures that we already had with the new collected ones. +		for _, sigident := range sm.cosigned.SigIdent { +			if _, ok := sm.cosignature[*sigident.KeyHash]; !ok { +				sm.cosignature[*sigident.KeyHash] = sigident +			} +		} +		glog.V(3).Infof("cosigned tree head repeated, merged signatures") +	} +	var cosignatures []*types.SigIdent +	for _, sigident := range sm.cosignature { +		cosignatures = append(cosignatures, sigident) +	} + +	// Update cosigned tree head +	sm.cosigned.TreeHead = sm.tosign.TreeHead +	sm.cosigned.SigIdent = cosignatures + +	// Update to-sign tree head +	sm.tosign = *next +	sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ +		*next.SigIdent[0].KeyHash: next.SigIdent[0], // log signature +	} +	glog.V(3).Infof("rotated tree heads") +} diff --git a/pkg/state/state_manager_test.go b/pkg/state/state_manager_test.go new file mode 100644 index 0000000..08990cc --- /dev/null +++ b/pkg/state/state_manager_test.go @@ -0,0 +1,393 @@ +package state + +import ( +	"bytes" +	"context" +	"crypto" +	"crypto/ed25519" +	"crypto/rand" +	"fmt" +	"reflect" +	"testing" +	"time" + +	"github.com/golang/mock/gomock" +	"github.com/system-transparency/stfe/pkg/mocks" +	"github.com/system-transparency/stfe/pkg/types" +) + +var ( +	testSig = &[types.SignatureSize]byte{} +	testPub = &[types.VerificationKeySize]byte{} +	testTH  = &types.TreeHead{ +		Timestamp: 0, +		TreeSize:  0, +		RootHash:  types.Hash(nil), +	} +	testSigIdent = &types.SigIdent{ +		Signature: testSig, +		KeyHash:   types.Hash(testPub[:]), +	} +	testSTH = &types.SignedTreeHead{ +		TreeHead: *testTH, +		SigIdent: []*types.SigIdent{testSigIdent}, +	} +	testSignerOK  = &mocks.TestSigner{testPub, testSig, nil} +	testSignerErr = &mocks.TestSigner{testPub, testSig, fmt.Errorf("something went wrong")} +) + +func TestNewStateManagerSingle(t *testing.T) { +	for _, table := range []struct { +		description string +		signer      crypto.Signer +		rsp         *types.TreeHead +		err         error +		wantErr     bool +		wantSth     *types.SignedTreeHead +	}{ +		{ +			description: "invalid: backend failure", +			signer:      testSignerOK, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "valid", +			signer:      testSignerOK, +			rsp:         testTH, +			wantSth:     testSTH, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) + +			sm, err := NewStateManagerSingle(client, table.signer, time.Duration(0), time.Duration(0)) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := &sm.cosigned, table.wantSth; !reflect.DeepEqual(got, want) { +				t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +			if got, want := &sm.tosign, table.wantSth; !reflect.DeepEqual(got, want) { +				t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +			// we only have log signature on startup +			if got, want := len(sm.cosignature), 1; got != want { +				t.Errorf("got %d cosignatures but wanted %d in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestLatest(t *testing.T) { +	for _, table := range []struct { +		description string +		signer      crypto.Signer +		rsp         *types.TreeHead +		err         error +		wantErr     bool +		wantSth     *types.SignedTreeHead +	}{ +		{ +			description: "invalid: backend failure", +			signer:      testSignerOK, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: signature failure", +			rsp:         testTH, +			signer:      testSignerErr, +			wantErr:     true, +		}, +		{ +			description: "valid", +			signer:      testSignerOK, +			rsp:         testTH, +			wantSth:     testSTH, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocks.NewMockClient(ctrl) +			client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) +			sm := StateManagerSingle{ +				client: client, +				signer: table.signer, +			} + +			sth, err := sm.Latest(context.Background()) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) { +				t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestToSign(t *testing.T) { +	description := "valid" +	sm := StateManagerSingle{ +		tosign: *testSTH, +	} +	sth, err := sm.ToSign(context.Background()) +	if err != nil { +		t.Errorf("ToSign should not fail with error: %v", err) +		return +	} +	if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { +		t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) +	} +} + +func TestCosigned(t *testing.T) { +	description := "valid" +	sm := StateManagerSingle{ +		cosigned: *testSTH, +	} +	sth, err := sm.Cosigned(context.Background()) +	if err != nil { +		t.Errorf("Cosigned should not fail with error: %v", err) +		return +	} +	if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { +		t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) +	} +} + +func TestAddCosignature(t *testing.T) { +	vk, sk, err := ed25519.GenerateKey(rand.Reader) +	if err != nil { +		t.Fatalf("GenerateKey: %v", err) +	} +	if bytes.Equal(vk[:], testPub[:]) { +		t.Fatalf("Sampled same key as testPub, aborting...") +	} +	var vkArray [types.VerificationKeySize]byte +	copy(vkArray[:], vk[:]) + +	for _, table := range []struct { +		description string +		signer      crypto.Signer +		vk          *[types.VerificationKeySize]byte +		th          *types.TreeHead +		wantErr     bool +	}{ +		{ +			description: "invalid: signature error", +			signer:      sk, +			vk:          testPub, // wrong key for message +			th:          testTH, +			wantErr:     true, +		}, +		{ +			description: "valid", +			signer:      sk, +			vk:          &vkArray, +			th:          testTH, +		}, +	} { +		sth, _ := table.th.Sign(testSignerOK) +		logKeyHash := sth.SigIdent[0].KeyHash +		logSigIdent := sth.SigIdent[0] +		sm := &StateManagerSingle{ +			signer:   testSignerOK, +			cosigned: *sth, +			tosign:   *sth, +			cosignature: map[[types.HashSize]byte]*types.SigIdent{ +				*logKeyHash: logSigIdent, +			}, +		} + +		// Prepare witness signature +		sth, err := table.th.Sign(table.signer) +		if err != nil { +			t.Fatalf("Sign: %v", err) +		} +		witnessKeyHash := sth.SigIdent[0].KeyHash +		witnessSigIdent := sth.SigIdent[0] + +		// Add witness signature +		err = sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} + +		// We should have two signatures (log + witness) +		if got, want := len(sm.cosignature), 2; got != want { +			t.Errorf("got %d cosignatures but wanted %v in test %q", got, want, table.description) +			continue +		} +		// check that log signature is there +		sigident, ok := sm.cosignature[*logKeyHash] +		if !ok { +			t.Errorf("log signature is missing") +			continue +		} +		if got, want := sigident, logSigIdent; !reflect.DeepEqual(got, want) { +			t.Errorf("got log sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +		// check that witness signature is there +		sigident, ok = sm.cosignature[*witnessKeyHash] +		if !ok { +			t.Errorf("witness signature is missing") +			continue +		} +		if got, want := sigident, witnessSigIdent; !reflect.DeepEqual(got, want) { +			t.Errorf("got witness sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			continue +		} + +		// Adding a duplicate signature should give an error +		if err := sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature); err == nil { +			t.Errorf("duplicate witness signature accepted as valid") +		} +	} +} + +func TestRotate(t *testing.T) { +	log := testSigIdent +	wit1 := &types.SigIdent{ +		Signature: testSig, +		KeyHash:   types.Hash([]byte("wit1 key")), +	} +	wit2 := &types.SigIdent{ +		Signature: testSig, +		KeyHash:   types.Hash([]byte("wit2 key")), +	} +	th0 := testTH +	th1 := &types.TreeHead{ +		Timestamp: 1, +		TreeSize:  1, +		RootHash:  types.Hash([]byte("1")), +	} +	th2 := &types.TreeHead{ +		Timestamp: 2, +		TreeSize:  2, +		RootHash:  types.Hash([]byte("2")), +	} + +	for _, table := range []struct { +		description   string +		before, after *StateManagerSingle +		next          *types.SignedTreeHead +	}{ +		{ +			description: "tosign tree head repated, but got one new witnes signature", +			before: &StateManagerSingle{ +				cosigned: types.SignedTreeHead{ +					TreeHead: *th0, +					SigIdent: []*types.SigIdent{log, wit1}, +				}, +				tosign: types.SignedTreeHead{ +					TreeHead: *th0, +					SigIdent: []*types.SigIdent{log}, +				}, +				cosignature: map[[types.HashSize]byte]*types.SigIdent{ +					*log.KeyHash:  log, +					*wit2.KeyHash: wit2, // the new witness signature +				}, +			}, +			next: &types.SignedTreeHead{ +				TreeHead: *th1, +				SigIdent: []*types.SigIdent{log}, +			}, +			after: &StateManagerSingle{ +				cosigned: types.SignedTreeHead{ +					TreeHead: *th0, +					SigIdent: []*types.SigIdent{log, wit1, wit2}, +				}, +				tosign: types.SignedTreeHead{ +					TreeHead: *th1, +					SigIdent: []*types.SigIdent{log}, +				}, +				cosignature: map[[types.HashSize]byte]*types.SigIdent{ +					*log.KeyHash: log, // after rotate we always have log sig +				}, +			}, +		}, +		{ +			description: "tosign tree head did not repeat, it got one witness signature", +			before: &StateManagerSingle{ +				cosigned: types.SignedTreeHead{ +					TreeHead: *th0, +					SigIdent: []*types.SigIdent{log, wit1}, +				}, +				tosign: types.SignedTreeHead{ +					TreeHead: *th1, +					SigIdent: []*types.SigIdent{log}, +				}, +				cosignature: map[[types.HashSize]byte]*types.SigIdent{ +					*log.KeyHash:  log, +					*wit2.KeyHash: wit2, // the only witness that signed tosign +				}, +			}, +			next: &types.SignedTreeHead{ +				TreeHead: *th2, +				SigIdent: []*types.SigIdent{log}, +			}, +			after: &StateManagerSingle{ +				cosigned: types.SignedTreeHead{ +					TreeHead: *th1, +					SigIdent: []*types.SigIdent{log, wit2}, +				}, +				tosign: types.SignedTreeHead{ +					TreeHead: *th2, +					SigIdent: []*types.SigIdent{log}, +				}, +				cosignature: map[[types.HashSize]byte]*types.SigIdent{ +					*log.KeyHash: log, // after rotate we always have log sig +				}, +			}, +		}, +	} { +		table.before.rotate(table.next) +		if got, want := table.before.cosigned.TreeHead, table.after.cosigned.TreeHead; !reflect.DeepEqual(got, want) { +			t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +		checkWitnessList(t, table.description, table.before.cosigned.SigIdent, table.after.cosigned.SigIdent) +		if got, want := table.before.tosign.TreeHead, table.after.tosign.TreeHead; !reflect.DeepEqual(got, want) { +			t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +		checkWitnessList(t, table.description, table.before.tosign.SigIdent, table.after.tosign.SigIdent) +		if got, want := table.before.cosignature, table.after.cosignature; !reflect.DeepEqual(got, want) { +			t.Errorf("got cosignature map\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func checkWitnessList(t *testing.T, description string, got, want []*types.SigIdent) { +	t.Helper() +	for _, si := range got { +		found := false +		for _, sj := range want { +			if reflect.DeepEqual(si, sj) { +				found = true +				break +			} +		} +		if !found { +			t.Errorf("got unexpected signature-signer pair with key hash in test %q: %x", description, si.KeyHash[:]) +		} +	} +	if len(got) != len(want) { +		t.Errorf("got %d signature-signer pairs but wanted %d in test %q", len(got), len(want), description) +	} +} diff --git a/pkg/trillian/client.go b/pkg/trillian/client.go new file mode 100644 index 0000000..9523e56 --- /dev/null +++ b/pkg/trillian/client.go @@ -0,0 +1,178 @@ +package trillian + +import ( +	"context" +	"fmt" + +	"github.com/golang/glog" +	"github.com/google/trillian" +	ttypes "github.com/google/trillian/types" +	"github.com/system-transparency/stfe/pkg/types" +	"google.golang.org/grpc/codes" +) + +type Client interface { +	AddLeaf(context.Context, *types.LeafRequest) error +	GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) +	GetTreeHead(context.Context) (*types.TreeHead, error) +	GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error) +	GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error) +} + +// TrillianClient is a wrapper around the Trillian gRPC client. +type TrillianClient struct { +	// TreeID is a Merkle tree identifier that Trillian uses +	TreeID int64 + +	// GRPC is a Trillian gRPC client +	GRPC trillian.TrillianLogClient +} + +func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error { +	leaf := types.Leaf{ +		Message: req.Message, +		SigIdent: types.SigIdent{ +			Signature: req.Signature, +			KeyHash:   types.Hash(req.VerificationKey[:]), +		}, +	} +	serialized := leaf.Marshal() + +	glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized)) +	rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{ +		LogId: c.TreeID, +		Leaf: &trillian.LogLeaf{ +			LeafValue: serialized, +		}, +	}) +	if err != nil { +		return fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return fmt.Errorf("no response") +	} +	if rsp.QueuedLeaf == nil { +		return fmt.Errorf("no queued leaf") +	} +	if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists { +		return fmt.Errorf("leaf is already queued or included") +	} +	return nil +} + +func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) { +	rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ +		LogId: c.TreeID, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if rsp.SignedLogRoot == nil { +		return nil, fmt.Errorf("no signed log root") +	} +	if rsp.SignedLogRoot.LogRoot == nil { +		return nil, fmt.Errorf("no log root") +	} +	var r ttypes.LogRootV1 +	if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { +		return nil, fmt.Errorf("no log root: unmarshal failed: %v", err) +	} +	if len(r.RootHash) != types.HashSize { +		return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash)) +	} +	return treeHeadFromLogRoot(&r), nil +} + +func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { +	rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ +		LogId:          c.TreeID, +		FirstTreeSize:  int64(req.OldSize), +		SecondTreeSize: int64(req.NewSize), +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if rsp.Proof == nil { +		return nil, fmt.Errorf("no consistency proof") +	} +	if len(rsp.Proof.Hashes) == 0 { +		return nil, fmt.Errorf("not a consistency proof: empty") +	} +	path, err := nodePathFromHashes(rsp.Proof.Hashes) +	if err != nil { +		return nil, fmt.Errorf("not a consistency proof: %v", err) +	} +	return &types.ConsistencyProof{ +		OldSize: req.OldSize, +		NewSize: req.NewSize, +		Path:    path, +	}, nil +} + +func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) { +	rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ +		LogId:           c.TreeID, +		LeafHash:        req.LeafHash[:], +		TreeSize:        int64(req.TreeSize), +		OrderBySequence: true, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if len(rsp.Proof) != 1 { +		return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof)) +	} +	proof := rsp.Proof[0] +	if len(proof.Hashes) == 0 { +		return nil, fmt.Errorf("not an inclusion proof: empty") +	} +	path, err := nodePathFromHashes(proof.Hashes) +	if err != nil { +		return nil, fmt.Errorf("not an inclusion proof: %v", err) +	} +	return &types.InclusionProof{ +		TreeSize:  req.TreeSize, +		LeafIndex: uint64(proof.LeafIndex), +		Path:      path, +	}, nil +} + +func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) { +	rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ +		LogId:      c.TreeID, +		StartIndex: int64(req.StartSize), +		Count:      int64(req.EndSize-req.StartSize) + 1, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { +		return nil, fmt.Errorf("unexpected number of leaves: %d", got) +	} +	var list types.LeafList +	for i, leaf := range rsp.Leaves { +		leafIndex := int64(req.StartSize + uint64(i)) +		if leafIndex != leaf.LeafIndex { +			return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) +		} + +		var l types.Leaf +		if err := l.Unmarshal(leaf.LeafValue); err != nil { +			return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) +		} +		list = append(list[:], &l) +	} +	return &list, nil +} diff --git a/pkg/trillian/client_test.go b/pkg/trillian/client_test.go new file mode 100644 index 0000000..6b3d881 --- /dev/null +++ b/pkg/trillian/client_test.go @@ -0,0 +1,533 @@ +package trillian + +import ( +	"context" +	"fmt" +	"reflect" +	"testing" + +	"github.com/golang/mock/gomock" +	"github.com/google/trillian" +	ttypes "github.com/google/trillian/types" +	"github.com/system-transparency/stfe/pkg/mocks" +	"github.com/system-transparency/stfe/pkg/types" +	"google.golang.org/grpc/codes" +	"google.golang.org/grpc/status" +) + +func TestAddLeaf(t *testing.T) { +	req := &types.LeafRequest{ +		Message: types.Message{ +			ShardHint: 0, +			Checksum:  &[types.HashSize]byte{}, +		}, +		Signature:       &[types.SignatureSize]byte{}, +		VerificationKey: &[types.VerificationKeySize]byte{}, +		DomainHint:      "example.com", +	} +	for _, table := range []struct { +		description string +		req         *types.LeafRequest +		rsp         *trillian.QueueLeafResponse +		err         error +		wantErr     bool +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: no queued leaf", +			req:         req, +			rsp:         &trillian.QueueLeafResponse{}, +			wantErr:     true, +		}, +		{ +			description: "invalid: leaf is already queued or included", +			req:         req, +			rsp: &trillian.QueueLeafResponse{ +				QueuedLeaf: &trillian.QueuedLogLeaf{ +					Leaf: &trillian.LogLeaf{ +						LeafValue: req.Message.Marshal(), +					}, +					Status: status.New(codes.AlreadyExists, "duplicate").Proto(), +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.QueueLeafResponse{ +				QueuedLeaf: &trillian.QueuedLogLeaf{ +					Leaf: &trillian.LogLeaf{ +						LeafValue: req.Message.Marshal(), +					}, +					Status: status.New(codes.OK, "ok").Proto(), +				}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocks.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			err := client.AddLeaf(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +		}() +	} +} + +func TestGetTreeHead(t *testing.T) { +	// valid root +	root := &ttypes.LogRootV1{ +		TreeSize:       0, +		RootHash:       make([]byte, types.HashSize), +		TimestampNanos: 1622585623133599429, +	} +	buf, err := root.MarshalBinary() +	if err != nil { +		t.Fatalf("must marshal log root: %v", err) +	} +	// invalid root +	root.RootHash = make([]byte, types.HashSize+1) +	bufBadHash, err := root.MarshalBinary() +	if err != nil { +		t.Fatalf("must marshal log root: %v", err) +	} + +	for _, table := range []struct { +		description string +		rsp         *trillian.GetLatestSignedLogRootResponse +		err         error +		wantErr     bool +		wantTh      *types.TreeHead +	}{ +		{ +			description: "invalid: backend failure", +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			wantErr:     true, +		}, +		{ +			description: "invalid: no signed log root", +			rsp:         &trillian.GetLatestSignedLogRootResponse{}, +			wantErr:     true, +		}, +		{ +			description: "invalid: no log root", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: no log root: unmarshal failed", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: buf[1:], +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected hash length", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: bufBadHash, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: buf, +				}, +			}, +			wantTh: &types.TreeHead{ +				Timestamp: 1622585623, +				TreeSize:  0, +				RootHash:  &[types.HashSize]byte{}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocks.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			th, err := client.GetTreeHead(context.Background()) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) { +				t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetConsistencyProof(t *testing.T) { +	req := &types.ConsistencyProofRequest{ +		OldSize: 1, +		NewSize: 3, +	} +	for _, table := range []struct { +		description string +		req         *types.ConsistencyProofRequest +		rsp         *trillian.GetConsistencyProofResponse +		err         error +		wantErr     bool +		wantProof   *types.ConsistencyProof +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: no consistency proof", +			req:         req, +			rsp:         &trillian.GetConsistencyProofResponse{}, +			wantErr:     true, +		}, +		{ +			description: "invalid: not a consistency proof (1/2)", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not a consistency proof (2/2)", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{ +						make([]byte, types.HashSize), +						make([]byte, types.HashSize+1), +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{ +						make([]byte, types.HashSize), +						make([]byte, types.HashSize), +					}, +				}, +			}, +			wantProof: &types.ConsistencyProof{ +				OldSize: 1, +				NewSize: 3, +				Path: []*[types.HashSize]byte{ +					&[types.HashSize]byte{}, +					&[types.HashSize]byte{}, +				}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocks.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			proof, err := client.GetConsistencyProof(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { +				t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetInclusionProof(t *testing.T) { +	req := &types.InclusionProofRequest{ +		TreeSize: 4, +		LeafHash: &[types.HashSize]byte{}, +	} +	for _, table := range []struct { +		description string +		req         *types.InclusionProofRequest +		rsp         *trillian.GetInclusionProofByHashResponse +		err         error +		wantErr     bool +		wantProof   *types.InclusionProof +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: bad proof count", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{}, +					&trillian.Proof{}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not an inclusion proof (1/2)", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes:    [][]byte{}, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not an inclusion proof (2/2)", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes: [][]byte{ +							make([]byte, types.HashSize), +							make([]byte, types.HashSize+1), +						}, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes: [][]byte{ +							make([]byte, types.HashSize), +							make([]byte, types.HashSize), +						}, +					}, +				}, +			}, +			wantProof: &types.InclusionProof{ +				TreeSize:  4, +				LeafIndex: 1, +				Path: []*[types.HashSize]byte{ +					&[types.HashSize]byte{}, +					&[types.HashSize]byte{}, +				}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocks.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			proof, err := client.GetInclusionProof(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { +				t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetLeaves(t *testing.T) { +	req := &types.LeavesRequest{ +		StartSize: 1, +		EndSize:   2, +	} +	firstLeaf := &types.Leaf{ +		Message: types.Message{ +			ShardHint: 0, +			Checksum:  &[types.HashSize]byte{}, +		}, +		SigIdent: types.SigIdent{ +			Signature: &[types.SignatureSize]byte{}, +			KeyHash:   &[types.HashSize]byte{}, +		}, +	} +	secondLeaf := &types.Leaf{ +		Message: types.Message{ +			ShardHint: 0, +			Checksum:  &[types.HashSize]byte{}, +		}, +		SigIdent: types.SigIdent{ +			Signature: &[types.SignatureSize]byte{}, +			KeyHash:   &[types.HashSize]byte{}, +		}, +	} + +	for _, table := range []struct { +		description string +		req         *types.LeavesRequest +		rsp         *trillian.GetLeavesByRangeResponse +		err         error +		wantErr     bool +		wantLeaves  *types.LeafList +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: unexpected number of leaves", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.Marshal(), +						LeafIndex: 1, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected leaf (1/2)", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.Marshal(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.Marshal(), +						LeafIndex: 3, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected leaf (2/2)", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.Marshal(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.Marshal()[1:], +						LeafIndex: 2, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.Marshal(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.Marshal(), +						LeafIndex: 2, +					}, +				}, +			}, +			wantLeaves: &types.LeafList{ +				firstLeaf, +				secondLeaf, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocks.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			leaves, err := client.GetLeaves(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { +				t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} diff --git a/pkg/trillian/util.go b/pkg/trillian/util.go new file mode 100644 index 0000000..4cf31fb --- /dev/null +++ b/pkg/trillian/util.go @@ -0,0 +1,33 @@ +package trillian + +import ( +	"fmt" + +	trillian "github.com/google/trillian/types" +	siglog "github.com/system-transparency/stfe/pkg/types" +) + +func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead { +	var hash [siglog.HashSize]byte +	th := siglog.TreeHead{ +		Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000), +		TreeSize:  uint64(lr.TreeSize), +		RootHash:  &hash, +	} +	copy(th.RootHash[:], lr.RootHash) +	return &th +} + +func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) { +	var path []*[siglog.HashSize]byte +	for _, hash := range hashes { +		if len(hash) != siglog.HashSize { +			return nil, fmt.Errorf("unexpected hash length: %v", len(hash)) +		} + +		var h [siglog.HashSize]byte +		copy(h[:], hash) +		path = append(path, &h) +	} +	return path, nil +} diff --git a/pkg/types/ascii.go b/pkg/types/ascii.go new file mode 100644 index 0000000..d27d79b --- /dev/null +++ b/pkg/types/ascii.go @@ -0,0 +1,421 @@ +package types + +import ( +	"bytes" +	"encoding/hex" +	"fmt" +	"io" +	"io/ioutil" +	"strconv" +) + +const ( +	// Delim is a key-value separator +	Delim = "=" + +	// EOL is a line sepator +	EOL = "\n" + +	// NumField* is the number of unique keys in an incoming ASCII message +	NumFieldLeaf                    = 4 +	NumFieldSignedTreeHead          = 5 +	NumFieldConsistencyProof        = 3 +	NumFieldInclusionProof          = 3 +	NumFieldLeavesRequest           = 2 +	NumFieldInclusionProofRequest   = 2 +	NumFieldConsistencyProofRequest = 2 +	NumFieldLeafRequest             = 5 +	NumFieldCosignatureRequest      = 2 + +	// New leaf keys +	ShardHint            = "shard_hint" +	Checksum             = "checksum" +	SignatureOverMessage = "signature_over_message" +	VerificationKey      = "verification_key" +	DomainHint           = "domain_hint" + +	// Inclusion proof keys +	LeafHash      = "leaf_hash" +	LeafIndex     = "leaf_index" +	InclusionPath = "inclusion_path" + +	// Consistency proof keys +	NewSize         = "new_size" +	OldSize         = "old_size" +	ConsistencyPath = "consistency_path" + +	// Range of leaves keys +	StartSize = "start_size" +	EndSize   = "end_size" + +	// Tree head keys +	Timestamp = "timestamp" +	TreeSize  = "tree_size" +	RootHash  = "root_hash" + +	// Signature and signer-identity keys +	Signature = "signature" +	KeyHash   = "key_hash" +) + +// MessageASCI is a wrapper that manages ASCII key-value pairs +type MessageASCII struct { +	m map[string][]string +} + +// NewMessageASCII unpacks an incoming ASCII message +func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) { +	buf, err := ioutil.ReadAll(r) +	if err != nil { +		return nil, fmt.Errorf("ReadAll: %v", err) +	} +	lines := bytes.Split(buf, []byte(EOL)) +	if len(lines) <= 1 { +		return nil, fmt.Errorf("Not enough lines: empty") +	} +	lines = lines[:len(lines)-1] // valid message => split gives empty last line + +	msg := MessageASCII{make(map[string][]string)} +	for _, line := range lines { +		split := bytes.Index(line, []byte(Delim)) +		if split == -1 { +			return nil, fmt.Errorf("invalid line: %v", string(line)) +		} + +		key := string(line[:split]) +		value := string(line[split+len(Delim):]) +		values, ok := msg.m[key] +		if !ok { +			values = nil +			msg.m[key] = values +		} +		msg.m[key] = append(values, value) +	} + +	if msg.NumField() != numFieldExpected { +		return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField()) +	} +	return &msg, nil +} + +// NumField returns the number of unique keys +func (msg *MessageASCII) NumField() int { +	return len(msg.m) +} + +// GetStrings returns a list of strings +func (msg *MessageASCII) GetStrings(key string) []string { +	strs, ok := msg.m[key] +	if !ok { +		return nil +	} +	return strs +} + +// GetString unpacks a string +func (msg *MessageASCII) GetString(key string) (string, error) { +	strs := msg.GetStrings(key) +	if len(strs) != 1 { +		return "", fmt.Errorf("expected one string: %v", strs) +	} +	return strs[0], nil +} + +// GetUint64 unpacks an uint64 +func (msg *MessageASCII) GetUint64(key string) (uint64, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return 0, fmt.Errorf("GetString: %v", err) +	} +	num, err := strconv.ParseUint(str, 10, 64) +	if err != nil { +		return 0, fmt.Errorf("ParseUint: %v", err) +	} +	return num, nil +} + +// GetHash unpacks a hash +func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var hash [HashSize]byte +	if err := decodeHex(str, hash[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &hash, nil +} + +// GetSignature unpacks a signature +func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var signature [SignatureSize]byte +	if err := decodeHex(str, signature[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &signature, nil +} + +// GetVerificationKey unpacks a verification key +func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var vk [VerificationKeySize]byte +	if err := decodeHex(str, vk[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &vk, nil +} + +// decodeHex decodes a hex-encoded string into an already-sized byte slice +func decodeHex(str string, out []byte) error { +	buf, err := hex.DecodeString(str) +	if err != nil { +		return fmt.Errorf("DecodeString: %v", err) +	} +	if len(buf) != len(out) { +		return fmt.Errorf("invalid length: %v", len(buf)) +	} +	copy(out, buf) +	return nil +} + +/* + * + * MarshalASCII wrappers for types that the log server outputs + * + */ +func (l *Leaf) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, SignatureOverMessage, hex.EncodeToString(l.Signature[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	return nil +} + +func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, sigident := range sth.SigIdent { +		if err := sigident.MarshalASCII(w); err != nil { +			return fmt.Errorf("MarshalASCII: %v", err) +		} +	} +	return nil +} + +func (si *SigIdent) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, Signature, hex.EncodeToString(si.Signature[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, KeyHash, hex.EncodeToString(si.KeyHash[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	return nil +} + +func (p *ConsistencyProof) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, hash := range p.Path { +		if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +	} +	return nil +} + +func (p *InclusionProof) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, hash := range p.Path { +		if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +	} +	return nil +} + +func writeASCII(w io.Writer, key, value string) error { +	if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil { +		return fmt.Errorf("Fprintf: %v", err) +	} +	return nil +} + +/* + * + * Unmarshal ASCII wrappers that the log server and/or log clients receive. + * + */ +func (ll *LeafList) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldSignedTreeHead) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	// TreeHead +	if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil { +		return fmt.Errorf("GetUint64(Timestamp): %v", err) +	} +	if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil { +		return fmt.Errorf("GetUint64(TreeSize): %v", err) +	} +	if sth.RootHash, err = msg.GetHash(RootHash); err != nil { +		return fmt.Errorf("GetHash(RootHash): %v", err) +	} + +	// SigIdent +	signatures := msg.GetStrings(Signature) +	if len(signatures) == 0 { +		return fmt.Errorf("no signer") +	} +	keyHashes := msg.GetStrings(KeyHash) +	if len(signatures) != len(keyHashes) { +		return fmt.Errorf("mismatched signature-signer count") +	} +	sth.SigIdent = make([]*SigIdent, 0, len(signatures)) +	for i, n := 0, len(signatures); i < n; i++ { +		var signature [SignatureSize]byte +		if err := decodeHex(signatures[i], signature[:]); err != nil { +			return fmt.Errorf("decodeHex: %v", err) +		} +		var hash [HashSize]byte +		if err := decodeHex(keyHashes[i], hash[:]); err != nil { +			return fmt.Errorf("decodeHex: %v", err) +		} +		sth.SigIdent = append(sth.SigIdent, &SigIdent{ +			Signature: &signature, +			KeyHash:   &hash, +		}) +	} +	return nil +} + +func (p *InclusionProof) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.LeafHash, err = msg.GetHash(LeafHash); err != nil { +		return fmt.Errorf("GetHash(LeafHash): %v", err) +	} +	if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil { +		return fmt.Errorf("GetUint64(TreeSize): %v", err) +	} +	return nil +} + +func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.NewSize, err = msg.GetUint64(NewSize); err != nil { +		return fmt.Errorf("GetUint64(NewSize): %v", err) +	} +	if req.OldSize, err = msg.GetUint64(OldSize); err != nil { +		return fmt.Errorf("GetUint64(OldSize): %v", err) +	} +	return nil +} + +func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldLeavesRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.StartSize, err = msg.GetUint64(StartSize); err != nil { +		return fmt.Errorf("GetUint64(StartSize): %v", err) +	} +	if req.EndSize, err = msg.GetUint64(EndSize); err != nil { +		return fmt.Errorf("GetUint64(EndSize): %v", err) +	} +	return nil +} + +func (req *LeafRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldLeafRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil { +		return fmt.Errorf("GetUint64(ShardHint): %v", err) +	} +	if req.Checksum, err = msg.GetHash(Checksum); err != nil { +		return fmt.Errorf("GetHash(Checksum): %v", err) +	} +	if req.Signature, err = msg.GetSignature(SignatureOverMessage); err != nil { +		return fmt.Errorf("GetSignature: %v", err) +	} +	if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil { +		return fmt.Errorf("GetVerificationKey: %v", err) +	} +	if req.DomainHint, err = msg.GetString(DomainHint); err != nil { +		return fmt.Errorf("GetString(DomainHint): %v", err) +	} +	return nil +} + +func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldCosignatureRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.Signature, err = msg.GetSignature(Signature); err != nil { +		return fmt.Errorf("GetSignature: %v", err) +	} +	if req.KeyHash, err = msg.GetHash(KeyHash); err != nil { +		return fmt.Errorf("GetHash(KeyHash): %v", err) +	} +	return nil +} diff --git a/pkg/types/ascii_test.go b/pkg/types/ascii_test.go new file mode 100644 index 0000000..92732f9 --- /dev/null +++ b/pkg/types/ascii_test.go @@ -0,0 +1,465 @@ +package types + +import ( +	"bytes" +	"fmt" +	"io" +	"reflect" +	"testing" +) + +/* + * + * MessageASCII methods and helpers + * + */ +func TestNewMessageASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		input       io.Reader +		wantErr     bool +		wantMap     map[string][]string +	}{ +		{ +			description: "invalid: not enough lines", +			input:       bytes.NewBufferString(""), +			wantErr:     true, +		}, +		{ +			description: "invalid: lines must end with new line", +			input:       bytes.NewBufferString("k1=v1\nk2=v2"), +			wantErr:     true, +		}, +		{ +			description: "invalid: lines must not be empty", +			input:       bytes.NewBufferString("k1=v1\n\nk2=v2\n"), +			wantErr:     true, +		}, +		{ +			description: "invalid: wrong number of fields", +			input:       bytes.NewBufferString("k1=v1\n"), +			wantErr:     true, +		}, +		{ +			description: "valid", +			input:       bytes.NewBufferString("k1=v1\nk2=v2\nk2=v3=4\n"), +			wantMap: map[string][]string{ +				"k1": []string{"v1"}, +				"k2": []string{"v2", "v3=4"}, +			}, +		}, +	} { +		msg, err := NewMessageASCII(table.input, len(table.wantMap)) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := msg.m, table.wantMap; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestNumField(t *testing.T)           {} +func TestGetStrings(t *testing.T)         {} +func TestGetString(t *testing.T)          {} +func TestGetUint64(t *testing.T)          {} +func TestGetHash(t *testing.T)            {} +func TestGetSignature(t *testing.T)       {} +func TestGetVerificationKey(t *testing.T) {} +func TestDecodeHex(t *testing.T)          {} + +/* + * + * MarshalASCII methods and helpers + * + */ +func TestLeafMarshalASCII(t *testing.T) { +	description := "valid: two leaves" +	leafList := []*Leaf{ +		&Leaf{ +			Message: Message{ +				ShardHint: 123, +				Checksum:  testBuffer32, +			}, +			SigIdent: SigIdent{ +				Signature: testBuffer64, +				KeyHash:   testBuffer32, +			}, +		}, +		&Leaf{ +			Message: Message{ +				ShardHint: 456, +				Checksum:  testBuffer32, +			}, +			SigIdent: SigIdent{ +				Signature: testBuffer64, +				KeyHash:   testBuffer32, +			}, +		}, +	} +	wantBuf := bytes.NewBufferString(fmt.Sprintf( +		"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+ +			"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", +		// Leaf 1 +		ShardHint, Delim, 123, EOL, +		Checksum, Delim, testBuffer32[:], EOL, +		SignatureOverMessage, Delim, testBuffer64[:], EOL, +		KeyHash, Delim, testBuffer32[:], EOL, +		// Leaf 2 +		ShardHint, Delim, 456, EOL, +		Checksum, Delim, testBuffer32[:], EOL, +		SignatureOverMessage, Delim, testBuffer64[:], EOL, +		KeyHash, Delim, testBuffer32[:], EOL, +	)) +	buf := bytes.NewBuffer(nil) +	for _, leaf := range leafList { +		if err := leaf.MarshalASCII(buf); err != nil { +			t.Errorf("expected error %v but got %v in test %q: %v", false, true, description, err) +			return +		} +	} +	if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { +		t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) +	} +} + +func TestSignedTreeHeadMarshalASCII(t *testing.T) { +	description := "valid" +	sth := &SignedTreeHead{ +		TreeHead: TreeHead{ +			Timestamp: 123, +			TreeSize:  456, +			RootHash:  testBuffer32, +		}, +		SigIdent: []*SigIdent{ +			&SigIdent{ +				Signature: testBuffer64, +				KeyHash:   testBuffer32, +			}, +			&SigIdent{ +				Signature: testBuffer64, +				KeyHash:   testBuffer32, +			}, +		}, +	} +	wantBuf := bytes.NewBufferString(fmt.Sprintf( +		"%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", +		Timestamp, Delim, 123, EOL, +		TreeSize, Delim, 456, EOL, +		RootHash, Delim, testBuffer32[:], EOL, +		Signature, Delim, testBuffer64[:], EOL, +		KeyHash, Delim, testBuffer32[:], EOL, +		Signature, Delim, testBuffer64[:], EOL, +		KeyHash, Delim, testBuffer32[:], EOL, +	)) +	buf := bytes.NewBuffer(nil) +	if err := sth.MarshalASCII(buf); err != nil { +		t.Errorf("expected error %v but got %v in test %q", false, true, description) +		return +	} +	if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { +		t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) +	} +} + +func TestInclusionProofMarshalASCII(t *testing.T) { +	description := "valid" +	proof := InclusionProof{ +		TreeSize:  321, +		LeafIndex: 123, +		Path: []*[HashSize]byte{ +			testBuffer32, +			testBuffer32, +		}, +	} +	wantBuf := bytes.NewBufferString(fmt.Sprintf( +		"%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", +		TreeSize, Delim, 321, EOL, +		LeafIndex, Delim, 123, EOL, +		InclusionPath, Delim, testBuffer32[:], EOL, +		InclusionPath, Delim, testBuffer32[:], EOL, +	)) +	buf := bytes.NewBuffer(nil) +	if err := proof.MarshalASCII(buf); err != nil { +		t.Errorf("expected error %v but got %v in test %q", false, true, description) +		return +	} +	if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { +		t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) +	} +} + +func TestConsistencyProofMarshalASCII(t *testing.T) { +	description := "valid" +	proof := ConsistencyProof{ +		NewSize: 321, +		OldSize: 123, +		Path: []*[HashSize]byte{ +			testBuffer32, +			testBuffer32, +		}, +	} +	wantBuf := bytes.NewBufferString(fmt.Sprintf( +		"%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", +		NewSize, Delim, 321, EOL, +		OldSize, Delim, 123, EOL, +		ConsistencyPath, Delim, testBuffer32[:], EOL, +		ConsistencyPath, Delim, testBuffer32[:], EOL, +	)) +	buf := bytes.NewBuffer(nil) +	if err := proof.MarshalASCII(buf); err != nil { +		t.Errorf("expected error %v but got %v in test %q", false, true, description) +		return +	} +	if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { +		t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) +	} +} + +func TestWriteASCII(t *testing.T) { +} + +/* + * + * UnmarshalASCII methods and helpers + * + */ +func TestLeafListUnmarshalASCII(t *testing.T) {} + +func TestSignedTreeHeadUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantSth     *SignedTreeHead +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", +				Timestamp, Delim, 123, EOL, +				TreeSize, Delim, 456, EOL, +				RootHash, Delim, testBuffer32[:], EOL, +				Signature, Delim, testBuffer64[:], EOL, +				KeyHash, Delim, testBuffer32[:], EOL, +				Signature, Delim, testBuffer64[:], EOL, +				KeyHash, Delim, testBuffer32[:], EOL, +			)), +			wantSth: &SignedTreeHead{ +				TreeHead: TreeHead{ +					Timestamp: 123, +					TreeSize:  456, +					RootHash:  testBuffer32, +				}, +				SigIdent: []*SigIdent{ +					&SigIdent{ +						Signature: testBuffer64, +						KeyHash:   testBuffer32, +					}, +					&SigIdent{ +						Signature: testBuffer64, +						KeyHash:   testBuffer32, +					}, +				}, +			}, +		}, +	} { +		var sth SignedTreeHead +		err := sth.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &sth, table.wantSth; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestInclusionProofUnmarshalASCII(t *testing.T)   {} +func TestConsistencyProofUnmarshalASCII(t *testing.T) {} + +func TestInclusionProofRequestUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantReq     *InclusionProofRequest +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%x%s"+"%s%s%d%s", +				LeafHash, Delim, testBuffer32[:], EOL, +				TreeSize, Delim, 123, EOL, +			)), +			wantReq: &InclusionProofRequest{ +				LeafHash: testBuffer32, +				TreeSize: 123, +			}, +		}, +	} { +		var req InclusionProofRequest +		err := req.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestConsistencyProofRequestUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantReq     *ConsistencyProofRequest +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%d%s"+"%s%s%d%s", +				NewSize, Delim, 321, EOL, +				OldSize, Delim, 123, EOL, +			)), +			wantReq: &ConsistencyProofRequest{ +				NewSize: 321, +				OldSize: 123, +			}, +		}, +	} { +		var req ConsistencyProofRequest +		err := req.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestLeavesRequestUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantReq     *LeavesRequest +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%d%s"+"%s%s%d%s", +				StartSize, Delim, 123, EOL, +				EndSize, Delim, 456, EOL, +			)), +			wantReq: &LeavesRequest{ +				StartSize: 123, +				EndSize:   456, +			}, +		}, +	} { +		var req LeavesRequest +		err := req.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestLeafRequestUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantReq     *LeafRequest +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%s%s", +				ShardHint, Delim, 123, EOL, +				Checksum, Delim, testBuffer32[:], EOL, +				SignatureOverMessage, Delim, testBuffer64[:], EOL, +				VerificationKey, Delim, testBuffer32[:], EOL, +				DomainHint, Delim, "example.com", EOL, +			)), +			wantReq: &LeafRequest{ +				Message: Message{ +					ShardHint: 123, +					Checksum:  testBuffer32, +				}, +				Signature:       testBuffer64, +				VerificationKey: testBuffer32, +				DomainHint:      "example.com", +			}, +		}, +	} { +		var req LeafRequest +		err := req.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} + +func TestCosignatureRequestUnmarshalASCII(t *testing.T) { +	for _, table := range []struct { +		description string +		buf         io.Reader +		wantErr     bool +		wantReq     *CosignatureRequest +	}{ +		{ +			description: "valid", +			buf: bytes.NewBufferString(fmt.Sprintf( +				"%s%s%x%s"+"%s%s%x%s", +				Signature, Delim, testBuffer64[:], EOL, +				KeyHash, Delim, testBuffer32[:], EOL, +			)), +			wantReq: &CosignatureRequest{ +				SigIdent: SigIdent{ +					Signature: testBuffer64, +					KeyHash:   testBuffer32, +				}, +			}, +		}, +	} { +		var req CosignatureRequest +		err := req.UnmarshalASCII(table.buf) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { +			t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +		} +	} +} diff --git a/pkg/types/trunnel.go b/pkg/types/trunnel.go new file mode 100644 index 0000000..268f6f7 --- /dev/null +++ b/pkg/types/trunnel.go @@ -0,0 +1,60 @@ +package types + +import ( +	"encoding/binary" +	"fmt" +) + +const ( +	// MessageSize is the number of bytes in a Trunnel-encoded leaf message +	MessageSize = 8 + HashSize +	// LeafSize is the number of bytes in a Trunnel-encoded leaf +	LeafSize = MessageSize + SignatureSize + HashSize +) + +// Marshal returns a Trunnel-encoded message +func (m *Message) Marshal() []byte { +	buf := make([]byte, MessageSize) +	binary.BigEndian.PutUint64(buf, m.ShardHint) +	copy(buf[8:], m.Checksum[:]) +	return buf +} + +// Marshal returns a Trunnel-encoded leaf +func (l *Leaf) Marshal() []byte { +	buf := l.Message.Marshal() +	buf = append(buf, l.SigIdent.Signature[:]...) +	buf = append(buf, l.SigIdent.KeyHash[:]...) +	return buf +} + +// Marshal returns a Trunnel-encoded tree head +func (th *TreeHead) Marshal() []byte { +	buf := make([]byte, 8+8+HashSize) +	binary.BigEndian.PutUint64(buf[0:8], th.Timestamp) +	binary.BigEndian.PutUint64(buf[8:16], th.TreeSize) +	copy(buf[16:], th.RootHash[:]) +	return buf +} + +// Unmarshal parses the Trunnel-encoded buffer as a leaf +func (l *Leaf) Unmarshal(buf []byte) error { +	if len(buf) != LeafSize { +		return fmt.Errorf("invalid leaf size: %v", len(buf)) +	} +	// Shard hint +	l.ShardHint = binary.BigEndian.Uint64(buf) +	offset := 8 +	// Checksum +	l.Checksum = &[HashSize]byte{} +	copy(l.Checksum[:], buf[offset:offset+HashSize]) +	offset += HashSize +	// Signature +	l.Signature = &[SignatureSize]byte{} +	copy(l.Signature[:], buf[offset:offset+SignatureSize]) +	offset += SignatureSize +	// KeyHash +	l.KeyHash = &[HashSize]byte{} +	copy(l.KeyHash[:], buf[offset:]) +	return nil +} diff --git a/pkg/types/trunnel_test.go b/pkg/types/trunnel_test.go new file mode 100644 index 0000000..297578c --- /dev/null +++ b/pkg/types/trunnel_test.go @@ -0,0 +1,114 @@ +package types + +import ( +	"bytes" +	"reflect" +	"testing" +) + +var ( +	testBuffer32 = &[32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31} +	testBuffer64 = &[64]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63} +) + +func TestMarshalMessage(t *testing.T) { +	description := "valid: shard hint 72623859790382856, checksum 0x00,0x01,..." +	message := &Message{ +		ShardHint: 72623859790382856, +		Checksum:  testBuffer32, +	} +	want := bytes.Join([][]byte{ +		[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, +		testBuffer32[:], +	}, nil) +	if got := message.Marshal(); !bytes.Equal(got, want) { +		t.Errorf("got message\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) +	} +} + +func TestMarshalLeaf(t *testing.T) { +	description := "valid: shard hint 72623859790382856, buffers 0x00,0x01,..." +	leaf := &Leaf{ +		Message: Message{ +			ShardHint: 72623859790382856, +			Checksum:  testBuffer32, +		}, +		SigIdent: SigIdent{ +			Signature: testBuffer64, +			KeyHash:   testBuffer32, +		}, +	} +	want := bytes.Join([][]byte{ +		[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, +		testBuffer32[:], testBuffer64[:], testBuffer32[:], +	}, nil) +	if got := leaf.Marshal(); !bytes.Equal(got, want) { +		t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) +	} +} + +func TestMarshalTreeHead(t *testing.T) { +	description := "valid: timestamp 16909060, tree size 72623859790382856, root hash 0x00,0x01,..." +	th := &TreeHead{ +		Timestamp: 16909060, +		TreeSize:  72623859790382856, +		RootHash:  testBuffer32, +	} +	want := bytes.Join([][]byte{ +		[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04}, +		[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, +		testBuffer32[:], +	}, nil) +	if got := th.Marshal(); !bytes.Equal(got, want) { +		t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) +	} +} + +func TestUnmarshalLeaf(t *testing.T) { +	for _, table := range []struct { +		description string +		serialized  []byte +		wantErr     bool +		want        *Leaf +	}{ +		{ +			description: "invalid: not enough bytes", +			serialized:  make([]byte, LeafSize-1), +			wantErr:     true, +		}, +		{ +			description: "invalid: too many bytes", +			serialized:  make([]byte, LeafSize+1), +			wantErr:     true, +		}, +		{ +			description: "valid: shard hint 72623859790382856, buffers 0x00,0x01,...", +			serialized: bytes.Join([][]byte{ +				[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, +				testBuffer32[:], testBuffer64[:], testBuffer32[:], +			}, nil), +			want: &Leaf{ +				Message: Message{ +					ShardHint: 72623859790382856, +					Checksum:  testBuffer32, +				}, +				SigIdent: SigIdent{ +					Signature: testBuffer64, +					KeyHash:   testBuffer32, +				}, +			}, +		}, +	} { +		var leaf Leaf +		err := leaf.Unmarshal(table.serialized) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +		} +		if err != nil { +			continue +		} +		if got, want := &leaf, table.want; !reflect.DeepEqual(got, want) { +			t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, table.description) +		} +	} +} diff --git a/pkg/types/types.go b/pkg/types/types.go new file mode 100644 index 0000000..9ca7db8 --- /dev/null +++ b/pkg/types/types.go @@ -0,0 +1,155 @@ +package types + +import ( +	"crypto" +	"crypto/ed25519" +	"crypto/sha256" +	"fmt" +	"strings" +) + +const ( +	HashSize            = sha256.Size +	SignatureSize       = ed25519.SignatureSize +	VerificationKeySize = ed25519.PublicKeySize + +	EndpointAddLeaf             = Endpoint("add-leaf") +	EndpointAddCosignature      = Endpoint("add-cosignature") +	EndpointGetTreeHeadLatest   = Endpoint("get-tree-head-latest") +	EndpointGetTreeHeadToSign   = Endpoint("get-tree-head-to-sign") +	EndpointGetTreeHeadCosigned = Endpoint("get-tree-head-cosigned") +	EndpointGetProofByHash      = Endpoint("get-proof-by-hash") +	EndpointGetConsistencyProof = Endpoint("get-consistency-proof") +	EndpointGetLeaves           = Endpoint("get-leaves") +) + +// Endpoint is a named HTTP API endpoint +type Endpoint string + +// Path joins a number of components to form a full endpoint path.  For example, +// EndpointAddLeaf.Path("example.com", "st/v0") -> example.com/st/v0/add-leaf. +func (e Endpoint) Path(components ...string) string { +	return strings.Join(append(components, string(e)), "/") +} + +// Leaf is the log's Merkle tree leaf. +type Leaf struct { +	Message +	SigIdent +} + +// Message is composed of a shard hint and a checksum.  The submitter selects +// these values to fit the log's shard interval and the opaque data in question. +type Message struct { +	ShardHint uint64 +	Checksum  *[HashSize]byte +} + +// SigIdent is composed of a signature-signer pair.  The signature is computed +// over the Trunnel-serialized leaf message.  KeyHash identifies the signer. +type SigIdent struct { +	Signature *[SignatureSize]byte +	KeyHash   *[HashSize]byte +} + +// SignedTreeHead is composed of a tree head and a list of signature-signer +// pairs.  Each signature is computed over the Trunnel-serialized tree head. +type SignedTreeHead struct { +	TreeHead +	SigIdent []*SigIdent +} + +// TreeHead is the log's tree head. +type TreeHead struct { +	Timestamp uint64 +	TreeSize  uint64 +	RootHash  *[HashSize]byte +} + +// ConsistencyProof is a consistency proof that proves the log's append-only +// property. +type ConsistencyProof struct { +	NewSize uint64 +	OldSize uint64 +	Path    []*[HashSize]byte +} + +// InclusionProof is an inclusion proof that proves a leaf is included in the +// log. +type InclusionProof struct { +	TreeSize  uint64 +	LeafIndex uint64 +	Path      []*[HashSize]byte +} + +// LeafList is a list of leaves +type LeafList []*Leaf + +// ConsistencyProofRequest is a get-consistency-proof request +type ConsistencyProofRequest struct { +	NewSize uint64 +	OldSize uint64 +} + +// InclusionProofRequest is a get-proof-by-hash request +type InclusionProofRequest struct { +	LeafHash *[HashSize]byte +	TreeSize uint64 +} + +// LeavesRequest is a get-leaves request +type LeavesRequest struct { +	StartSize uint64 +	EndSize   uint64 +} + +// LeafRequest is an add-leaf request +type LeafRequest struct { +	Message +	Signature       *[SignatureSize]byte +	VerificationKey *[VerificationKeySize]byte +	DomainHint      string +} + +// CosignatureRequest is an add-cosignature request +type CosignatureRequest struct { +	SigIdent +} + +// Sign signs the tree head using the log's signature scheme +func (th *TreeHead) Sign(signer crypto.Signer) (*SignedTreeHead, error) { +	sig, err := signer.Sign(nil, th.Marshal(), crypto.Hash(0)) +	if err != nil { +		return nil, fmt.Errorf("Sign: %v", err) +	} + +	sigident := SigIdent{ +		KeyHash:   Hash(signer.Public().(ed25519.PublicKey)[:]), +		Signature: &[SignatureSize]byte{}, +	} +	copy(sigident.Signature[:], sig) +	return &SignedTreeHead{ +		TreeHead: *th, +		SigIdent: []*SigIdent{ +			&sigident, +		}, +	}, nil +} + +// Verify verifies the tree head signature using the log's signature scheme +func (th *TreeHead) Verify(vk *[VerificationKeySize]byte, sig *[SignatureSize]byte) error { +	if !ed25519.Verify(ed25519.PublicKey(vk[:]), th.Marshal(), sig[:]) { +		return fmt.Errorf("invalid tree head signature") +	} +	return nil +} + +// Verify checks if a leaf is included in the log +func (p *InclusionProof) Verify(leaf *Leaf, th *TreeHead) error { // TODO +	return nil +} + +// Verify checks if two tree heads are consistent +func (p *ConsistencyProof) Verify(oldTH, newTH *TreeHead) error { // TODO +	return nil +} diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go new file mode 100644 index 0000000..da89c59 --- /dev/null +++ b/pkg/types/types_test.go @@ -0,0 +1,58 @@ +package types + +import ( +	"testing" +) + +func TestEndpointPath(t *testing.T) { +	base, prefix, proto := "example.com", "log", "st/v0" +	for _, table := range []struct { +		endpoint Endpoint +		want     string +	}{ +		{ +			endpoint: EndpointAddLeaf, +			want:     "example.com/log/st/v0/add-leaf", +		}, +		{ +			endpoint: EndpointAddCosignature, +			want:     "example.com/log/st/v0/add-cosignature", +		}, +		{ +			endpoint: EndpointGetTreeHeadLatest, +			want:     "example.com/log/st/v0/get-tree-head-latest", +		}, +		{ +			endpoint: EndpointGetTreeHeadToSign, +			want:     "example.com/log/st/v0/get-tree-head-to-sign", +		}, +		{ +			endpoint: EndpointGetTreeHeadCosigned, +			want:     "example.com/log/st/v0/get-tree-head-cosigned", +		}, +		{ +			endpoint: EndpointGetConsistencyProof, +			want:     "example.com/log/st/v0/get-consistency-proof", +		}, +		{ +			endpoint: EndpointGetProofByHash, +			want:     "example.com/log/st/v0/get-proof-by-hash", +		}, +		{ +			endpoint: EndpointGetLeaves, +			want:     "example.com/log/st/v0/get-leaves", +		}, +	} { +		if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want { +			t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want) +		} +		if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want { +			t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want) +		} +	} +} + +func TestTreeHeadSign(t *testing.T)           {} +func TestTreeHeadVerify(t *testing.T)         {} +func TestInclusionProofVerify(t *testing.T)   {} +func TestConsistencyProofVerify(t *testing.T) {} diff --git a/pkg/types/util.go b/pkg/types/util.go new file mode 100644 index 0000000..3cd7dfa --- /dev/null +++ b/pkg/types/util.go @@ -0,0 +1,21 @@ +package types + +import ( +	"crypto/sha256" +) + +const ( +	LeafHashPrefix = 0x00 +) + +func Hash(buf []byte) *[HashSize]byte { +	var ret [HashSize]byte +	hash := sha256.New() +	hash.Write(buf) +	copy(ret[:], hash.Sum(nil)) +	return &ret +} + +func HashLeaf(buf []byte) *[HashSize]byte { +	return Hash(append([]byte{LeafHashPrefix}, buf...)) +} | 
