diff options
Diffstat (limited to 'pkg/trillian')
-rw-r--r-- | pkg/trillian/client.go | 178 | ||||
-rw-r--r-- | pkg/trillian/client_test.go | 533 | ||||
-rw-r--r-- | pkg/trillian/util.go | 33 |
3 files changed, 744 insertions, 0 deletions
diff --git a/pkg/trillian/client.go b/pkg/trillian/client.go new file mode 100644 index 0000000..9523e56 --- /dev/null +++ b/pkg/trillian/client.go @@ -0,0 +1,178 @@ +package trillian + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/google/trillian" + ttypes "github.com/google/trillian/types" + "github.com/system-transparency/stfe/pkg/types" + "google.golang.org/grpc/codes" +) + +type Client interface { + AddLeaf(context.Context, *types.LeafRequest) error + GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) + GetTreeHead(context.Context) (*types.TreeHead, error) + GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error) + GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error) +} + +// TrillianClient is a wrapper around the Trillian gRPC client. +type TrillianClient struct { + // TreeID is a Merkle tree identifier that Trillian uses + TreeID int64 + + // GRPC is a Trillian gRPC client + GRPC trillian.TrillianLogClient +} + +func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error { + leaf := types.Leaf{ + Message: req.Message, + SigIdent: types.SigIdent{ + Signature: req.Signature, + KeyHash: types.Hash(req.VerificationKey[:]), + }, + } + serialized := leaf.Marshal() + + glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized)) + rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{ + LogId: c.TreeID, + Leaf: &trillian.LogLeaf{ + LeafValue: serialized, + }, + }) + if err != nil { + return fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return fmt.Errorf("no response") + } + if rsp.QueuedLeaf == nil { + return fmt.Errorf("no queued leaf") + } + if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists { + return fmt.Errorf("leaf is already queued or included") + } + return nil +} + +func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) { + rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ + LogId: c.TreeID, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if rsp.SignedLogRoot == nil { + return nil, fmt.Errorf("no signed log root") + } + if rsp.SignedLogRoot.LogRoot == nil { + return nil, fmt.Errorf("no log root") + } + var r ttypes.LogRootV1 + if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { + return nil, fmt.Errorf("no log root: unmarshal failed: %v", err) + } + if len(r.RootHash) != types.HashSize { + return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash)) + } + return treeHeadFromLogRoot(&r), nil +} + +func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { + rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ + LogId: c.TreeID, + FirstTreeSize: int64(req.OldSize), + SecondTreeSize: int64(req.NewSize), + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if rsp.Proof == nil { + return nil, fmt.Errorf("no consistency proof") + } + if len(rsp.Proof.Hashes) == 0 { + return nil, fmt.Errorf("not a consistency proof: empty") + } + path, err := nodePathFromHashes(rsp.Proof.Hashes) + if err != nil { + return nil, fmt.Errorf("not a consistency proof: %v", err) + } + return &types.ConsistencyProof{ + OldSize: req.OldSize, + NewSize: req.NewSize, + Path: path, + }, nil +} + +func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) { + rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ + LogId: c.TreeID, + LeafHash: req.LeafHash[:], + TreeSize: int64(req.TreeSize), + OrderBySequence: true, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if len(rsp.Proof) != 1 { + return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof)) + } + proof := rsp.Proof[0] + if len(proof.Hashes) == 0 { + return nil, fmt.Errorf("not an inclusion proof: empty") + } + path, err := nodePathFromHashes(proof.Hashes) + if err != nil { + return nil, fmt.Errorf("not an inclusion proof: %v", err) + } + return &types.InclusionProof{ + TreeSize: req.TreeSize, + LeafIndex: uint64(proof.LeafIndex), + Path: path, + }, nil +} + +func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) { + rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ + LogId: c.TreeID, + StartIndex: int64(req.StartSize), + Count: int64(req.EndSize-req.StartSize) + 1, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { + return nil, fmt.Errorf("unexpected number of leaves: %d", got) + } + var list types.LeafList + for i, leaf := range rsp.Leaves { + leafIndex := int64(req.StartSize + uint64(i)) + if leafIndex != leaf.LeafIndex { + return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) + } + + var l types.Leaf + if err := l.Unmarshal(leaf.LeafValue); err != nil { + return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) + } + list = append(list[:], &l) + } + return &list, nil +} diff --git a/pkg/trillian/client_test.go b/pkg/trillian/client_test.go new file mode 100644 index 0000000..6b3d881 --- /dev/null +++ b/pkg/trillian/client_test.go @@ -0,0 +1,533 @@ +package trillian + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/trillian" + ttypes "github.com/google/trillian/types" + "github.com/system-transparency/stfe/pkg/mocks" + "github.com/system-transparency/stfe/pkg/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestAddLeaf(t *testing.T) { + req := &types.LeafRequest{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + Signature: &[types.SignatureSize]byte{}, + VerificationKey: &[types.VerificationKeySize]byte{}, + DomainHint: "example.com", + } + for _, table := range []struct { + description string + req *types.LeafRequest + rsp *trillian.QueueLeafResponse + err error + wantErr bool + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: no queued leaf", + req: req, + rsp: &trillian.QueueLeafResponse{}, + wantErr: true, + }, + { + description: "invalid: leaf is already queued or included", + req: req, + rsp: &trillian.QueueLeafResponse{ + QueuedLeaf: &trillian.QueuedLogLeaf{ + Leaf: &trillian.LogLeaf{ + LeafValue: req.Message.Marshal(), + }, + Status: status.New(codes.AlreadyExists, "duplicate").Proto(), + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.QueueLeafResponse{ + QueuedLeaf: &trillian.QueuedLogLeaf{ + Leaf: &trillian.LogLeaf{ + LeafValue: req.Message.Marshal(), + }, + Status: status.New(codes.OK, "ok").Proto(), + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + err := client.AddLeaf(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + }() + } +} + +func TestGetTreeHead(t *testing.T) { + // valid root + root := &ttypes.LogRootV1{ + TreeSize: 0, + RootHash: make([]byte, types.HashSize), + TimestampNanos: 1622585623133599429, + } + buf, err := root.MarshalBinary() + if err != nil { + t.Fatalf("must marshal log root: %v", err) + } + // invalid root + root.RootHash = make([]byte, types.HashSize+1) + bufBadHash, err := root.MarshalBinary() + if err != nil { + t.Fatalf("must marshal log root: %v", err) + } + + for _, table := range []struct { + description string + rsp *trillian.GetLatestSignedLogRootResponse + err error + wantErr bool + wantTh *types.TreeHead + }{ + { + description: "invalid: backend failure", + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + wantErr: true, + }, + { + description: "invalid: no signed log root", + rsp: &trillian.GetLatestSignedLogRootResponse{}, + wantErr: true, + }, + { + description: "invalid: no log root", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{}, + }, + wantErr: true, + }, + { + description: "invalid: no log root: unmarshal failed", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: buf[1:], + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected hash length", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: bufBadHash, + }, + }, + wantErr: true, + }, + { + description: "valid", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: buf, + }, + }, + wantTh: &types.TreeHead{ + Timestamp: 1622585623, + TreeSize: 0, + RootHash: &[types.HashSize]byte{}, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + th, err := client.GetTreeHead(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) { + t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetConsistencyProof(t *testing.T) { + req := &types.ConsistencyProofRequest{ + OldSize: 1, + NewSize: 3, + } + for _, table := range []struct { + description string + req *types.ConsistencyProofRequest + rsp *trillian.GetConsistencyProofResponse + err error + wantErr bool + wantProof *types.ConsistencyProof + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: no consistency proof", + req: req, + rsp: &trillian.GetConsistencyProofResponse{}, + wantErr: true, + }, + { + description: "invalid: not a consistency proof (1/2)", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{}, + }, + }, + wantErr: true, + }, + { + description: "invalid: not a consistency proof (2/2)", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize+1), + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize), + }, + }, + }, + wantProof: &types.ConsistencyProof{ + OldSize: 1, + NewSize: 3, + Path: []*[types.HashSize]byte{ + &[types.HashSize]byte{}, + &[types.HashSize]byte{}, + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + proof, err := client.GetConsistencyProof(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { + t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetInclusionProof(t *testing.T) { + req := &types.InclusionProofRequest{ + TreeSize: 4, + LeafHash: &[types.HashSize]byte{}, + } + for _, table := range []struct { + description string + req *types.InclusionProofRequest + rsp *trillian.GetInclusionProofByHashResponse + err error + wantErr bool + wantProof *types.InclusionProof + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: bad proof count", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{}, + &trillian.Proof{}, + }, + }, + wantErr: true, + }, + { + description: "invalid: not an inclusion proof (1/2)", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{}, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: not an inclusion proof (2/2)", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize+1), + }, + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize), + }, + }, + }, + }, + wantProof: &types.InclusionProof{ + TreeSize: 4, + LeafIndex: 1, + Path: []*[types.HashSize]byte{ + &[types.HashSize]byte{}, + &[types.HashSize]byte{}, + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + proof, err := client.GetInclusionProof(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { + t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetLeaves(t *testing.T) { + req := &types.LeavesRequest{ + StartSize: 1, + EndSize: 2, + } + firstLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + secondLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + + for _, table := range []struct { + description string + req *types.LeavesRequest + rsp *trillian.GetLeavesByRangeResponse + err error + wantErr bool + wantLeaves *types.LeafList + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: unexpected number of leaves", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (1/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 3, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (2/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal()[1:], + LeafIndex: 2, + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 2, + }, + }, + }, + wantLeaves: &types.LeafList{ + firstLeaf, + secondLeaf, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + leaves, err := client.GetLeaves(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { + t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} diff --git a/pkg/trillian/util.go b/pkg/trillian/util.go new file mode 100644 index 0000000..4cf31fb --- /dev/null +++ b/pkg/trillian/util.go @@ -0,0 +1,33 @@ +package trillian + +import ( + "fmt" + + trillian "github.com/google/trillian/types" + siglog "github.com/system-transparency/stfe/pkg/types" +) + +func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead { + var hash [siglog.HashSize]byte + th := siglog.TreeHead{ + Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000), + TreeSize: uint64(lr.TreeSize), + RootHash: &hash, + } + copy(th.RootHash[:], lr.RootHash) + return &th +} + +func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) { + var path []*[siglog.HashSize]byte + for _, hash := range hashes { + if len(hash) != siglog.HashSize { + return nil, fmt.Errorf("unexpected hash length: %v", len(hash)) + } + + var h [siglog.HashSize]byte + copy(h[:], hash) + path = append(path, &h) + } + return path, nil +} |