diff options
-rw-r--r-- | handler.go | 58 | ||||
-rw-r--r-- | handler_test.go | 225 | ||||
-rw-r--r-- | instance.go | 37 | ||||
-rw-r--r-- | instance_test.go | 61 | ||||
-rw-r--r-- | reqres.go | 41 | ||||
-rw-r--r-- | reqres_test.go | 41 | ||||
-rw-r--r-- | server/main.go | 157 | ||||
-rw-r--r-- | sth.go | 159 | ||||
-rw-r--r-- | sth_test.go | 454 | ||||
-rw-r--r-- | type.go | 33 | ||||
-rw-r--r-- | type_test.go | 16 |
11 files changed, 1175 insertions, 107 deletions
@@ -9,7 +9,6 @@ import ( "github.com/golang/glog" "github.com/google/trillian" - "github.com/google/trillian/types" ) // Handler implements the http.Handler interface, and contains a reference @@ -37,7 +36,7 @@ func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() reqcnt.Inc(a.instance.LogParameters.id(), string(a.endpoint)) - ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.instance.Deadline)) + ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.instance.LogParameters.Deadline)) defer cancel() if r.Method != a.method { @@ -57,6 +56,7 @@ func (a Handler) sendHTTPError(w http.ResponseWriter, statusCode int, err error) http.Error(w, http.StatusText(statusCode), statusCode) } +// addEntry accepts log entries from trusted submitters func addEntry(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { glog.V(3).Info("handling add-entry request") req, err := i.LogParameters.newAddEntryRequest(r) @@ -184,17 +184,26 @@ func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter // getSth provides the most recent STH func getSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { glog.V(3).Info("handling get-sth request") - trsp, err := i.Client.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ - LogId: i.LogParameters.TreeId, - }) - var lr types.LogRootV1 - if errInner := checkGetLatestSignedLogRoot(i.LogParameters, trsp, err, &lr); errInner != nil { - return http.StatusInternalServerError, fmt.Errorf("bad GetLatestSignedLogRootResponse: %v", errInner) + sth, err := i.SthSource.Latest(ctx) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("Latest: %v", err) + } + rsp, err := sth.MarshalB64() + if err != nil { + return http.StatusInternalServerError, err + } + if err := writeJsonResponse(rsp, w); err != nil { + return http.StatusInternalServerError, err } + return http.StatusOK, nil +} - sth, err := i.LogParameters.genV1Sth(NewTreeHeadV1(&lr)) +// getStableSth provides an STH that is stable for a fixed period of time +func getStableSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { + glog.V(3).Info("handling get-stable-sth request") + sth, err := i.SthSource.Stable(ctx) if err != nil { - return http.StatusInternalServerError, fmt.Errorf("failed creating signed tree head: %v", err) + return http.StatusInternalServerError, fmt.Errorf("Latest: %v", err) } rsp, err := sth.MarshalB64() if err != nil { @@ -205,3 +214,32 @@ func getSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Req } return http.StatusOK, nil } + +// getCosi provides a cosigned STH +func getCosi(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { + costh, err := i.SthSource.Cosigned(ctx) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("Cosigned: %v", err) + } + rsp, err := costh.MarshalB64() + if err != nil { + return http.StatusInternalServerError, err + } + if err := writeJsonResponse(rsp, w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +// addCosi accepts cosigned STHs from trusted witnesses +func addCosi(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling add-cosignature request") + costh, err := i.LogParameters.newAddCosignatureRequest(r) + if err != nil { + return http.StatusBadRequest, err + } + if err := i.SthSource.AddCosignature(ctx, costh); err != nil { + return http.StatusBadRequest, err + } + return http.StatusOK, nil +} diff --git a/handler_test.go b/handler_test.go index dd32c37..daa1a6c 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,5 +1,7 @@ package stfe +// TODO: refactor tests + import ( "bytes" "context" @@ -8,7 +10,6 @@ import ( "testing" "crypto/ed25519" - //"crypto/tls" "encoding/base64" "encoding/json" "net/http" @@ -28,16 +29,31 @@ type testHandler struct { instance *Instance } -func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler { +func newTestHandler(t *testing.T, signer crypto.Signer, sth *StItem) *testHandler { ctrl := gomock.NewController(t) client := mockclient.NewMockTrillianLogClient(ctrl) + lp := makeTestLogParameters(t, signer) + source := &ActiveSthSource{ + client: client, + logParameters: lp, + } + if sth != nil { + source.currSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + Signature: testSignature, + }, + }) + source.nextSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) + source.cosignatureFrom = make(map[string]bool) + } return &testHandler{ mockCtrl: ctrl, client: client, instance: &Instance{ - Deadline: testDeadline, Client: client, - LogParameters: makeTestLogParameters(t, signer), + LogParameters: lp, + SthSource: source, }, } } @@ -49,6 +65,8 @@ func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler { EndpointGetProofByHash: Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, EndpointGetAnchors: Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, EndpointGetEntries: Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, + EndpointGetStableSth: Handler{instance: th.instance, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet}, + EndpointGetCosi: Handler{instance: th.instance, handler: getCosi, endpoint: EndpointGetCosi, method: http.MethodGet}, } } @@ -63,6 +81,7 @@ func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler { func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler { return map[Endpoint]Handler{ EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, + EndpointAddCosi: Handler{instance: th.instance, handler: addCosi, endpoint: EndpointAddCosi, method: http.MethodPost}, } } @@ -76,7 +95,7 @@ func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler { // TestGetHandlersRejectPost checks that all get handlers reject post requests func TestGetHandlersRejectPost(t *testing.T) { - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.getHandlers(t) { @@ -96,7 +115,7 @@ func TestGetHandlersRejectPost(t *testing.T) { // TestPostHandlersRejectGet checks that all post handlers reject get requests func TestPostHandlersRejectGet(t *testing.T) { - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() for endpoint, handler := range th.postHandlers(t) { @@ -196,7 +215,7 @@ func TestGetEntries(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetEntries.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -298,7 +317,7 @@ func TestAddEntry(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, table.signer) + th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointAddEntry.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -392,7 +411,7 @@ func TestGetSth(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, table.signer) + th := newTestHandler(t, table.signer, nil) defer th.mockCtrl.Finish() url := EndpointGetSth.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -496,7 +515,7 @@ func TestGetConsistencyProof(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetConsistencyProof.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -605,7 +624,7 @@ func TestGetProofByHash(t *testing.T) { }, } { func() { // run deferred functions at the end of each iteration - th := newTestHandler(t, nil) + th := newTestHandler(t, nil, nil) defer th.mockCtrl.Finish() url := EndpointGetProofByHash.Path("http://example.com", th.instance.LogParameters.Prefix) @@ -671,6 +690,166 @@ func TestGetProofByHash(t *testing.T) { } } +func TestGetStableSth(t *testing.T) { + for _, table := range cosiTestCases(t) { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointGetStableSth.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + w := httptest.NewRecorder() + th.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) + + // Check response code + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + if w.Code != http.StatusOK { + return + } + // Check response bytes + var gotBytes []byte + if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { + t.Errorf("failed unmarshaling json: %v, wanted ok", err) + return + } + wantBytes, _ := table.sth.Marshal() + if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { + t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) + } + }() + } +} + +func TestGetCosi(t *testing.T) { + for _, table := range cosiTestCases(t) { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointGetCosi.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + w := httptest.NewRecorder() + th.getHandler(t, EndpointGetCosi).ServeHTTP(w, req) + + // Check response code + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + if w.Code != http.StatusOK { + return + } + // Check response bytes + var gotBytes []byte + if err := json.Unmarshal([]byte(w.Body.String()), &gotBytes); err != nil { + t.Errorf("failed unmarshaling json: %v, wanted ok", err) + return + } + wantCosth := NewCosignedTreeHeadV1(table.sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + Signature: testSignature, + }, + }) + wantBytes, _ := wantCosth.Marshal() + if got, want := gotBytes, wantBytes; !bytes.Equal(got, want) { + t.Errorf("wanted response %X but got %X in test %q", got, want, table.description) + } + }() + } +} + +func TestAddCosi(t *testing.T) { + validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + validSth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature) + for _, table := range []struct { + description string + sth *StItem + breq *bytes.Buffer + wantCode int + }{ + { + description: "invalid request: untrusted witness", // more specific tests can be found in TestNewAddCosignatureRequest + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid request: cosigned wrong sth", // more specific tests can be found in TestAddCosignature + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth2), + wantCode: http.StatusBadRequest, + }, + { + description: "valid", + sth: validSth, + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth), + wantCode: http.StatusOK, + }, + } { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, nil, table.sth) + defer th.mockCtrl.Finish() + + // Setup and run client query + url := EndpointAddCosi.Path("http://example.com", th.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + th.postHandler(t, EndpointAddCosi).ServeHTTP(w, req) + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + + // Check response + if w.Code != table.wantCode { + t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) + } + }() + } +} + +type cosiTestCase struct { + description string + sth *StItem + wantCode int +} + +// cosiTestCases returns test cases used by TestGetStableSth and TestGetCosi +func cosiTestCases(t *testing.T) []cosiTestCase { + t.Helper() + return []cosiTestCase{ + { + description: "no cosigned/stable sth", + sth: nil, + wantCode: http.StatusInternalServerError, + }, + { + description: "malformed cosigned/stable sth", + sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), []byte("not a log id"), testSignature), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + sth: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature), + wantCode: http.StatusOK, + }, + } +} + // mustMakeEd25519ChecksumV1 creates an ed25519-signed ChecksumV1 leaf func mustMakeEd25519ChecksumV1(t *testing.T, id, checksum, vk, sk []byte) ([]byte, []byte) { t.Helper() @@ -697,6 +876,30 @@ func mustMakeEd25519ChecksumV1Buffer(t *testing.T, identifier, checksum, vk, sk return bytes.NewBuffer(data) } +// mustMakeAddCosiBuffer creates an add-cosi data buffer +func mustMakeAddCosiBuffer(t *testing.T, sk, vk []byte, sth *StItem) *bytes.Buffer { + t.Helper() + msg, err := sth.Marshal() + if err != nil { + t.Fatalf("must marshal sth: %v", err) + } + costh := NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, vk), + Signature: ed25519.Sign(ed25519.PrivateKey(sk), msg), + }, + }) + item, err := costh.Marshal() + if err != nil { + t.Fatalf("must marshal costh: %v", err) + } + data, err := json.Marshal(AddCosignatureRequest{item}) + if err != nil { + t.Fatalf("must marshal add-cosi request: %v", err) + } + return bytes.NewBuffer(data) +} + // deadlineMatcher implements gomock.Matcher, such that an error is raised if // there is no context.Context deadline set type deadlineMatcher struct{} diff --git a/instance.go b/instance.go index 3ca14b8..a67307f 100644 --- a/instance.go +++ b/instance.go @@ -14,9 +14,9 @@ import ( // Instance is an instance of a particular log front-end type Instance struct { - LogParameters *LogParameters Client trillian.TrillianLogClient - Deadline time.Duration + SthSource SthSource + LogParameters *LogParameters } // LogParameters is a collection of log parameters @@ -25,9 +25,12 @@ type LogParameters struct { TreeId int64 // used internally by Trillian Prefix string // e.g., "test" for <base>/test MaxRange int64 // max entries per get-entries request - Namespaces *namespace.NamespacePool // trust namespaces - Signer crypto.Signer - HashType crypto.Hash // hash function used by Trillian + Submitters *namespace.NamespacePool // trusted submitters + Witnesses *namespace.NamespacePool // trusted witnesses + Deadline time.Duration // gRPC deadline + Interval time.Duration // cosigning sth frequency + Signer crypto.Signer // interface to access private key + HashType crypto.Hash // hash function used by Trillian } // Endpoint is a named HTTP API endpoint @@ -35,19 +38,22 @@ type Endpoint string const ( EndpointAddEntry = Endpoint("add-entry") + EndpointAddCosi = Endpoint("add-cosi") // TODO: name? EndpointGetEntries = Endpoint("get-entries") EndpointGetAnchors = Endpoint("get-anchors") EndpointGetProofByHash = Endpoint("get-proof-by-hash") EndpointGetConsistencyProof = Endpoint("get-consistency-proof") EndpointGetSth = Endpoint("get-sth") + EndpointGetStableSth = Endpoint("get-stable-sth") // TODO: name? + EndpointGetCosi = Endpoint("get-cosi") // TODO: name? ) func (i Instance) String() string { - return fmt.Sprintf("%s Deadline(%v)\n", i.LogParameters, i.Deadline) + return fmt.Sprintf("%s\n", i.LogParameters) } func (lp LogParameters) String() string { - return fmt.Sprintf("LogId(%s) TreeId(%d) Prefix(%s) MaxRange(%d) Namespaces(%d)", lp.id(), lp.TreeId, lp.Prefix, lp.MaxRange, len(lp.Namespaces.List())) + return fmt.Sprintf("LogId(%s) TreeId(%d) Prefix(%s) MaxRange(%d) Submitters(%d) Witnesses(%d) Deadline(%v) Interval(%v)", lp.id(), lp.TreeId, lp.Prefix, lp.MaxRange, len(lp.Submitters.List()), len(lp.Witnesses.List()), lp.Deadline, lp.Interval) } func (e Endpoint) String() string { @@ -55,26 +61,23 @@ func (e Endpoint) String() string { } // NewInstance creates a new STFE instance -func NewInstance(lp *LogParameters, client trillian.TrillianLogClient, deadline time.Duration) *Instance { +func NewInstance(lp *LogParameters, client trillian.TrillianLogClient, source SthSource) *Instance { return &Instance{ LogParameters: lp, Client: client, - Deadline: deadline, + SthSource: source, } } // NewLogParameters creates new log parameters. Note that the signer is // assumed to be an ed25519 signing key. Could be fixed at some point. -func NewLogParameters(signer crypto.Signer, logId *namespace.Namespace, treeId int64, prefix string, namespaces *namespace.NamespacePool, maxRange int64) (*LogParameters, error) { +func NewLogParameters(signer crypto.Signer, logId *namespace.Namespace, treeId int64, prefix string, submitters, witnesses *namespace.NamespacePool, maxRange int64, interval, deadline time.Duration) (*LogParameters, error) { if signer == nil { return nil, fmt.Errorf("need a signer but got none") } if maxRange < 1 { return nil, fmt.Errorf("max range must be at least one") } - if len(namespaces.List()) < 1 { - return nil, fmt.Errorf("need at least one trusted namespace") - } lid, err := logId.Marshal() if err != nil { return nil, fmt.Errorf("failed encoding log identifier: %v", err) @@ -84,7 +87,10 @@ func NewLogParameters(signer crypto.Signer, logId *namespace.Namespace, treeId i TreeId: treeId, Prefix: prefix, MaxRange: maxRange, - Namespaces: namespaces, + Submitters: submitters, + Witnesses: witnesses, + Deadline: deadline, + Interval: interval, Signer: signer, HashType: crypto.SHA256, // STFE assumes RFC 6962 hashing }, nil @@ -94,11 +100,14 @@ func NewLogParameters(signer crypto.Signer, logId *namespace.Namespace, treeId i func (i *Instance) Handlers() []Handler { return []Handler{ Handler{instance: i, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, + Handler{instance: i, handler: addCosi, endpoint: EndpointAddCosi, method: http.MethodPost}, Handler{instance: i, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, Handler{instance: i, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, Handler{instance: i, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, Handler{instance: i, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet}, Handler{instance: i, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet}, + Handler{instance: i, handler: getStableSth, endpoint: EndpointGetStableSth, method: http.MethodGet}, + Handler{instance: i, handler: getCosi, endpoint: EndpointGetCosi, method: http.MethodGet}, } } diff --git a/instance_test.go b/instance_test.go index 3e55b5b..50302fb 100644 --- a/instance_test.go +++ b/instance_test.go @@ -32,20 +32,23 @@ var ( } testIndex = uint64(0) testHashLen = 31 - testDeadline = time.Second * 10 + testDeadline = time.Second * 5 + testInterval = time.Second * 10 ) +// TestNewLogParamters checks that invalid ones are rejected and that a valid +// set of parameters are accepted. func TestNewLogParameters(t *testing.T) { testLogId := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk3) namespaces := mustNewNamespacePool(t, []*namespace.Namespace{ mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), }) + witnesses := mustNewNamespacePool(t, []*namespace.Namespace{}) signer := ed25519.PrivateKey(testdata.Ed25519Sk) for _, table := range []struct { description string logId *namespace.Namespace maxRange int64 - namespaces *namespace.NamespacePool signer crypto.Signer wantErr bool }{ @@ -53,7 +56,6 @@ func TestNewLogParameters(t *testing.T) { description: "invalid signer: nil", logId: testLogId, maxRange: testMaxRange, - namespaces: namespaces, signer: nil, wantErr: true, }, @@ -61,15 +63,6 @@ func TestNewLogParameters(t *testing.T) { description: "invalid max range", logId: testLogId, maxRange: 0, - namespaces: namespaces, - signer: signer, - wantErr: true, - }, - { - description: "no namespaces", - logId: testLogId, - maxRange: testMaxRange, - namespaces: mustNewNamespacePool(t, []*namespace.Namespace{}), signer: signer, wantErr: true, }, @@ -81,20 +74,18 @@ func TestNewLogParameters(t *testing.T) { Namespace: make([]byte, 31), // too short }, }, - maxRange: testMaxRange, - namespaces: namespaces, - signer: signer, - wantErr: true, + maxRange: testMaxRange, + signer: signer, + wantErr: true, }, { description: "valid log parameters", logId: testLogId, maxRange: testMaxRange, - namespaces: namespaces, signer: signer, }, } { - lp, err := NewLogParameters(table.signer, table.logId, testTreeId, testPrefix, table.namespaces, table.maxRange) + lp, err := NewLogParameters(table.signer, table.logId, testTreeId, testPrefix, namespaces, witnesses, table.maxRange, testInterval, testDeadline) if got, want := err != nil, table.wantErr; got != want { t.Errorf("got error=%v but wanted %v in test %q: %v", got, want, table.description, err) } @@ -118,7 +109,10 @@ func TestNewLogParameters(t *testing.T) { if got, want := lp.MaxRange, testMaxRange; got != want { t.Errorf("got max range %d but wanted %d in test %q", got, want, table.description) } - if got, want := len(lp.Namespaces.List()), len(namespaces.List()); got != want { + if got, want := len(lp.Submitters.List()), len(namespaces.List()); got != want { + t.Errorf("got %d anchors but wanted %d in test %q", got, want, table.description) + } + if got, want := len(lp.Witnesses.List()), len(witnesses.List()); got != want { t.Errorf("got %d anchors but wanted %d in test %q", got, want, table.description) } } @@ -134,8 +128,11 @@ func TestHandlers(t *testing.T) { EndpointGetProofByHash: false, EndpointGetConsistencyProof: false, EndpointGetAnchors: false, + EndpointGetStableSth: false, + EndpointGetCosi: false, + EndpointAddCosi: false, } - i := NewInstance(makeTestLogParameters(t, nil), nil, testDeadline) + i := NewInstance(makeTestLogParameters(t, nil), nil, nil) for _, handler := range i.Handlers() { if _, ok := endpoints[handler.endpoint]; !ok { t.Errorf("got unexpected endpoint: %s", handler.endpoint) @@ -149,6 +146,7 @@ func TestHandlers(t *testing.T) { } } +// TestEndpointPath checks that the endpoint path builder works as expected func TestEndpointPath(t *testing.T) { base, prefix := "http://example.com", "test" for _, table := range []struct { @@ -179,6 +177,18 @@ func TestEndpointPath(t *testing.T) { endpoint: EndpointGetAnchors, want: "http://example.com/test/get-anchors", }, + { + endpoint: EndpointGetStableSth, + want: "http://example.com/test/get-stable-sth", + }, + { + endpoint: EndpointGetCosi, + want: "http://example.com/test/get-cosi", + }, + { + endpoint: EndpointAddCosi, + want: "http://example.com/test/add-cosi", + }, } { if got, want := table.endpoint.Path(base, prefix), table.want; got != want { t.Errorf("got %s but wanted %s with multiple components", got, want) @@ -216,8 +226,9 @@ func mustNewNamespacePool(t *testing.T, anchors []*namespace.Namespace) *namespa // makeTestLogParameters makes a collection of test log parameters. // // The log's identity is based on testdata.Ed25519{Vk3,Sk3}. The log's accepted -// namespaces is based on testdata.Ed25519{Vk,Vk2}. The remaining log -// parameters are based on the global test* variables in this file. +// submitters are based on testdata.Ed25519Vk. The log's accepted witnesses are +// based on testdata.Ed25519Vk. The remaining log parameters are based on the +// global test* variables in this file. // // For convenience the passed signer is optional (i.e., it may be nil). func makeTestLogParameters(t *testing.T, signer crypto.Signer) *LogParameters { @@ -226,9 +237,11 @@ func makeTestLogParameters(t *testing.T, signer crypto.Signer) *LogParameters { TreeId: testTreeId, Prefix: testPrefix, MaxRange: testMaxRange, - Namespaces: mustNewNamespacePool(t, []*namespace.Namespace{ + Submitters: mustNewNamespacePool(t, []*namespace.Namespace{ + mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + }), + Witnesses: mustNewNamespacePool(t, []*namespace.Namespace{ mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), - mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk2), }), Signer: signer, HashType: testHashType, @@ -40,6 +40,41 @@ type GetConsistencyProofRequest struct { // is identical to the add-entry request that the log once accepted. type GetEntryResponse AddEntryRequest +// AddCosignatureRequest encapsulates a cosignature request +type AddCosignatureRequest struct { + Item []byte `json:"item"` +} + +// newAddCosignatureRequest parses and verifies an STH cosignature request +func (lp *LogParameters) newAddCosignatureRequest(r *http.Request) (*StItem, error) { + var req AddCosignatureRequest + if err := unpackJsonPost(r, &req); err != nil { + return nil, fmt.Errorf("unpackJsonPost: %v", err) + } + + // Try decoding as CosignedTreeHeadV1 + var item StItem + if err := item.Unmarshal(req.Item); err != nil { + return nil, fmt.Errorf("Unmarshal: %v", err) + } + if item.Format != StFormatCosignedTreeHeadV1 { + return nil, fmt.Errorf("invalid StItem format: %v", item.Format) + } + + // Check that witness namespace is valid + sth := &StItem{Format: StFormatSignedTreeHeadV1, SignedTreeHeadV1: &item.CosignedTreeHeadV1.SignedTreeHeadV1} + if len(item.CosignedTreeHeadV1.SignatureV1) != 1 { + return nil, fmt.Errorf("invalid number of cosignatures") + } else if namespace, ok := lp.Witnesses.Find(&item.CosignedTreeHeadV1.SignatureV1[0].Namespace); !ok { + return nil, fmt.Errorf("unknown witness") + } else if msg, err := sth.Marshal(); err != nil { + return nil, fmt.Errorf("Marshal: %v", err) + } else if err := namespace.Verify(msg, item.CosignedTreeHeadV1.SignatureV1[0].Signature); err != nil { + return nil, fmt.Errorf("Verify: %v", err) + } + return &item, nil +} + // newAddEntryRequest parses and sanitizes the JSON-encoded add-entry // parameters from an incoming HTTP post. The request is returned if it is // a checksumV1 entry that is signed by a valid namespace. @@ -59,7 +94,7 @@ func (lp *LogParameters) newAddEntryRequest(r *http.Request) (*AddEntryRequest, } // Check that namespace is valid for item - if namespace, ok := lp.Namespaces.Find(&item.ChecksumV1.Namespace); !ok { + if namespace, ok := lp.Submitters.Find(&item.ChecksumV1.Namespace); !ok { return nil, fmt.Errorf("unknown namespace: %s", item.ChecksumV1.Namespace.String()) } else if err := namespace.Verify(entry.Item, entry.Signature); err != nil { return nil, fmt.Errorf("invalid namespace: %v", err) @@ -154,8 +189,8 @@ func (lp *LogParameters) newGetEntriesResponse(leaves []*trillian.LogLeaf) ([]*G // newGetAnchorsResponse assembles a get-anchors response func (lp *LogParameters) newGetAnchorsResponse() [][]byte { - namespaces := make([][]byte, 0, len(lp.Namespaces.List())) - for _, namespace := range lp.Namespaces.List() { + namespaces := make([][]byte, 0, len(lp.Submitters.List())) + for _, namespace := range lp.Submitters.List() { raw, err := namespace.Marshal() if err != nil { fmt.Printf("TODO: fix me and entire func\n") diff --git a/reqres_test.go b/reqres_test.go index fab0e29..ce0c7b6 100644 --- a/reqres_test.go +++ b/reqres_test.go @@ -7,9 +7,48 @@ import ( "testing" "net/http" - // "github.com/google/trillian" + + "github.com/system-transparency/stfe/namespace/testdata" ) +func TestNewAddCosignatureRequest(t *testing.T) { + lp := makeTestLogParameters(t, nil) + validSth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + for _, table := range []struct { + description string + breq *bytes.Buffer + wantErr bool + }{ + // TODO: test cases for all errors + add wantBytes for valid cases + { + description: "invalid: unknown witness", + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk2, testdata.Ed25519Vk2, validSth), + wantErr: true, + }, + { + description: "invalid: bad signature", + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk2, validSth), + wantErr: true, + }, + { + description: "valid", + breq: mustMakeAddCosiBuffer(t, testdata.Ed25519Sk, testdata.Ed25519Vk, validSth), + }, + } { + url := EndpointAddCosi.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + _, err = lp.newAddCosignatureRequest(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + } +} + // TODO: TestNewAddEntryRequest func TestNewAddEntryRequest(t *testing.T) { } diff --git a/server/main.go b/server/main.go index 8f5ab48..7402fb3 100644 --- a/server/main.go +++ b/server/main.go @@ -2,13 +2,19 @@ package main import ( + "context" "flag" + "fmt" + "os" "strings" + "sync" + "syscall" "time" "crypto/ed25519" "encoding/base64" "net/http" + "os/signal" "github.com/golang/glog" "github.com/google/trillian" @@ -23,78 +29,145 @@ var ( rpcBackend = flag.String("log_rpc_server", "localhost:6962", "host:port specification of where Trillian serves clients") prefix = flag.String("prefix", "st/v1", "a prefix that proceeds each endpoint path") trillianID = flag.Int64("trillian_id", 5991359069696313945, "log identifier in the Trillian database") - rpcDeadline = flag.Duration("rpc_deadline", time.Second*10, "deadline for backend RPC requests") + deadline = flag.Duration("deadline", time.Second*10, "deadline for backend requests") key = flag.String("key", "8gzezwrU/2eTrO6tEYyLKsoqn5V54URvKIL9cTE7jUYUqXVX4neJvcBq/zpSAYPsZFG1woh0OGBzQbi9UP9MZw==", "base64-encoded Ed25519 signing key") - namespaces = flag.String("namespaces", "AAEgHOQFUkKNWpjYAhNKTyWCzahlI7RDtf5123kHD2LACj0=,AAEgLqrWb9JwQUTk/SwTNDdMH8aRmy3mbmhwEepO5WSgb+A=", "comma-separated list of trusted namespaces in base64 (default: testdata.Ed25519{Vk,Vk2})") + submitters = flag.String("submitters", "AAEgHOQFUkKNWpjYAhNKTyWCzahlI7RDtf5123kHD2LACj0=,AAEgLqrWb9JwQUTk/SwTNDdMH8aRmy3mbmhwEepO5WSgb+A=", "comma-separated list of trusted submitter namespaces in base64 (default: testdata.Ed25519{Vk,Vk2})") + witnesses = flag.String("witnesses", "", "comma-separated list of trusted submitter namespaces in base64 (default: none") maxRange = flag.Int64("max_range", 2, "maximum number of entries that can be retrived in a single request") + interval = flag.Duration("interval", time.Second*30, "interval used to rotate the log's cosigned STH") ) func main() { flag.Parse() + defer glog.Flush() - glog.Info("Dialling Trillian gRPC log server") - dialOpts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(*rpcDeadline)} + // wait for clean-up before exit + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + glog.V(3).Infof("configuring stfe instance...") + instance, err := setupInstanceFromFlags() + if err != nil { + glog.Errorf("setupInstance: %v", err) + return + } + + glog.V(3).Infof("spawning SthSource") + go func() { + wg.Add(1) + defer wg.Done() + instance.SthSource.Run(ctx) + glog.Errorf("SthSource shutdown") + cancel() // must have SthSource running + }() + + glog.V(3).Infof("spawning await") + server := http.Server{Addr: *httpEndpoint} + go await(ctx, func() { + wg.Add(1) + defer wg.Done() + ctxInner, _ := context.WithTimeout(ctx, time.Second*60) + glog.Infof("Shutting down HTTP server...") + server.Shutdown(ctxInner) + glog.V(3).Infof("HTTP server shutdown") + glog.Infof("Shutting down spawned go routines...") + cancel() + }) + + glog.Infof("Serving on %v/%v", *httpEndpoint, *prefix) + if err = server.ListenAndServe(); err != http.ErrServerClosed { + glog.Errorf("ListenAndServe: %v", err) + } +} + +// SetupInstance sets up a new STFE instance from flags +func setupInstanceFromFlags() (*stfe.Instance, error) { + // Trillian gRPC connection + dialOpts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(*deadline)} conn, err := grpc.Dial(*rpcBackend, dialOpts...) if err != nil { - glog.Fatal(err) + return nil, fmt.Errorf("Dial: %v", err) } client := trillian.NewTrillianLogClient(conn) - - glog.Info("Creating HTTP request multiplexer") + // HTTP multiplexer mux := http.NewServeMux() http.Handle("/", mux) - - glog.Info("Adding prometheus handler on path: /metrics") + // Prometheus metrics + glog.V(3).Infof("Adding prometheus handler on path: /metrics") http.Handle("/metrics", promhttp.Handler()) - - glog.Infof("Creating namespace pool") - var anchors []*namespace.Namespace - for _, b64 := range strings.Split(*namespaces, ",") { - b, err := base64.StdEncoding.DecodeString(b64) - if err != nil { - glog.Fatalf("invalid namespace: %s: %v", b64, err) - } - var namespace namespace.Namespace - if err := namespace.Unmarshal(b); err != nil { - glog.Fatalf("invalid namespace: %s: %v", b64, err) - } - anchors = append(anchors, &namespace) + // Trusted submitters + submitters, err := newNamespacePoolFromString(*submitters) + if err != nil { + return nil, fmt.Errorf("submitters: newNamespacePoolFromString: %v", err) } - pool, err := namespace.NewNamespacePool(anchors) + // Trusted witnesses + witnesses, err := newNamespacePoolFromString(*witnesses) if err != nil { - glog.Fatalf("invalid namespace pool: %v", err) + return nil, fmt.Errorf("witnesses: NewNamespacePool: %v", err) } - - glog.Infof("Creating log signer and identifier") + // Log identity sk, err := base64.StdEncoding.DecodeString(*key) if err != nil { - glog.Fatalf("invalid signing key: %v", err) + return nil, fmt.Errorf("sk: DecodeString: %v", err) } signer := ed25519.PrivateKey(sk) logId, err := namespace.NewNamespaceEd25519V1([]byte(ed25519.PrivateKey(sk).Public().(ed25519.PublicKey))) if err != nil { - glog.Fatalf("failed creating log id from secret key: %v", err) + return nil, fmt.Errorf("NewNamespaceEd25519V1: %v", err) } - - glog.Infof("Initializing log parameters") - lp, err := stfe.NewLogParameters(signer, logId, *trillianID, *prefix, pool, *maxRange) + // Setup log parameters + lp, err := stfe.NewLogParameters(signer, logId, *trillianID, *prefix, submitters, witnesses, *maxRange, *interval, *deadline) if err != nil { - glog.Fatalf("failed setting up log parameters: %v", err) + return nil, fmt.Errorf("NewLogParameters: %v", err) } - - i := stfe.NewInstance(lp, client, *rpcDeadline) + // Setup STH source + source, err := stfe.NewActiveSthSource(client, lp) + if err != nil { + return nil, fmt.Errorf("NewActiveSthSource: %v", err) + } + // Setup log instance + i := stfe.NewInstance(lp, client, source) for _, handler := range i.Handlers() { - glog.Infof("adding handler: %s", handler.Path()) + glog.V(3).Infof("adding handler: %s", handler.Path()) mux.Handle(handler.Path(), handler) } - glog.Infof("Configured: %s", i) + return i, nil +} - glog.Infof("Serving on %v/%v", *httpEndpoint, *prefix) - srv := http.Server{Addr: *httpEndpoint} - err = srv.ListenAndServe() - if err != http.ErrServerClosed { - glog.Warningf("Server exited: %v", err) +// newNamespacePoolFromString creates a new namespace pool from a +// comma-separated list of serialized and base64-encoded namespaces. +func newNamespacePoolFromString(str string) (*namespace.NamespacePool, error) { + var namespaces []*namespace.Namespace + if len(str) > 0 { + for _, b64 := range strings.Split(str, ",") { + b, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil, fmt.Errorf("DecodeString: %v", err) + } + var namespace namespace.Namespace + if err := namespace.Unmarshal(b); err != nil { + return nil, fmt.Errorf("Unmarshal: %v", err) + } + namespaces = append(namespaces, &namespace) + } } + pool, err := namespace.NewNamespacePool(namespaces) + if err != nil { + return nil, fmt.Errorf("NewNamespacePool: %v", err) + } + return pool, nil +} - glog.Flush() +// await waits for a shutdown signal and then runs a clean-up function +func await(ctx context.Context, done func()) { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + select { + case <-sigs: + case <-ctx.Done(): + } + glog.V(3).Info("received shutdown signal") + done() } @@ -0,0 +1,159 @@ +package stfe + +import ( + "bytes" + "context" + "fmt" + "reflect" + "sync" + + "github.com/golang/glog" + "github.com/google/certificate-transparency-go/schedule" + "github.com/google/trillian" + "github.com/google/trillian/types" +) + +// SthSource provides access to the log's STHs. +type SthSource interface { + // Latest returns the most reccent signed_tree_head_v*. + Latest(context.Context) (*StItem, error) + // Stable returns the most recent signed_tree_head_v* that is stable for + // some period of time, e.g., 10 minutes. + Stable(context.Context) (*StItem, error) + // Cosigned returns the most recent cosigned_tree_head_v*. + Cosigned(context.Context) (*StItem, error) + // AddCosignature attempts to add a cosignature to the stable STH. The + // passed cosigned_tree_head_v* must have a single verified cosignature. + AddCosignature(context.Context, *StItem) error + // Run keeps the STH source updated until cancelled + Run(context.Context) +} + +// ActiveSthSource implements the SthSource interface for an STFE instance that +// accepts new logging requests, i.e., the log is running in read+write mode. +type ActiveSthSource struct { + client trillian.TrillianLogClient + logParameters *LogParameters + currSth *StItem // current cosigned_tree_head_v1 (already finalized) + nextSth *StItem // next cosigned_tree_head_v1 (under preparation) + cosignatureFrom map[string]bool // track who we got cosignatures from + mutex sync.RWMutex +} + +// NewActiveSthSource returns an initialized ActiveSthSource +func NewActiveSthSource(cli trillian.TrillianLogClient, lp *LogParameters) (*ActiveSthSource, error) { + s := ActiveSthSource{ + client: cli, + logParameters: lp, + } + + ctx, _ := context.WithTimeout(context.Background(), lp.Deadline) + sth, err := s.Latest(ctx) + if err != nil { + return nil, fmt.Errorf("Latest: %v", err) + } + // TODO: load peristed cosigned STH? + s.currSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) + s.nextSth = NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil) + s.cosignatureFrom = make(map[string]bool) + return &s, nil +} + +func (s *ActiveSthSource) Latest(ctx context.Context) (*StItem, error) { + trsp, err := s.client.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ + LogId: s.logParameters.TreeId, + }) + var lr types.LogRootV1 + if errInner := checkGetLatestSignedLogRoot(s.logParameters, trsp, err, &lr); errInner != nil { + return nil, fmt.Errorf("invalid signed log root response: %v", errInner) + } + return s.logParameters.genV1Sth(NewTreeHeadV1(&lr)) +} + +func (s *ActiveSthSource) Stable(_ context.Context) (*StItem, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + if s.nextSth == nil { + return nil, fmt.Errorf("no stable sth available") + } + return &StItem{ + Format: StFormatSignedTreeHeadV1, + SignedTreeHeadV1: &s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1, + }, nil +} + +func (s *ActiveSthSource) Cosigned(_ context.Context) (*StItem, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + if s.currSth == nil || len(s.currSth.CosignedTreeHeadV1.SignatureV1) == 0 { + return nil, fmt.Errorf("no cosigned sth available") + } + return s.currSth, nil +} + +func (a *SignedTreeHeadV1) Equals(b *SignedTreeHeadV1) bool { + return bytes.Equal(a.LogId, b.LogId) && + bytes.Equal(a.Signature, b.Signature) && + a.TreeHead.Timestamp == b.TreeHead.Timestamp && + a.TreeHead.TreeSize == b.TreeHead.TreeSize && + bytes.Equal(a.TreeHead.RootHash.Data, b.TreeHead.RootHash.Data) && + bytes.Equal(a.TreeHead.Extension, b.TreeHead.Extension) + // TODO: why reflect.DeepEqual(a, b) gives a different result? Fixme. +} + +func (s *ActiveSthSource) AddCosignature(_ context.Context, costh *StItem) error { + s.mutex.Lock() + defer s.mutex.Unlock() + //if !reflect.DeepEqual(s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1, costh.CosignedTreeHeadV1.SignedTreeHeadV1) { + if !(&s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1).Equals(&costh.CosignedTreeHeadV1.SignedTreeHeadV1) { + return fmt.Errorf("cosignature covers a different tree head") + } + witness := costh.CosignedTreeHeadV1.SignatureV1[0].Namespace.String() + if _, ok := s.cosignatureFrom[witness]; ok { + return nil // duplicate + } + s.cosignatureFrom[witness] = true + s.nextSth.CosignedTreeHeadV1.SignatureV1 = append(s.nextSth.CosignedTreeHeadV1.SignatureV1, costh.CosignedTreeHeadV1.SignatureV1[0]) + return nil +} + +func (s *ActiveSthSource) Run(ctx context.Context) { + schedule.Every(ctx, s.logParameters.Interval, func(ctx context.Context) { + // get the next stable sth + ictx, _ := context.WithTimeout(ctx, s.logParameters.Deadline) + sth, err := s.Latest(ictx) + if err != nil { + glog.Warningf("cannot rotate without new sth: Latest: %v", err) + return + } + // rotate + s.mutex.Lock() + defer s.mutex.Unlock() + s.rotate(sth) + // TODO: persist cosigned STH? + }) +} + +// rotate rotates the log's cosigned and stable STH. The caller must aquire the +// source's read-write lock if there are concurrent reads and/or writes. +func (s *ActiveSthSource) rotate(fixedSth *StItem) { + // rotate stable -> cosigned + if reflect.DeepEqual(&s.currSth.CosignedTreeHeadV1.SignedTreeHeadV1, &s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1) { + for _, sigv1 := range s.currSth.CosignedTreeHeadV1.SignatureV1 { + witness := sigv1.Namespace.String() + if _, ok := s.cosignatureFrom[witness]; !ok { + s.cosignatureFrom[witness] = true + s.nextSth.CosignedTreeHeadV1.SignatureV1 = append(s.nextSth.CosignedTreeHeadV1.SignatureV1, sigv1) + } + } + } + s.currSth.CosignedTreeHeadV1.SignedTreeHeadV1 = s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1 + s.currSth.CosignedTreeHeadV1.SignatureV1 = make([]SignatureV1, len(s.nextSth.CosignedTreeHeadV1.SignatureV1)) + copy(s.currSth.CosignedTreeHeadV1.SignatureV1, s.nextSth.CosignedTreeHeadV1.SignatureV1) + + // rotate new stable -> stable + if !reflect.DeepEqual(&s.nextSth.CosignedTreeHeadV1.SignedTreeHeadV1, fixedSth.SignedTreeHeadV1) { + s.nextSth = NewCosignedTreeHeadV1(fixedSth.SignedTreeHeadV1, nil) + s.cosignatureFrom = make(map[string]bool) + } +} diff --git a/sth_test.go b/sth_test.go new file mode 100644 index 0000000..3b84b8c --- /dev/null +++ b/sth_test.go @@ -0,0 +1,454 @@ +package stfe + +import ( + "context" + "crypto" + "fmt" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" + "github.com/google/trillian" + "github.com/system-transparency/stfe/namespace" + "github.com/system-transparency/stfe/namespace/testdata" +) + +func TestLatest(t *testing.T) { + for _, table := range []struct { + description string + signer crypto.Signer + trsp *trillian.GetLatestSignedLogRootResponse + terr error + wantErr bool + wantRsp *StItem + }{ + { + description: "invalid trillian response", + signer: cttestdata.NewSignerWithFixedSig(nil, testSignature), + terr: fmt.Errorf("internal server error"), + wantErr: true, + }, + { + description: "signature failure", + signer: cttestdata.NewSignerWithErr(nil, fmt.Errorf("signing failed")), + terr: fmt.Errorf("internal server error"), + wantErr: true, + }, + { + description: "valid", + signer: cttestdata.NewSignerWithFixedSig(nil, testSignature), + trsp: makeLatestSignedLogRootResponse(t, testTimestamp, testTreeSize, testNodeHash), + wantRsp: NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature), + }, + } { + func() { // run deferred functions at the end of each iteration + th := newTestHandler(t, table.signer, nil) + defer th.mockCtrl.Finish() + th.client.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.trsp, table.terr) + sth, err := th.instance.SthSource.Latest(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := sth, table.wantRsp; !reflect.DeepEqual(got, want) { + t.Errorf("got %v but wanted %v in test %q", got, want, table.description) + } + }() + } +} + +func TestStable(t *testing.T) { + sth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + for _, table := range []struct { + description string + source SthSource + wantRsp *StItem + wantErr bool + }{ + { + description: "no stable sth", + source: &ActiveSthSource{}, + wantErr: true, + }, + { + description: "valid", + source: &ActiveSthSource{ + nextSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil)}, + wantRsp: sth, + }, + } { + sth, err := table.source.Stable(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := sth, table.wantRsp; !reflect.DeepEqual(got, want) { + t.Errorf("got %v but wanted %v in test %q", got, want, table.description) + } + } +} + +func TestCosigned(t *testing.T) { + sth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + sigs := []SignatureV1{ + SignatureV1{ + Namespace: *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + Signature: testSignature, + }, + } + for _, table := range []struct { + description string + source SthSource + wantRsp *StItem + wantErr bool + }{ + { + description: "no cosigned sth: nil", + source: &ActiveSthSource{}, + wantErr: true, + }, + { + description: "no cosigned sth: nil signatures", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil), + }, + wantErr: true, + }, + { + description: "valid", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, sigs), + }, + wantRsp: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, sigs), + }, + } { + cosi, err := table.source.Cosigned(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := cosi, table.wantRsp; !reflect.DeepEqual(got, want) { + t.Errorf("got %v but wanted %v in test %q", got, want, table.description) + } + } +} + +func TestAddCosignature(t *testing.T) { + sth := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + wit1 := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk) + wit2 := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk2) + for _, table := range []struct { + description string + source *ActiveSthSource + req *StItem + wantWit []*namespace.Namespace + wantErr bool + }{ + { + description: "invalid: cosignature must target the stable sth", + source: &ActiveSthSource{ + nextSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil), + cosignatureFrom: make(map[string]bool), + }, + req: NewCosignedTreeHeadV1(NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize, testNodeHash)), testLogId, testSignature).SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantErr: true, + }, + { + description: "valid: adding duplicate into a pool of cosignatures", + source: &ActiveSthSource{ + nextSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit1.String(): true, + }, + }, + req: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantWit: []*namespace.Namespace{wit1}, + }, + { + description: "valid: adding into an empty pool of cosignatures", + source: &ActiveSthSource{ + nextSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, nil), + cosignatureFrom: make(map[string]bool), + }, + req: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantWit: []*namespace.Namespace{wit1}, + }, + { + description: "valid: adding into a pool of cosignatures", + source: &ActiveSthSource{ + nextSth: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit1.String(): true, + }, + }, + req: NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + }), + wantWit: []*namespace.Namespace{wit1, wit2}, + }, + } { + err := table.source.AddCosignature(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + + // Check that the next cosigned sth is updated + var sigs []SignatureV1 + for _, wit := range table.wantWit { + sigs = append(sigs, SignatureV1{ + Namespace: *wit, + Signature: testSignature, + }) + } + if got, want := table.source.nextSth, NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, sigs); !reflect.DeepEqual(got, want) { + t.Errorf("got %v but wanted %v in test %q", got, want, table.description) + } + // Check that the map tracking witness signatures is updated + if got, want := len(table.source.cosignatureFrom), len(table.wantWit); got != want { + t.Errorf("witness map got %d cosignatures but wanted %d in test %q", got, want, table.description) + } else { + for _, wit := range table.wantWit { + if _, ok := table.source.cosignatureFrom[wit.String()]; !ok { + t.Errorf("missing signature from witness %X in test %q", wit.String(), table.description) + } + } + } + } +} + +func TestRotate(t *testing.T) { + sth1 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature) + sth2 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+1000000, testTreeSize+1, testNodeHash)), testLogId, testSignature) + sth3 := NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp+2000000, testTreeSize+2, testNodeHash)), testLogId, testSignature) + wit1 := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk) + wit2 := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk2) + wit3 := mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk3) + for _, table := range []struct { + description string + source *ActiveSthSource + fixedSth *StItem + wantCurrSth *StItem + wantNextSth *StItem + wantWit []*namespace.Namespace + }{ + { + description: "not repeated cosigned and not repeated stable", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, nil), + nextSth: NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit1.String(): true, + }, + }, + fixedSth: sth3, + wantCurrSth: NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantNextSth: NewCosignedTreeHeadV1(sth3.SignedTreeHeadV1, nil), + wantWit: nil, // no cosignatures for the next stable sth yet + }, + { + description: "not repeated cosigned and repeated stable", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, nil), + nextSth: NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit1.String(): true, + }, + }, + fixedSth: sth2, + wantCurrSth: NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantNextSth: NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantWit: []*namespace.Namespace{wit1}, + }, + { + description: "repeated cosigned and not repeated stable", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + }), + nextSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit3, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit2.String(): true, + wit3.String(): true, + }, + }, + fixedSth: sth3, + wantCurrSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit3, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantNextSth: NewCosignedTreeHeadV1(sth3.SignedTreeHeadV1, nil), + wantWit: nil, // no cosignatures for the next stable sth yet + }, + { + description: "repeated cosigned and repeated stable", + source: &ActiveSthSource{ + currSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + }), + nextSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit3, + Signature: testSignature, + }, + }), + cosignatureFrom: map[string]bool{ + wit2.String(): true, + wit3.String(): true, + }, + }, + fixedSth: sth1, + wantCurrSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit3, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantNextSth: NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []SignatureV1{ + SignatureV1{ + Namespace: *wit2, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit3, + Signature: testSignature, + }, + SignatureV1{ + Namespace: *wit1, + Signature: testSignature, + }, + }), + wantWit: []*namespace.Namespace{wit1, wit2, wit3}, + }, + } { + table.source.rotate(table.fixedSth) + if got, want := table.source.currSth, table.wantCurrSth; !reflect.DeepEqual(got, want) { + t.Errorf("got currSth %X but wanted %X in test %q", got, want, table.description) + } + if got, want := table.source.nextSth, table.wantNextSth; !reflect.DeepEqual(got, want) { + t.Errorf("got nextSth %X but wanted %X in test %q", got, want, table.description) + } + if got, want := len(table.source.cosignatureFrom), len(table.wantWit); got != want { + t.Errorf("witness map got %d cosignatures but wanted %d in test %q", got, want, table.description) + } else { + for _, wit := range table.wantWit { + if _, ok := table.source.cosignatureFrom[wit.String()]; !ok { + t.Errorf("missing signature from witness %X in test %q", wit.String(), table.description) + } + } + } + // check that adding cosignatures to stable will not effect cosigned sth + wantLen := len(table.source.currSth.CosignedTreeHeadV1.SignatureV1) + table.source.nextSth.CosignedTreeHeadV1.SignatureV1 = append(table.source.nextSth.CosignedTreeHeadV1.SignatureV1, SignatureV1{Namespace: *wit1, Signature: testSignature}) + if gotLen := len(table.source.currSth.CosignedTreeHeadV1.SignatureV1); gotLen != wantLen { + t.Errorf("adding cosignatures to the stable sth modifies the fixated cosigned sth in test %q", table.description) + } + } +} @@ -21,6 +21,7 @@ const ( StFormatConsistencyProofV1 StFormat = 3 StFormatInclusionProofV1 StFormat = 4 StFormatChecksumV1 = 5 + StFormatCosignedTreeHeadV1 = 6 ) // StItem references a versioned item based on a given format specifier @@ -31,6 +32,7 @@ type StItem struct { ConsistencyProofV1 *ConsistencyProofV1 `tls:"selector:Format,val:3"` InclusionProofV1 *InclusionProofV1 `tls:"selector:Format,val:4"` ChecksumV1 *ChecksumV1 `tls:"selector:Format,val:5"` + CosignedTreeHeadV1 *CosignedTreeHeadV1 `tls:"selector:Format,val:6"` } // SignedTreeHeadV1 is a signed tree head as defined by RFC 6962/bis, §4.10 @@ -79,6 +81,18 @@ type TreeHeadV1 struct { Extension []byte `tls:"minlen:0,maxlen:65535"` } +// CosignedTreeheadV1 is a cosigned STH +type CosignedTreeHeadV1 struct { + SignedTreeHeadV1 SignedTreeHeadV1 + SignatureV1 []SignatureV1 `tls:"minlen:0,maxlen:4294967295"` +} + +// SignatureV1 is a detached signature that was produced by a namespace +type SignatureV1 struct { + Namespace namespace.Namespace + Signature []byte `tls:"minlen:1,maxlen:65535"` +} + // NodeHash is a Merkle tree hash as defined by RFC 6962/bis, §4.9 type NodeHash struct { Data []byte `tls:"minlen:32,maxlen:255"` @@ -103,6 +117,8 @@ func (f StFormat) String() string { return "inclusion_proof_v1" case StFormatChecksumV1: return "checksum_v1" + case StFormatCosignedTreeHeadV1: + return "cosigned_tree_head_v1" default: return fmt.Sprintf("Unknown StFormat: %d", f) } @@ -120,6 +136,8 @@ func (i StItem) String() string { return fmt.Sprintf("Format(%s): %s", i.Format, i.SignedDebugInfoV1) case StFormatSignedTreeHeadV1: return fmt.Sprintf("Format(%s): %s", i.Format, i.SignedTreeHeadV1) + case StFormatCosignedTreeHeadV1: + return fmt.Sprintf("Format(%s): %s", i.Format, i.CosignedTreeHeadV1) default: return fmt.Sprintf("unknown StItem: %s", i.Format) } @@ -149,6 +167,10 @@ func (th TreeHeadV1) String() string { return fmt.Sprintf("Timestamp(%s) TreeSize(%d) RootHash(%s)", time.Unix(int64(th.Timestamp/1000), 0), th.TreeSize, b64(th.RootHash.Data)) } +func (i CosignedTreeHeadV1) String() string { + return fmt.Sprintf("SignedTreeHead(%s) #Cosignatures(%d)", i.SignedTreeHeadV1.String(), len(i.SignatureV1)) +} + // Marshal serializes an Stitem as defined by RFC 5246 func (i *StItem) Marshal() ([]byte, error) { serialized, err := tls.Marshal(*i) @@ -264,6 +286,17 @@ func NewTreeHeadV1(lr *types.LogRootV1) *TreeHeadV1 { } } +// NewCosignedTreeHeadV1 creates a new StItem of type cosigned_tree_head_v1 +func NewCosignedTreeHeadV1(sth *SignedTreeHeadV1, sigs []SignatureV1) *StItem { + return &StItem{ + Format: StFormatCosignedTreeHeadV1, + CosignedTreeHeadV1: &CosignedTreeHeadV1{ + SignedTreeHeadV1: *sth, + SignatureV1: sigs, + }, + } +} + func b64(b []byte) string { return base64.StdEncoding.EncodeToString(b) } diff --git a/type_test.go b/type_test.go index 2e0f4b6..6ac3b29 100644 --- a/type_test.go +++ b/type_test.go @@ -222,7 +222,7 @@ func TestEncDecStItem(t *testing.T) { description: "too large checksum", item: NewChecksumV1(testPackage, make([]byte, checksumMax+1), mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk)), wantErr: true, - }, // namespace (un)marshal is already tested in its own package (skip) + }, { description: "ok checksum: min", item: NewChecksumV1(testPackage, make([]byte, checksumMin), mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk)), @@ -230,7 +230,19 @@ func TestEncDecStItem(t *testing.T) { { description: "ok checksum: max", item: NewChecksumV1(testPackage, make([]byte, checksumMax), mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk)), - }, + }, // namespace (un)marshal is already tested in its own package (skip) + { + description: "ok cosigned sth", + item: NewCosignedTreeHeadV1( + NewSignedTreeHeadV1(NewTreeHeadV1(makeTrillianLogRoot(t, testTimestamp, testTreeSize, testNodeHash)), testLogId, testSignature).SignedTreeHeadV1, + []SignatureV1{ + SignatureV1{ + *mustNewNamespaceEd25519V1(t, testdata.Ed25519Vk), + testSignature, + }, + }, + ), + }, // TODO: the only thing that is not tested elsewhere for cosigned sth is bound on signature. Unify signature into a type => some tests go away. } { b, err := table.item.MarshalB64() if err != nil && !table.wantErr { |