diff options
-rw-r--r-- | trillian/client.go | 29 | ||||
-rw-r--r-- | trillian/client_test.go | 135 |
2 files changed, 162 insertions, 2 deletions
diff --git a/trillian/client.go b/trillian/client.go index cbeb1ca..c619b03 100644 --- a/trillian/client.go +++ b/trillian/client.go @@ -139,5 +139,32 @@ func (c *Client) GetInclusionProof(ctx context.Context, req *types.InclusionProo } func (c *Client) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) { - return nil, fmt.Errorf("TODO") + rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ + LogId: c.TreeID, + StartIndex: int64(req.StartSize), + Count: int64(req.EndSize-req.StartSize) + 1, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { + return nil, fmt.Errorf("unexpected number of leaves: %d", got) + } + var list types.LeafList + for i, leaf := range rsp.Leaves { + leafIndex := int64(req.StartSize + uint64(i)) + if leafIndex != leaf.LeafIndex { + return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) + } + + var l types.Leaf + if err := l.Unmarshal(leaf.LeafValue); err != nil { + return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) + } + list = append(list[:], &l) + } + return &list, nil } diff --git a/trillian/client_test.go b/trillian/client_test.go index 001d1dc..bbc61ca 100644 --- a/trillian/client_test.go +++ b/trillian/client_test.go @@ -397,4 +397,137 @@ func TestGetInclusionProof(t *testing.T) { } } -func TestGetLeaves(t *testing.T) {} +func TestGetLeaves(t *testing.T) { + req := &types.LeavesRequest{ + StartSize: 1, + EndSize: 2, + } + firstLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + secondLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + + for _, table := range []struct { + description string + req *types.LeavesRequest + rsp *trillian.GetLeavesByRangeResponse + err error + wantErr bool + wantLeaves *types.LeafList + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: unexpected number of leaves", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (1/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 3, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (2/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal()[1:], + LeafIndex: 2, + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 2, + }, + }, + }, + wantLeaves: &types.LeafList{ + firstLeaf, + secondLeaf, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mockclient.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := Client{GRPC: grpc} + + leaves, err := client.GetLeaves(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { + t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} |