diff options
Diffstat (limited to 'internal')
28 files changed, 4207 insertions, 0 deletions
| diff --git a/internal/db/client.go b/internal/db/client.go new file mode 100644 index 0000000..ce3bb2b --- /dev/null +++ b/internal/db/client.go @@ -0,0 +1,18 @@ +package db + +import ( +	"context" + +	"git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// Client is an interface that interacts with a log's database backend +type Client interface { +	AddLeaf(context.Context, *requests.Leaf, uint64) (bool, error) +	AddSequencedLeaves(ctx context.Context, leaves types.Leaves, index int64) error +	GetTreeHead(context.Context) (*types.TreeHead, error) +	GetConsistencyProof(context.Context, *requests.ConsistencyProof) (*types.ConsistencyProof, error) +	GetInclusionProof(context.Context, *requests.InclusionProof) (*types.InclusionProof, error) +	GetLeaves(context.Context, *requests.Leaves) (*types.Leaves, error) +} diff --git a/internal/db/trillian.go b/internal/db/trillian.go new file mode 100644 index 0000000..e8a9945 --- /dev/null +++ b/internal/db/trillian.go @@ -0,0 +1,232 @@ +package db + +import ( +	"context" +	"fmt" +	"time" + +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/google/trillian" +	trillianTypes "github.com/google/trillian/types" +	"google.golang.org/grpc/codes" +	"google.golang.org/grpc/status" +) + +// TrillianClient implements the Client interface for Trillian's gRPC backend +type TrillianClient struct { +	// TreeID is a Merkle tree identifier that Trillian uses +	TreeID int64 + +	// GRPC is a Trillian gRPC client +	GRPC trillian.TrillianLogClient +} + +// AddLeaf adds a leaf to the tree and returns true if the leaf has +// been sequenced into the tree of size treeSize. +func (c *TrillianClient) AddLeaf(ctx context.Context, req *requests.Leaf, treeSize uint64) (bool, error) { +	leaf := types.Leaf{ +		Statement: types.Statement{ +			ShardHint: req.ShardHint, +			Checksum:  *merkle.HashFn(req.Message[:]), +		}, +		Signature: req.Signature, +		KeyHash:   *merkle.HashFn(req.PublicKey[:]), +	} +	serialized := leaf.ToBinary() + +	log.Debug("queueing leaf request: %x", merkle.HashLeafNode(serialized)) +	_, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{ +		LogId: c.TreeID, +		Leaf: &trillian.LogLeaf{ +			LeafValue: serialized, +		}, +	}) +	switch status.Code(err) { +	case codes.OK: +	case codes.AlreadyExists: +	default: +		log.Warning("gRPC error: %v", err) +		return false, fmt.Errorf("back-end failure") +	} +	_, err = c.GetInclusionProof(ctx, &requests.InclusionProof{treeSize, *merkle.HashLeafNode(serialized)}) +	return err == nil, nil +} + +// AddSequencedLeaves adds a set of already sequenced leaves to the tree. +func (c *TrillianClient) AddSequencedLeaves(ctx context.Context, leaves types.Leaves, index int64) error { +	trilLeaves := make([]*trillian.LogLeaf, len(leaves)) +	for i, leaf := range leaves { +		trilLeaves[i] = &trillian.LogLeaf{ +			LeafValue: leaf.ToBinary(), +			LeafIndex: index + int64(i), +		} +	} + +	req := trillian.AddSequencedLeavesRequest{ +		LogId:  c.TreeID, +		Leaves: trilLeaves, +	} +	log.Debug("adding sequenced leaves: count %d", len(trilLeaves)) +	var err error +	for wait := 1; wait < 30; wait *= 2 { +		var rsp *trillian.AddSequencedLeavesResponse +		rsp, err = c.GRPC.AddSequencedLeaves(ctx, &req) +		switch status.Code(err) { +		case codes.ResourceExhausted: +			log.Info("waiting %d seconds before retrying to add %d leaves, reason: %v", wait, len(trilLeaves), err) +			time.Sleep(time.Second * time.Duration(wait)) +			continue +		case codes.OK: +			if rsp == nil { +				return fmt.Errorf("GRPC.AddSequencedLeaves no response") +			} +			// FIXME: check rsp.Results.QueuedLogLeaf +			return nil +		default: +			return fmt.Errorf("GRPC.AddSequencedLeaves error: %v", err) +		} +	} + +	return fmt.Errorf("giving up on adding %d leaves", len(trilLeaves)) +} + +func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) { +	rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ +		LogId: c.TreeID, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if rsp.SignedLogRoot == nil { +		return nil, fmt.Errorf("no signed log root") +	} +	if rsp.SignedLogRoot.LogRoot == nil { +		return nil, fmt.Errorf("no log root") +	} +	var r trillianTypes.LogRootV1 +	if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { +		return nil, fmt.Errorf("no log root: unmarshal failed: %v", err) +	} +	if len(r.RootHash) != merkle.HashSize { +		return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash)) +	} +	return treeHeadFromLogRoot(&r), nil +} + +func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *requests.ConsistencyProof) (*types.ConsistencyProof, error) { +	rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ +		LogId:          c.TreeID, +		FirstTreeSize:  int64(req.OldSize), +		SecondTreeSize: int64(req.NewSize), +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if rsp.Proof == nil { +		return nil, fmt.Errorf("no consistency proof") +	} +	if len(rsp.Proof.Hashes) == 0 { +		return nil, fmt.Errorf("not a consistency proof: empty") +	} +	path, err := nodePathFromHashes(rsp.Proof.Hashes) +	if err != nil { +		return nil, fmt.Errorf("not a consistency proof: %v", err) +	} +	return &types.ConsistencyProof{ +		OldSize: req.OldSize, +		NewSize: req.NewSize, +		Path:    path, +	}, nil +} + +func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *requests.InclusionProof) (*types.InclusionProof, error) { +	rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ +		LogId:           c.TreeID, +		LeafHash:        req.LeafHash[:], +		TreeSize:        int64(req.TreeSize), +		OrderBySequence: true, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if len(rsp.Proof) != 1 { +		return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof)) +	} +	proof := rsp.Proof[0] +	if len(proof.Hashes) == 0 { +		return nil, fmt.Errorf("not an inclusion proof: empty") +	} +	path, err := nodePathFromHashes(proof.Hashes) +	if err != nil { +		return nil, fmt.Errorf("not an inclusion proof: %v", err) +	} +	return &types.InclusionProof{ +		TreeSize:  req.TreeSize, +		LeafIndex: uint64(proof.LeafIndex), +		Path:      path, +	}, nil +} + +func (c *TrillianClient) GetLeaves(ctx context.Context, req *requests.Leaves) (*types.Leaves, error) { +	rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ +		LogId:      c.TreeID, +		StartIndex: int64(req.StartSize), +		Count:      int64(req.EndSize-req.StartSize) + 1, +	}) +	if err != nil { +		return nil, fmt.Errorf("backend failure: %v", err) +	} +	if rsp == nil { +		return nil, fmt.Errorf("no response") +	} +	if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { +		return nil, fmt.Errorf("unexpected number of leaves: %d", got) +	} +	var list types.Leaves = make([]types.Leaf, 0, len(rsp.Leaves)) +	for i, leaf := range rsp.Leaves { +		leafIndex := int64(req.StartSize + uint64(i)) +		if leafIndex != leaf.LeafIndex { +			return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) +		} + +		var l types.Leaf +		if err := l.FromBinary(leaf.LeafValue); err != nil { +			return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) +		} +		list = append(list[:], l) +	} +	return &list, nil +} + +func treeHeadFromLogRoot(lr *trillianTypes.LogRootV1) *types.TreeHead { +	th := types.TreeHead{ +		Timestamp: uint64(time.Now().Unix()), +		TreeSize:  uint64(lr.TreeSize), +	} +	copy(th.RootHash[:], lr.RootHash) +	return &th +} + +func nodePathFromHashes(hashes [][]byte) ([]merkle.Hash, error) { +	path := make([]merkle.Hash, len(hashes)) +	for i := 0; i < len(hashes); i++ { +		if len(hashes[i]) != merkle.HashSize { +			return nil, fmt.Errorf("unexpected hash length: %v", len(hashes[i])) +		} + +		copy(path[i][:], hashes[i]) +	} +	return path, nil +} diff --git a/internal/db/trillian_test.go b/internal/db/trillian_test.go new file mode 100644 index 0000000..9ae682e --- /dev/null +++ b/internal/db/trillian_test.go @@ -0,0 +1,543 @@ +package db + +import ( +	"bytes" +	"context" +	"fmt" +	"reflect" +	"testing" +	"time" + +	mocksTrillian "git.sigsum.org/log-go/internal/mocks/trillian" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +	"github.com/google/trillian" +	ttypes "github.com/google/trillian/types" +	//"google.golang.org/grpc/codes" +	//"google.golang.org/grpc/status" +) + +// TODO: Add TestAddSequencedLeaves +// TODO: Update TestAddLeaf +//func TestAddLeaf(t *testing.T) { +//	req := &requests.Leaf{ +//		ShardHint:  0, +//		Message:    merkle.Hash{}, +//		Signature:  types.Signature{}, +//		PublicKey:  types.PublicKey{}, +//		DomainHint: "example.com", +//	} +//	for _, table := range []struct { +//		description string +//		req         *requests.Leaf +//		rsp         *trillian.QueueLeafResponse +//		err         error +//		wantErr     bool +//	}{ +//		{ +//			description: "invalid: backend failure", +//			req:         req, +//			err:         fmt.Errorf("something went wrong"), +//			wantErr:     true, +//		}, +//		{ +//			description: "invalid: no response", +//			req:         req, +//			wantErr:     true, +//		}, +//		{ +//			description: "invalid: no queued leaf", +//			req:         req, +//			rsp:         &trillian.QueueLeafResponse{}, +//			wantErr:     true, +//		}, +//		{ +//			description: "invalid: leaf is already queued or included", +//			req:         req, +//			rsp: &trillian.QueueLeafResponse{ +//				QueuedLeaf: &trillian.QueuedLogLeaf{ +//					Leaf: &trillian.LogLeaf{ +//						LeafValue: []byte{0}, // does not matter for test +//					}, +//					Status: status.New(codes.AlreadyExists, "duplicate").Proto(), +//				}, +//			}, +//			wantErr: true, +//		}, +//		{ +//			description: "valid", +//			req:         req, +//			rsp: &trillian.QueueLeafResponse{ +//				QueuedLeaf: &trillian.QueuedLogLeaf{ +//					Leaf: &trillian.LogLeaf{ +//						LeafValue: []byte{0}, // does not matter for test +//					}, +//					Status: status.New(codes.OK, "ok").Proto(), +//				}, +//			}, +//		}, +//	} { +//		// Run deferred functions at the end of each iteration +//		func() { +//			ctrl := gomock.NewController(t) +//			defer ctrl.Finish() +//			grpc := mocksTrillian.NewMockTrillianLogClient(ctrl) +//			grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +//			client := TrillianClient{GRPC: grpc} +// +//			_, err := client.AddLeaf(context.Background(), table.req, 0) +//			if got, want := err != nil, table.wantErr; got != want { +//				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +//			} +//		}() +//	} +//} + +func TestGetTreeHead(t *testing.T) { +	// valid root +	root := &ttypes.LogRootV1{ +		TreeSize:       0, +		RootHash:       make([]byte, merkle.HashSize), +		TimestampNanos: 1622585623133599429, +	} +	buf, err := root.MarshalBinary() +	if err != nil { +		t.Fatalf("must marshal log root: %v", err) +	} +	// invalid root +	root.RootHash = make([]byte, merkle.HashSize+1) +	bufBadHash, err := root.MarshalBinary() +	if err != nil { +		t.Fatalf("must marshal log root: %v", err) +	} + +	for _, table := range []struct { +		description string +		rsp         *trillian.GetLatestSignedLogRootResponse +		err         error +		wantErr     bool +		wantTh      *types.TreeHead +	}{ +		{ +			description: "invalid: backend failure", +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			wantErr:     true, +		}, +		{ +			description: "invalid: no signed log root", +			rsp:         &trillian.GetLatestSignedLogRootResponse{}, +			wantErr:     true, +		}, +		{ +			description: "invalid: no log root", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: no log root: unmarshal failed", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: buf[1:], +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected hash length", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: bufBadHash, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			rsp: &trillian.GetLatestSignedLogRootResponse{ +				SignedLogRoot: &trillian.SignedLogRoot{ +					LogRoot: buf, +				}, +			}, +			wantTh: &types.TreeHead{ +				Timestamp: 1622585623, +				TreeSize:  0, +				RootHash:  merkle.Hash{}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocksTrillian.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			th, err := client.GetTreeHead(context.Background()) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} + +			// we would need a clock that can be mocked to make a nicer test +			now := uint64(time.Now().Unix()) +			if got, wantLow, wantHigh := th.Timestamp, now-5, now+5; got < wantLow || got > wantHigh { +				t.Errorf("got tree head with timestamp %d but wanted between [%d, %d] in test %q", +					got, wantLow, wantHigh, table.description) +			} +			if got, want := th.TreeSize, table.wantTh.TreeSize; got != want { +				t.Errorf("got tree head with tree size %d but wanted %d in test %q", got, want, table.description) +			} +			if got, want := th.RootHash[:], table.wantTh.RootHash[:]; !bytes.Equal(got, want) { +				t.Errorf("got root hash %x but wanted %x in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetConsistencyProof(t *testing.T) { +	req := &requests.ConsistencyProof{ +		OldSize: 1, +		NewSize: 3, +	} +	for _, table := range []struct { +		description string +		req         *requests.ConsistencyProof +		rsp         *trillian.GetConsistencyProofResponse +		err         error +		wantErr     bool +		wantProof   *types.ConsistencyProof +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: no consistency proof", +			req:         req, +			rsp:         &trillian.GetConsistencyProofResponse{}, +			wantErr:     true, +		}, +		{ +			description: "invalid: not a consistency proof (1/2)", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not a consistency proof (2/2)", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{ +						make([]byte, merkle.HashSize), +						make([]byte, merkle.HashSize+1), +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetConsistencyProofResponse{ +				Proof: &trillian.Proof{ +					Hashes: [][]byte{ +						make([]byte, merkle.HashSize), +						make([]byte, merkle.HashSize), +					}, +				}, +			}, +			wantProof: &types.ConsistencyProof{ +				OldSize: 1, +				NewSize: 3, +				Path: []merkle.Hash{ +					merkle.Hash{}, +					merkle.Hash{}, +				}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocksTrillian.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			proof, err := client.GetConsistencyProof(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { +				t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetInclusionProof(t *testing.T) { +	req := &requests.InclusionProof{ +		TreeSize: 4, +		LeafHash: merkle.Hash{}, +	} +	for _, table := range []struct { +		description string +		req         *requests.InclusionProof +		rsp         *trillian.GetInclusionProofByHashResponse +		err         error +		wantErr     bool +		wantProof   *types.InclusionProof +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: bad proof count", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{}, +					&trillian.Proof{}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not an inclusion proof (1/2)", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes:    [][]byte{}, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: not an inclusion proof (2/2)", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes: [][]byte{ +							make([]byte, merkle.HashSize), +							make([]byte, merkle.HashSize+1), +						}, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetInclusionProofByHashResponse{ +				Proof: []*trillian.Proof{ +					&trillian.Proof{ +						LeafIndex: 1, +						Hashes: [][]byte{ +							make([]byte, merkle.HashSize), +							make([]byte, merkle.HashSize), +						}, +					}, +				}, +			}, +			wantProof: &types.InclusionProof{ +				TreeSize:  4, +				LeafIndex: 1, +				Path: []merkle.Hash{ +					merkle.Hash{}, +					merkle.Hash{}, +				}, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocksTrillian.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			proof, err := client.GetInclusionProof(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { +				t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetLeaves(t *testing.T) { +	req := &requests.Leaves{ +		StartSize: 1, +		EndSize:   2, +	} +	firstLeaf := &types.Leaf{ +		Statement: types.Statement{ +			ShardHint: 0, +			Checksum:  merkle.Hash{}, +		}, +		Signature: types.Signature{}, +		KeyHash:   merkle.Hash{}, +	} +	secondLeaf := &types.Leaf{ +		Statement: types.Statement{ +			ShardHint: 0, +			Checksum:  merkle.Hash{}, +		}, +		Signature: types.Signature{}, +		KeyHash:   merkle.Hash{}, +	} + +	for _, table := range []struct { +		description string +		req         *requests.Leaves +		rsp         *trillian.GetLeavesByRangeResponse +		err         error +		wantErr     bool +		wantLeaves  *types.Leaves +	}{ +		{ +			description: "invalid: backend failure", +			req:         req, +			err:         fmt.Errorf("something went wrong"), +			wantErr:     true, +		}, +		{ +			description: "invalid: no response", +			req:         req, +			wantErr:     true, +		}, +		{ +			description: "invalid: unexpected number of leaves", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.ToBinary(), +						LeafIndex: 1, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected leaf (1/2)", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.ToBinary(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.ToBinary(), +						LeafIndex: 3, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "invalid: unexpected leaf (2/2)", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.ToBinary(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.ToBinary()[1:], +						LeafIndex: 2, +					}, +				}, +			}, +			wantErr: true, +		}, +		{ +			description: "valid", +			req:         req, +			rsp: &trillian.GetLeavesByRangeResponse{ +				Leaves: []*trillian.LogLeaf{ +					&trillian.LogLeaf{ +						LeafValue: firstLeaf.ToBinary(), +						LeafIndex: 1, +					}, +					&trillian.LogLeaf{ +						LeafValue: secondLeaf.ToBinary(), +						LeafIndex: 2, +					}, +				}, +			}, +			wantLeaves: &types.Leaves{ +				*firstLeaf, +				*secondLeaf, +			}, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			grpc := mocksTrillian.NewMockTrillianLogClient(ctrl) +			grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			client := TrillianClient{GRPC: grpc} + +			leaves, err := client.GetLeaves(context.Background(), table.req) +			if got, want := err != nil, table.wantErr; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} +			if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { +				t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} diff --git a/internal/mocks/client/client.go b/internal/mocks/client/client.go new file mode 100644 index 0000000..6cb3b3d --- /dev/null +++ b/internal/mocks/client/client.go @@ -0,0 +1,170 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sigsum.org/sigsum-go/pkg/client (interfaces: Client) + +// Package client is a generated GoMock package. +package client + +import ( +	context "context" +	reflect "reflect" + +	requests "git.sigsum.org/sigsum-go/pkg/requests" +	types "git.sigsum.org/sigsum-go/pkg/types" +	gomock "github.com/golang/mock/gomock" +) + +// MockClient is a mock of Client interface. +type MockClient struct { +	ctrl     *gomock.Controller +	recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { +	mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { +	mock := &MockClient{ctrl: ctrl} +	mock.recorder = &MockClientMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { +	return m.recorder +} + +// AddCosignature mocks base method. +func (m *MockClient) AddCosignature(arg0 context.Context, arg1 requests.Cosignature) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddCosignature", arg0, arg1) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// AddCosignature indicates an expected call of AddCosignature. +func (mr *MockClientMockRecorder) AddCosignature(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddCosignature", reflect.TypeOf((*MockClient)(nil).AddCosignature), arg0, arg1) +} + +// AddLeaf mocks base method. +func (m *MockClient) AddLeaf(arg0 context.Context, arg1 requests.Leaf) (bool, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1) +	ret0, _ := ret[0].(bool) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddLeaf indicates an expected call of AddLeaf. +func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1) +} + +// GetConsistencyProof mocks base method. +func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 requests.ConsistencyProof) (types.ConsistencyProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1) +	ret0, _ := ret[0].(types.ConsistencyProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1) +} + +// GetCosignedTreeHead mocks base method. +func (m *MockClient) GetCosignedTreeHead(arg0 context.Context) (types.CosignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetCosignedTreeHead", arg0) +	ret0, _ := ret[0].(types.CosignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetCosignedTreeHead indicates an expected call of GetCosignedTreeHead. +func (mr *MockClientMockRecorder) GetCosignedTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCosignedTreeHead", reflect.TypeOf((*MockClient)(nil).GetCosignedTreeHead), arg0) +} + +// GetInclusionProof mocks base method. +func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 requests.InclusionProof) (types.InclusionProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1) +	ret0, _ := ret[0].(types.InclusionProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1) +} + +// GetLeaves mocks base method. +func (m *MockClient) GetLeaves(arg0 context.Context, arg1 requests.Leaves) (types.Leaves, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1) +	ret0, _ := ret[0].(types.Leaves) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeaves indicates an expected call of GetLeaves. +func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1) +} + +// GetToCosignTreeHead mocks base method. +func (m *MockClient) GetToCosignTreeHead(arg0 context.Context) (types.SignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetToCosignTreeHead", arg0) +	ret0, _ := ret[0].(types.SignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetToCosignTreeHead indicates an expected call of GetToCosignTreeHead. +func (mr *MockClientMockRecorder) GetToCosignTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToCosignTreeHead", reflect.TypeOf((*MockClient)(nil).GetToCosignTreeHead), arg0) +} + +// GetUnsignedTreeHead mocks base method. +func (m *MockClient) GetUnsignedTreeHead(arg0 context.Context) (types.TreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetUnsignedTreeHead", arg0) +	ret0, _ := ret[0].(types.TreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetUnsignedTreeHead indicates an expected call of GetUnsignedTreeHead. +func (mr *MockClientMockRecorder) GetUnsignedTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUnsignedTreeHead", reflect.TypeOf((*MockClient)(nil).GetUnsignedTreeHead), arg0) +} + +// Initiated mocks base method. +func (m *MockClient) Initiated() bool { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Initiated") +	ret0, _ := ret[0].(bool) +	return ret0 +} + +// Initiated indicates an expected call of Initiated. +func (mr *MockClientMockRecorder) Initiated() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initiated", reflect.TypeOf((*MockClient)(nil).Initiated)) +} diff --git a/internal/mocks/crypto/crypto.go b/internal/mocks/crypto/crypto.go new file mode 100644 index 0000000..0871e79 --- /dev/null +++ b/internal/mocks/crypto/crypto.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto (interfaces: Signer) + +// Package crypto is a generated GoMock package. +package crypto + +import ( +	crypto "crypto" +	io "io" +	reflect "reflect" + +	gomock "github.com/golang/mock/gomock" +) + +// MockSigner is a mock of Signer interface. +type MockSigner struct { +	ctrl     *gomock.Controller +	recorder *MockSignerMockRecorder +} + +// MockSignerMockRecorder is the mock recorder for MockSigner. +type MockSignerMockRecorder struct { +	mock *MockSigner +} + +// NewMockSigner creates a new mock instance. +func NewMockSigner(ctrl *gomock.Controller) *MockSigner { +	mock := &MockSigner{ctrl: ctrl} +	mock.recorder = &MockSignerMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSigner) EXPECT() *MockSignerMockRecorder { +	return m.recorder +} + +// Public mocks base method. +func (m *MockSigner) Public() crypto.PublicKey { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Public") +	ret0, _ := ret[0].(crypto.PublicKey) +	return ret0 +} + +// Public indicates an expected call of Public. +func (mr *MockSignerMockRecorder) Public() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Public", reflect.TypeOf((*MockSigner)(nil).Public)) +} + +// Sign mocks base method. +func (m *MockSigner) Sign(arg0 io.Reader, arg1 []byte, arg2 crypto.SignerOpts) ([]byte, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2) +	ret0, _ := ret[0].([]byte) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// Sign indicates an expected call of Sign. +func (mr *MockSignerMockRecorder) Sign(arg0, arg1, arg2 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*MockSigner)(nil).Sign), arg0, arg1, arg2) +} diff --git a/internal/mocks/db/db.go b/internal/mocks/db/db.go new file mode 100644 index 0000000..b96328d --- /dev/null +++ b/internal/mocks/db/db.go @@ -0,0 +1,126 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sigsum.org/log-go/internal/db (interfaces: Client) + +// Package db is a generated GoMock package. +package db + +import ( +	context "context" +	reflect "reflect" + +	requests "git.sigsum.org/sigsum-go/pkg/requests" +	types "git.sigsum.org/sigsum-go/pkg/types" +	gomock "github.com/golang/mock/gomock" +) + +// MockClient is a mock of Client interface. +type MockClient struct { +	ctrl     *gomock.Controller +	recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { +	mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { +	mock := &MockClient{ctrl: ctrl} +	mock.recorder = &MockClientMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { +	return m.recorder +} + +// AddLeaf mocks base method. +func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *requests.Leaf, arg2 uint64) (bool, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1, arg2) +	ret0, _ := ret[0].(bool) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddLeaf indicates an expected call of AddLeaf. +func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1, arg2 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1, arg2) +} + +// AddSequencedLeaves mocks base method. +func (m *MockClient) AddSequencedLeaves(arg0 context.Context, arg1 types.Leaves, arg2 int64) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddSequencedLeaves", arg0, arg1, arg2) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. +func (mr *MockClientMockRecorder) AddSequencedLeaves(arg0, arg1, arg2 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockClient)(nil).AddSequencedLeaves), arg0, arg1, arg2) +} + +// GetConsistencyProof mocks base method. +func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *requests.ConsistencyProof) (*types.ConsistencyProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1) +	ret0, _ := ret[0].(*types.ConsistencyProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1) +} + +// GetInclusionProof mocks base method. +func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *requests.InclusionProof) (*types.InclusionProof, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1) +	ret0, _ := ret[0].(*types.InclusionProof) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1) +} + +// GetLeaves mocks base method. +func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *requests.Leaves) (*types.Leaves, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1) +	ret0, _ := ret[0].(*types.Leaves) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeaves indicates an expected call of GetLeaves. +func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1) +} + +// GetTreeHead mocks base method. +func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "GetTreeHead", arg0) +	ret0, _ := ret[0].(*types.TreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetTreeHead indicates an expected call of GetTreeHead. +func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0) +} diff --git a/internal/mocks/node/handler/handler.go b/internal/mocks/node/handler/handler.go new file mode 100644 index 0000000..97cac8e --- /dev/null +++ b/internal/mocks/node/handler/handler.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sigsum.org/log-go/internal/node/handler (interfaces: Config) + +// Package handler is a generated GoMock package. +package handler + +import ( +	reflect "reflect" +	time "time" + +	gomock "github.com/golang/mock/gomock" +) + +// MockConfig is a mock of Config interface. +type MockConfig struct { +	ctrl     *gomock.Controller +	recorder *MockConfigMockRecorder +} + +// MockConfigMockRecorder is the mock recorder for MockConfig. +type MockConfigMockRecorder struct { +	mock *MockConfig +} + +// NewMockConfig creates a new mock instance. +func NewMockConfig(ctrl *gomock.Controller) *MockConfig { +	mock := &MockConfig{ctrl: ctrl} +	mock.recorder = &MockConfigMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConfig) EXPECT() *MockConfigMockRecorder { +	return m.recorder +} + +// Deadline mocks base method. +func (m *MockConfig) Deadline() time.Duration { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Deadline") +	ret0, _ := ret[0].(time.Duration) +	return ret0 +} + +// Deadline indicates an expected call of Deadline. +func (mr *MockConfigMockRecorder) Deadline() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Deadline", reflect.TypeOf((*MockConfig)(nil).Deadline)) +} + +// LogID mocks base method. +func (m *MockConfig) LogID() string { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "LogID") +	ret0, _ := ret[0].(string) +	return ret0 +} + +// LogID indicates an expected call of LogID. +func (mr *MockConfigMockRecorder) LogID() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogID", reflect.TypeOf((*MockConfig)(nil).LogID)) +} + +// Prefix mocks base method. +func (m *MockConfig) Prefix() string { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Prefix") +	ret0, _ := ret[0].(string) +	return ret0 +} + +// Prefix indicates an expected call of Prefix. +func (mr *MockConfigMockRecorder) Prefix() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prefix", reflect.TypeOf((*MockConfig)(nil).Prefix)) +} diff --git a/internal/mocks/state/state.go b/internal/mocks/state/state.go new file mode 100644 index 0000000..52dfb09 --- /dev/null +++ b/internal/mocks/state/state.go @@ -0,0 +1,91 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sigsum.org/log-go/internal/state (interfaces: StateManager) + +// Package state is a generated GoMock package. +package state + +import ( +	context "context" +	reflect "reflect" + +	types "git.sigsum.org/sigsum-go/pkg/types" +	gomock "github.com/golang/mock/gomock" +) + +// MockStateManager is a mock of StateManager interface. +type MockStateManager struct { +	ctrl     *gomock.Controller +	recorder *MockStateManagerMockRecorder +} + +// MockStateManagerMockRecorder is the mock recorder for MockStateManager. +type MockStateManagerMockRecorder struct { +	mock *MockStateManager +} + +// NewMockStateManager creates a new mock instance. +func NewMockStateManager(ctrl *gomock.Controller) *MockStateManager { +	mock := &MockStateManager{ctrl: ctrl} +	mock.recorder = &MockStateManagerMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateManager) EXPECT() *MockStateManagerMockRecorder { +	return m.recorder +} + +// AddCosignature mocks base method. +func (m *MockStateManager) AddCosignature(arg0 context.Context, arg1 *types.PublicKey, arg2 *types.Signature) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "AddCosignature", arg0, arg1, arg2) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// AddCosignature indicates an expected call of AddCosignature. +func (mr *MockStateManagerMockRecorder) AddCosignature(arg0, arg1, arg2 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddCosignature", reflect.TypeOf((*MockStateManager)(nil).AddCosignature), arg0, arg1, arg2) +} + +// CosignedTreeHead mocks base method. +func (m *MockStateManager) CosignedTreeHead(arg0 context.Context) (*types.CosignedTreeHead, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "CosignedTreeHead", arg0) +	ret0, _ := ret[0].(*types.CosignedTreeHead) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// CosignedTreeHead indicates an expected call of CosignedTreeHead. +func (mr *MockStateManagerMockRecorder) CosignedTreeHead(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CosignedTreeHead", reflect.TypeOf((*MockStateManager)(nil).CosignedTreeHead), arg0) +} + +// Run mocks base method. +func (m *MockStateManager) Run(arg0 context.Context) { +	m.ctrl.T.Helper() +	m.ctrl.Call(m, "Run", arg0) +} + +// Run indicates an expected call of Run. +func (mr *MockStateManagerMockRecorder) Run(arg0 interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockStateManager)(nil).Run), arg0) +} + +// ToCosignTreeHead mocks base method. +func (m *MockStateManager) ToCosignTreeHead() *types.SignedTreeHead { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "ToCosignTreeHead") +	ret0, _ := ret[0].(*types.SignedTreeHead) +	return ret0 +} + +// ToCosignTreeHead indicates an expected call of ToCosignTreeHead. +func (mr *MockStateManagerMockRecorder) ToCosignTreeHead() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToCosignTreeHead", reflect.TypeOf((*MockStateManager)(nil).ToCosignTreeHead)) +} diff --git a/internal/mocks/trillian/trillian.go b/internal/mocks/trillian/trillian.go new file mode 100644 index 0000000..b923e23 --- /dev/null +++ b/internal/mocks/trillian/trillian.go @@ -0,0 +1,317 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/google/trillian (interfaces: TrillianLogClient) + +// Package trillian is a generated GoMock package. +package trillian + +import ( +	context "context" +	reflect "reflect" + +	gomock "github.com/golang/mock/gomock" +	trillian "github.com/google/trillian" +	grpc "google.golang.org/grpc" +) + +// MockTrillianLogClient is a mock of TrillianLogClient interface. +type MockTrillianLogClient struct { +	ctrl     *gomock.Controller +	recorder *MockTrillianLogClientMockRecorder +} + +// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient. +type MockTrillianLogClientMockRecorder struct { +	mock *MockTrillianLogClient +} + +// NewMockTrillianLogClient creates a new mock instance. +func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient { +	mock := &MockTrillianLogClient{ctrl: ctrl} +	mock.recorder = &MockTrillianLogClientMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder { +	return m.recorder +} + +// AddSequencedLeaf mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...) +	ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddSequencedLeaf indicates an expected call of AddSequencedLeaf. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...) +} + +// AddSequencedLeaves mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...) +	ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...) +} + +// GetConsistencyProof mocks base method. +func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...) +} + +// GetEntryAndProof mocks base method. +func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetEntryAndProof indicates an expected call of GetEntryAndProof. +func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...) +} + +// GetInclusionProof mocks base method. +func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetInclusionProof", varargs...) +	ret0, _ := ret[0].(*trillian.GetInclusionProofResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...) +} + +// GetInclusionProofByHash mocks base method. +func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...) +	ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...) +} + +// GetLatestSignedLogRoot mocks base method. +func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...) +	ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot. +func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...) +} + +// GetLeavesByHash mocks base method. +func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByHash indicates an expected call of GetLeavesByHash. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...) +} + +// GetLeavesByIndex mocks base method. +func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByIndex indicates an expected call of GetLeavesByIndex. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...) +} + +// GetLeavesByRange mocks base method. +func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...) +	ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetLeavesByRange indicates an expected call of GetLeavesByRange. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...) +} + +// GetSequencedLeafCount mocks base method. +func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...) +	ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount. +func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...) +} + +// InitLog mocks base method. +func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "InitLog", varargs...) +	ret0, _ := ret[0].(*trillian.InitLogResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// InitLog indicates an expected call of InitLog. +func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...) +} + +// QueueLeaf mocks base method. +func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "QueueLeaf", varargs...) +	ret0, _ := ret[0].(*trillian.QueueLeafResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// QueueLeaf indicates an expected call of QueueLeaf. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...) +} + +// QueueLeaves mocks base method. +func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) { +	m.ctrl.T.Helper() +	varargs := []interface{}{arg0, arg1} +	for _, a := range arg2 { +		varargs = append(varargs, a) +	} +	ret := m.ctrl.Call(m, "QueueLeaves", varargs...) +	ret0, _ := ret[0].(*trillian.QueueLeavesResponse) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// QueueLeaves indicates an expected call of QueueLeaves. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	varargs := append([]interface{}{arg0, arg1}, arg2...) +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...) +} diff --git a/internal/node/handler/handler.go b/internal/node/handler/handler.go new file mode 100644 index 0000000..2871c5d --- /dev/null +++ b/internal/node/handler/handler.go @@ -0,0 +1,91 @@ +package handler + +import ( +	"context" +	"fmt" +	"net/http" +	"time" + +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +type Config interface { +	Prefix() string +	LogID() string +	Deadline() time.Duration +} + +// Handler implements the http.Handler interface +type Handler struct { +	Config +	Fun      func(context.Context, Config, http.ResponseWriter, *http.Request) (int, error) +	Endpoint types.Endpoint +	Method   string +} + +// Path returns a path that should be configured for this handler +func (h Handler) Path() string { +	if len(h.Prefix()) == 0 { +		return h.Endpoint.Path("", "sigsum", "v0") +	} +	return h.Endpoint.Path("", h.Prefix(), "sigsum", "v0") +} + +// ServeHTTP is part of the http.Handler interface +func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +	start := time.Now() +	code := 0 +	defer func() { +		end := time.Now().Sub(start).Seconds() +		sc := fmt.Sprintf("%d", code) + +		rspcnt.Inc(h.LogID(), string(h.Endpoint), sc) +		latency.Observe(end, h.LogID(), string(h.Endpoint), sc) +	}() +	reqcnt.Inc(h.LogID(), string(h.Endpoint)) + +	code = h.verifyMethod(w, r) +	if code != 0 { +		return +	} +	h.handle(w, r) +} + +// verifyMethod checks that an appropriate HTTP method is used and +// returns 0 if so, or an HTTP status code if not.  Error handling is +// based on RFC 7231, see Sections 6.5.5 (Status 405) and 6.5.1 +// (Status 400). +func (h Handler) verifyMethod(w http.ResponseWriter, r *http.Request) int { +	checkHTTPMethod := func(m string) bool { +		return m == http.MethodGet || m == http.MethodPost +	} + +	if h.Method == r.Method { +		return 0 +	} + +	code := http.StatusBadRequest +	if ok := checkHTTPMethod(r.Method); ok { +		w.Header().Set("Allow", h.Method) +		code = http.StatusMethodNotAllowed +	} + +	http.Error(w, fmt.Sprintf("error=%s", http.StatusText(code)), code) +	return code +} + +// handle handles an HTTP request for which the HTTP method is already verified +func (h Handler) handle(w http.ResponseWriter, r *http.Request) { +	deadline := time.Now().Add(h.Deadline()) +	ctx, cancel := context.WithDeadline(r.Context(), deadline) +	defer cancel() + +	code, err := h.Fun(ctx, h.Config, w, r) +	if err != nil { +		log.Debug("%s/%s: %v", h.Prefix(), h.Endpoint, err) +		http.Error(w, fmt.Sprintf("error=%s", err.Error()), code) +	} else if code != 200 { +		w.WriteHeader(code) +	} +} diff --git a/internal/node/handler/handler_test.go b/internal/node/handler/handler_test.go new file mode 100644 index 0000000..dfd27bd --- /dev/null +++ b/internal/node/handler/handler_test.go @@ -0,0 +1,113 @@ +package handler + +import ( +	"context" +	"net/http" +	"net/http/httptest" +	"testing" +	"time" + +	"git.sigsum.org/sigsum-go/pkg/types" +) + +type dummyConfig struct { +	prefix string +} + +func (c dummyConfig) Prefix() string          { return c.prefix } +func (c dummyConfig) LogID() string           { return "dummyLogID" } +func (c dummyConfig) Deadline() time.Duration { return time.Nanosecond } + +// TestPath checks that Path works for an endpoint (add-leaf) +func TestPath(t *testing.T) { +	testFun := func(_ context.Context, _ Config, _ http.ResponseWriter, _ *http.Request) (int, error) { +		return 0, nil +	} +	for _, table := range []struct { +		description string +		prefix      string +		want        string +	}{ +		{ +			description: "no prefix", +			want:        "/sigsum/v0/add-leaf", +		}, +		{ +			description: "a prefix", +			prefix:      "test-prefix", +			want:        "/test-prefix/sigsum/v0/add-leaf", +		}, +	} { +		testConfig := dummyConfig{ +			prefix: table.prefix, +		} +		h := Handler{testConfig, testFun, types.EndpointAddLeaf, http.MethodPost} +		if got, want := h.Path(), table.want; got != want { +			t.Errorf("got path %v but wanted %v", got, want) +		} +	} +} + +// func TestServeHTTP(t *testing.T) { +// 	h.ServeHTTP(w http.ResponseWriter, r *http.Request) +// } + +func TestVerifyMethod(t *testing.T) { +	badMethod := http.MethodHead +	for _, h := range []Handler{ +		{ +			Endpoint: types.EndpointAddLeaf, +			Method:   http.MethodPost, +		}, +		{ +			Endpoint: types.EndpointGetTreeHeadToCosign, +			Method:   http.MethodGet, +		}, +	} { +		for _, method := range []string{ +			http.MethodGet, +			http.MethodPost, +			badMethod, +		} { +			url := h.Endpoint.Path("http://log.example.com", "fixme") +			req, err := http.NewRequest(method, url, nil) +			if err != nil { +				t.Fatalf("must create HTTP request: %v", err) +			} + +			w := httptest.NewRecorder() +			code := h.verifyMethod(w, req) +			if got, want := code == 0, h.Method == method; got != want { +				t.Errorf("%s %s: got %v but wanted %v: %v", method, url, got, want, err) +				continue +			} +			if code == 0 { +				continue +			} + +			if method == badMethod { +				if got, want := code, http.StatusBadRequest; got != want { +					t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) +				} +				if _, ok := w.Header()["Allow"]; ok { +					t.Errorf("%s %s: got Allow header, wanted none", method, url) +				} +				continue +			} + +			if got, want := code, http.StatusMethodNotAllowed; got != want { +				t.Errorf("%s %s: got status %d, wanted %d", method, url, got, want) +			} else if methods, ok := w.Header()["Allow"]; !ok { +				t.Errorf("%s %s: got no allow header, expected one", method, url) +			} else if got, want := len(methods), 1; got != want { +				t.Errorf("%s %s: got %d allowed method(s), wanted %d", method, url, got, want) +			} else if got, want := methods[0], h.Method; got != want { +				t.Errorf("%s %s: got allowed method %s, wanted %s", method, url, got, want) +			} +		} +	} +} + +// func TestHandle(t *testing.T) { +// 	h.handle(w http.ResponseWriter, r *http.Request) +// } diff --git a/internal/node/handler/metric.go b/internal/node/handler/metric.go new file mode 100644 index 0000000..ced0096 --- /dev/null +++ b/internal/node/handler/metric.go @@ -0,0 +1,19 @@ +package handler + +import ( +	"github.com/google/trillian/monitoring" +	"github.com/google/trillian/monitoring/prometheus" +) + +var ( +	reqcnt  monitoring.Counter   // number of incoming http requests +	rspcnt  monitoring.Counter   // number of valid http responses +	latency monitoring.Histogram // request-response latency +) + +func init() { +	mf := prometheus.MetricFactory{} +	reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint") +	rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status") +	latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status") +} diff --git a/internal/node/primary/endpoint_external.go b/internal/node/primary/endpoint_external.go new file mode 100644 index 0000000..42e8fcb --- /dev/null +++ b/internal/node/primary/endpoint_external.go @@ -0,0 +1,147 @@ +package primary + +// This file implements external HTTP handler callbacks for primary nodes. + +import ( +	"context" +	"fmt" +	"net/http" + +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/log-go/internal/requests" +	"git.sigsum.org/sigsum-go/pkg/log" +) + +func addLeaf(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling add-leaf request") +	req, err := requests.LeafRequestFromHTTP(r, p.Config.ShardStart, ctx, p.DNS) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	sth := p.Stateman.ToCosignTreeHead() +	sequenced, err := p.TrillianClient.AddLeaf(ctx, req, sth.TreeSize) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if sequenced { +		return http.StatusOK, nil +	} else { +		return http.StatusAccepted, nil +	} +} + +func addCosignature(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling add-cosignature request") +	req, err := requests.CosignatureRequestFromHTTP(r, p.Witnesses) +	if err != nil { +		return http.StatusBadRequest, err +	} +	vk := p.Witnesses[req.KeyHash] +	if err := p.Stateman.AddCosignature(ctx, &vk, &req.Cosignature); err != nil { +		return http.StatusBadRequest, err +	} +	return http.StatusOK, nil +} + +func getTreeHeadToCosign(ctx context.Context, c handler.Config, w http.ResponseWriter, _ *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling get-tree-head-to-cosign request") +	sth := p.Stateman.ToCosignTreeHead() +	if err := sth.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getTreeHeadCosigned(ctx context.Context, c handler.Config, w http.ResponseWriter, _ *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling get-tree-head-cosigned request") +	cth, err := p.Stateman.CosignedTreeHead(ctx) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := cth.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getConsistencyProof(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling get-consistency-proof request") +	req, err := requests.ConsistencyProofRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	curTree := p.Stateman.ToCosignTreeHead() +	if req.NewSize > curTree.TreeHead.TreeSize { +		return http.StatusBadRequest, fmt.Errorf("new_size outside of current tree") +	} + +	proof, err := p.TrillianClient.GetConsistencyProof(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := proof.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getInclusionProof(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	p := c.(Primary) +	log.Debug("handling get-inclusion-proof request") +	req, err := requests.InclusionProofRequestFromHTTP(r) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	curTree := p.Stateman.ToCosignTreeHead() +	if req.TreeSize > curTree.TreeHead.TreeSize { +		return http.StatusBadRequest, fmt.Errorf("tree_size outside of current tree") +	} + +	proof, err := p.TrillianClient.GetInclusionProof(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := proof.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getLeavesGeneral(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request, doLimitToCurrentTree bool) (int, error) { +	p := c.(Primary) +	log.Debug("handling get-leaves request") +	req, err := requests.LeavesRequestFromHTTP(r, uint64(p.MaxRange)) +	if err != nil { +		return http.StatusBadRequest, err +	} + +	if doLimitToCurrentTree { +		curTree := p.Stateman.ToCosignTreeHead() +		if req.EndSize >= curTree.TreeHead.TreeSize { +			return http.StatusBadRequest, fmt.Errorf("end_size outside of current tree") +		} +	} + +	leaves, err := p.TrillianClient.GetLeaves(ctx, req) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	for _, leaf := range *leaves { +		if err := leaf.ToASCII(w); err != nil { +			return http.StatusInternalServerError, err +		} +	} +	return http.StatusOK, nil +} + +func getLeavesExternal(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	return getLeavesGeneral(ctx, c, w, r, true) +} diff --git a/internal/node/primary/endpoint_external_test.go b/internal/node/primary/endpoint_external_test.go new file mode 100644 index 0000000..7ee161b --- /dev/null +++ b/internal/node/primary/endpoint_external_test.go @@ -0,0 +1,626 @@ +package primary + +import ( +	"bytes" +	"crypto/ed25519" +	"crypto/rand" +	"fmt" +	"io" +	"net/http" +	"net/http/httptest" +	"reflect" +	"testing" +	"time" + +	mocksDB "git.sigsum.org/log-go/internal/mocks/db" +	mocksDNS "git.sigsum.org/log-go/internal/mocks/dns" +	mocksState "git.sigsum.org/log-go/internal/mocks/state" +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +var ( +	testSTH = &types.SignedTreeHead{ +		TreeHead:  *testTH, +		Signature: types.Signature{}, +	} +	testCTH = &types.CosignedTreeHead{ +		SignedTreeHead: *testSTH, +		Cosignature:    []types.Signature{types.Signature{}}, +		KeyHash:        []merkle.Hash{merkle.Hash{}}, +	} +	sth1 = types.SignedTreeHead{TreeHead: types.TreeHead{TreeSize: 1}} +	sth2 = types.SignedTreeHead{TreeHead: types.TreeHead{TreeSize: 2}} // 2 < testConfig.MaxRange +	sth5 = types.SignedTreeHead{TreeHead: types.TreeHead{TreeSize: 5}} // 5 >= testConfig.MaxRange+1 +) + +// TODO: remove tests that are now located in internal/requests instead + +func TestAddLeaf(t *testing.T) { +	for _, table := range []struct { +		description    string +		ascii          io.Reader // buffer used to populate HTTP request +		expectTrillian bool      // expect Trillian client code path +		errTrillian    error     // error from Trillian client +		expectDNS      bool      // expect DNS verifier code path +		errDNS         error     // error from DNS verifier +		wantCode       int       // HTTP status ok +		expectStateman bool +		sequenced      bool // return value from db.AddLeaf() +		sthStateman    *types.SignedTreeHead +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (signature error)", +			ascii:       mustLeafBuffer(t, 10, merkle.Hash{}, false), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (shard hint is before shard start)", +			ascii:       mustLeafBuffer(t, 9, merkle.Hash{}, true), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (shard hint is after shard end)", +			ascii:       mustLeafBuffer(t, uint64(time.Now().Unix())+1024, merkle.Hash{}, true), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: failed verifying domain hint", +			ascii:       mustLeafBuffer(t, 10, merkle.Hash{}, true), +			expectDNS:   true, +			errDNS:      fmt.Errorf("something went wrong"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description:    "invalid: backend failure", +			ascii:          mustLeafBuffer(t, 10, merkle.Hash{}, true), +			expectDNS:      true, +			expectStateman: true, +			sthStateman:    testSTH, +			expectTrillian: true, +			errTrillian:    fmt.Errorf("something went wrong"), +			wantCode:       http.StatusInternalServerError, +		}, +		{ +			description:    "valid: 202", +			ascii:          mustLeafBuffer(t, 10, merkle.Hash{}, true), +			expectDNS:      true, +			expectStateman: true, +			sthStateman:    testSTH, +			expectTrillian: true, +			wantCode:       http.StatusAccepted, +		}, +		{ +			description:    "valid: 200", +			ascii:          mustLeafBuffer(t, 10, merkle.Hash{}, true), +			expectDNS:      true, +			expectStateman: true, +			sthStateman:    testSTH, +			expectTrillian: true, +			sequenced:      true, +			wantCode:       http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			dns := mocksDNS.NewMockVerifier(ctrl) +			if table.expectDNS { +				dns.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.errDNS) +			} +			client := mocksDB.NewMockClient(ctrl) +			if table.expectTrillian { +				client.EXPECT().AddLeaf(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.sequenced, table.errTrillian) +			} +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.expectStateman { +				stateman.EXPECT().ToCosignTreeHead().Return(table.sthStateman) +			} +			node := Primary{ +				Config:         testConfig, +				TrillianClient: client, +				Stateman:       stateman, +				DNS:            dns, +			} + +			// Create HTTP request +			url := types.EndpointAddLeaf.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointAddLeaf).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestAddCosignature(t *testing.T) { +	buf := func() io.Reader { +		return bytes.NewBufferString(fmt.Sprintf("%s=%x\n%s=%x\n", +			"cosignature", types.Signature{}, +			"key_hash", *merkle.HashFn(testWitVK[:]), +		)) +	} +	for _, table := range []struct { +		description string +		ascii       io.Reader // buffer used to populate HTTP request +		expect      bool      // set if a mock answer is expected +		err         error     // error from Trillian client +		wantCode    int       // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			ascii:       bytes.NewBufferString("key=value\n"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (unknown witness)", +			ascii: bytes.NewBufferString(fmt.Sprintf("%s=%x\n%s=%x\n", +				"cosignature", types.Signature{}, +				"key_hash", *merkle.HashFn(testWitVK[1:]), +			)), +			wantCode: http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			ascii:       buf(), +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "valid", +			ascii:       buf(), +			expect:      true, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().AddCosignature(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.err) +			} +			node := Primary{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointAddCosignature.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("POST", url, table.ascii) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointAddCosignature).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetTreeToCosign(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool                  // set if a mock answer is expected +		rsp         *types.SignedTreeHead // signed tree head from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "valid", +			expect:      true, +			rsp:         testSTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().ToCosignTreeHead().Return(table.rsp) +			} +			node := Primary{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadToCosign.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointGetTreeHeadToCosign).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetTreeCosigned(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool                    // set if a mock answer is expected +		rsp         *types.CosignedTreeHead // cosigned tree head from Trillian client +		err         error                   // error from Trillian client +		wantCode    int                     // HTTP status ok +	}{ +		{ +			description: "invalid: no cosigned STH", +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			expect:      true, +			rsp:         testCTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.expect { +				stateman.EXPECT().CosignedTreeHead(gomock.Any()).Return(table.rsp, table.err) +			} +			node := Primary{ +				Config:   testConfig, +				Stateman: stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadCosigned.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointGetTreeHeadCosigned).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetConsistencyProof(t *testing.T) { +	for _, table := range []struct { +		description string +		params      string // params is the query's url params +		sth         *types.SignedTreeHead +		expect      bool                    // set if a mock answer is expected +		rsp         *types.ConsistencyProof // consistency proof from Trillian client +		err         error                   // error from Trillian client +		wantCode    int                     // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			params:      "a/1", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (OldSize is zero)", +			params:      "0/1", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (OldSize > NewSize)", +			params:      "2/1", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (NewSize > tree size)", +			params:      "1/2", +			sth:         &sth1, +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			params:      "1/2", +			sth:         &sth2, +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			params:      "1/2", +			sth:         &sth2, +			expect:      true, +			rsp: &types.ConsistencyProof{ +				OldSize: 1, +				NewSize: 2, +				Path: []merkle.Hash{ +					*merkle.HashFn([]byte{}), +				}, +			}, +			wantCode: http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocksDB.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.sth != nil { +				stateman.EXPECT().ToCosignTreeHead().Return(table.sth) +			} +			node := Primary{ +				Config:         testConfig, +				TrillianClient: client, +				Stateman:       stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetConsistencyProof.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointGetConsistencyProof).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetInclusionProof(t *testing.T) { +	for _, table := range []struct { +		description string +		params      string // params is the query's url params +		sth         *types.SignedTreeHead +		expect      bool                  // set if a mock answer is expected +		rsp         *types.InclusionProof // inclusion proof from Trillian client +		err         error                 // error from Trillian client +		wantCode    int                   // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			params:      "a/0000000000000000000000000000000000000000000000000000000000000000", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (no proof available for tree size 1)", +			params:      "1/0000000000000000000000000000000000000000000000000000000000000000", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (request outside current tree size)", +			params:      "2/0000000000000000000000000000000000000000000000000000000000000000", +			sth:         &sth1, +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			params:      "2/0000000000000000000000000000000000000000000000000000000000000000", +			sth:         &sth2, +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			params:      "2/0000000000000000000000000000000000000000000000000000000000000000", +			sth:         &sth2, +			expect:      true, +			rsp: &types.InclusionProof{ +				TreeSize:  2, +				LeafIndex: 0, +				Path: []merkle.Hash{ +					*merkle.HashFn([]byte{}), +				}, +			}, +			wantCode: http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocksDB.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetInclusionProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.sth != nil { +				stateman.EXPECT().ToCosignTreeHead().Return(table.sth) +			} +			node := Primary{ +				Config:         testConfig, +				TrillianClient: client, +				Stateman:       stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetInclusionProof.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointGetInclusionProof).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func TestGetLeaves(t *testing.T) { +	for _, table := range []struct { +		description string +		params      string // params is the query's url params +		sth         *types.SignedTreeHead +		expect      bool          // set if a mock answer is expected +		rsp         *types.Leaves // list of leaves from Trillian client +		err         error         // error from Trillian client +		wantCode    int           // HTTP status ok +	}{ +		{ +			description: "invalid: bad request (parser error)", +			params:      "a/1", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (StartSize > EndSize)", +			params:      "1/0", +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: bad request (EndSize >= current tree size)", +			params:      "0/2", +			sth:         &sth2, +			wantCode:    http.StatusBadRequest, +		}, +		{ +			description: "invalid: backend failure", +			params:      "0/0", +			sth:         &sth2, +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid: one more entry than the configured MaxRange", +			params:      fmt.Sprintf("%d/%d", 0, testConfig.MaxRange), // query will be pruned +			sth:         &sth5, +			expect:      true, +			rsp: func() *types.Leaves { +				var list types.Leaves +				for i := int64(0); i < testConfig.MaxRange; i++ { +					list = append(list[:], types.Leaf{ +						Statement: types.Statement{ +							ShardHint: 0, +							Checksum:  merkle.Hash{}, +						}, +						Signature: types.Signature{}, +						KeyHash:   merkle.Hash{}, +					}) +				} +				return &list +			}(), +			wantCode: http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			client := mocksDB.NewMockClient(ctrl) +			if table.expect { +				client.EXPECT().GetLeaves(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) +			} +			stateman := mocksState.NewMockStateManager(ctrl) +			if table.sth != nil { +				stateman.EXPECT().ToCosignTreeHead().Return(table.sth) +			} +			node := Primary{ +				Config:         testConfig, +				TrillianClient: client, +				Stateman:       stateman, +			} + +			// Create HTTP request +			url := types.EndpointGetLeaves.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandlePublic(t, node, types.EndpointGetLeaves).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +			if w.Code != http.StatusOK { +				return +			} + +			list := types.Leaves{} +			if err := list.FromASCII(w.Body); err != nil { +				t.Fatalf("must unmarshal leaf list: %v", err) +			} +			if got, want := &list, table.rsp; !reflect.DeepEqual(got, want) { +				t.Errorf("got leaf list\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) +			} +		}() +	} +} + +func mustHandlePublic(t *testing.T, p Primary, e types.Endpoint) handler.Handler { +	for _, handler := range p.PublicHTTPHandlers() { +		if handler.Endpoint == e { +			return handler +		} +	} +	t.Fatalf("must handle endpoint: %v", e) +	return handler.Handler{} +} + +func mustLeafBuffer(t *testing.T, shardHint uint64, message merkle.Hash, wantSig bool) io.Reader { +	t.Helper() + +	vk, sk, err := ed25519.GenerateKey(rand.Reader) +	if err != nil { +		t.Fatalf("must generate ed25519 keys: %v", err) +	} +	msg := types.Statement{ +		ShardHint: shardHint, +		Checksum:  *merkle.HashFn(message[:]), +	} +	sig := ed25519.Sign(sk, msg.ToBinary()) +	if !wantSig { +		sig[0] += 1 +	} +	return bytes.NewBufferString(fmt.Sprintf( +		"%s=%d\n"+"%s=%x\n"+"%s=%x\n"+"%s=%x\n"+"%s=%s\n", +		"shard_hint", shardHint, +		"message", message[:], +		"signature", sig, +		"public_key", vk, +		"domain_hint", "example.com", +	)) +} diff --git a/internal/node/primary/endpoint_internal.go b/internal/node/primary/endpoint_internal.go new file mode 100644 index 0000000..f7a684f --- /dev/null +++ b/internal/node/primary/endpoint_internal.go @@ -0,0 +1,30 @@ +package primary + +// This file implements internal HTTP handler callbacks for primary nodes. + +import ( +	"context" +	"fmt" +	"net/http" + +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +func getTreeHeadUnsigned(ctx context.Context, c handler.Config, w http.ResponseWriter, _ *http.Request) (int, error) { +	log.Debug("handling %s request", types.EndpointGetTreeHeadUnsigned) +	p := c.(Primary) +	th, err := p.TrillianClient.GetTreeHead(ctx) +	if err != nil { +		return http.StatusInternalServerError, fmt.Errorf("failed getting tree head: %v", err) +	} +	if err := th.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} + +func getLeavesInternal(ctx context.Context, c handler.Config, w http.ResponseWriter, r *http.Request) (int, error) { +	return getLeavesGeneral(ctx, c, w, r, false) +} diff --git a/internal/node/primary/endpoint_internal_test.go b/internal/node/primary/endpoint_internal_test.go new file mode 100644 index 0000000..6c76c33 --- /dev/null +++ b/internal/node/primary/endpoint_internal_test.go @@ -0,0 +1,82 @@ +package primary + +import ( +	"fmt" +	"net/http" +	"net/http/httptest" +	"testing" + +	mocksDB "git.sigsum.org/log-go/internal/mocks/db" +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +var ( +	testTH = &types.TreeHead{ +		Timestamp: 0, +		TreeSize:  0, +		RootHash:  *merkle.HashFn([]byte("root hash")), +	} +) + +func TestGetTreeHeadUnsigned(t *testing.T) { +	for _, table := range []struct { +		description string +		expect      bool            // set if a mock answer is expected +		rsp         *types.TreeHead // tree head from Trillian client +		err         error           // error from Trillian client +		wantCode    int             // HTTP status ok +	}{ +		{ +			description: "invalid: backend failure", +			expect:      true, +			err:         fmt.Errorf("something went wrong"), +			wantCode:    http.StatusInternalServerError, +		}, +		{ +			description: "valid", +			expect:      true, +			rsp:         testTH, +			wantCode:    http.StatusOK, +		}, +	} { +		// Run deferred functions at the end of each iteration +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			trillianClient := mocksDB.NewMockClient(ctrl) +			trillianClient.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) + +			node := Primary{ +				Config:         testConfig, +				TrillianClient: trillianClient, +			} + +			// Create HTTP request +			url := types.EndpointGetTreeHeadUnsigned.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandleInternal(t, node, types.EndpointGetTreeHeadUnsigned).ServeHTTP(w, req) +			if got, want := w.Code, table.wantCode; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, table.description) +			} +		}() +	} +} + +func mustHandleInternal(t *testing.T, p Primary, e types.Endpoint) handler.Handler { +	for _, handler := range p.InternalHTTPHandlers() { +		if handler.Endpoint == e { +			return handler +		} +	} +	t.Fatalf("must handle endpoint: %v", e) +	return handler.Handler{} +} diff --git a/internal/node/primary/primary.go b/internal/node/primary/primary.go new file mode 100644 index 0000000..6128f49 --- /dev/null +++ b/internal/node/primary/primary.go @@ -0,0 +1,74 @@ +package primary + +import ( +	"crypto" +	"net/http" +	"time" + +	"git.sigsum.org/log-go/internal/db" +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/log-go/internal/state" +	"git.sigsum.org/sigsum-go/pkg/client" +	"git.sigsum.org/sigsum-go/pkg/dns" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// Config is a collection of log parameters +type Config struct { +	LogID      string        // H(public key), then hex-encoded +	TreeID     int64         // Merkle tree identifier used by Trillian +	Prefix     string        // The portion between base URL and st/v0 (may be "") +	MaxRange   int64         // Maximum number of leaves per get-leaves request +	Deadline   time.Duration // Deadline used for gRPC requests +	Interval   time.Duration // Cosigning frequency +	ShardStart uint64        // Shard interval start (num seconds since UNIX epoch) + +	// Witnesses map trusted witness identifiers to public keys +	Witnesses map[merkle.Hash]types.PublicKey +} + +// Primary is an instance of the log's primary node +type Primary struct { +	Config +	PublicHTTPMux   *http.ServeMux +	InternalHTTPMux *http.ServeMux +	TrillianClient  db.Client          // provides access to the Trillian backend +	Signer          crypto.Signer      // provides access to Ed25519 private key +	Stateman        state.StateManager // coordinates access to (co)signed tree heads +	DNS             dns.Verifier       // checks if domain name knows a public key +	Secondary       client.Client +} + +// Implementing handler.Config +func (p Primary) Prefix() string { +	return p.Config.Prefix +} +func (p Primary) LogID() string { +	return p.Config.LogID +} +func (p Primary) Deadline() time.Duration { +	return p.Config.Deadline +} + +// PublicHTTPHandlers returns all external handlers +func (p Primary) PublicHTTPHandlers() []handler.Handler { +	return []handler.Handler{ +		handler.Handler{p, addLeaf, types.EndpointAddLeaf, http.MethodPost}, +		handler.Handler{p, addCosignature, types.EndpointAddCosignature, http.MethodPost}, +		handler.Handler{p, getTreeHeadToCosign, types.EndpointGetTreeHeadToCosign, http.MethodGet}, +		handler.Handler{p, getTreeHeadCosigned, types.EndpointGetTreeHeadCosigned, http.MethodGet}, +		handler.Handler{p, getConsistencyProof, types.EndpointGetConsistencyProof, http.MethodGet}, +		handler.Handler{p, getInclusionProof, types.EndpointGetInclusionProof, http.MethodGet}, +		handler.Handler{p, getLeavesExternal, types.EndpointGetLeaves, http.MethodGet}, +	} +} + +// InternalHTTPHandlers() returns all internal handlers +func (p Primary) InternalHTTPHandlers() []handler.Handler { +	return []handler.Handler{ +		handler.Handler{p, getTreeHeadUnsigned, types.EndpointGetTreeHeadUnsigned, http.MethodGet}, +		handler.Handler{p, getConsistencyProof, types.EndpointGetConsistencyProof, http.MethodGet}, +		handler.Handler{p, getLeavesInternal, types.EndpointGetLeaves, http.MethodGet}, +	} +} diff --git a/internal/node/primary/primary_test.go b/internal/node/primary/primary_test.go new file mode 100644 index 0000000..5062955 --- /dev/null +++ b/internal/node/primary/primary_test.go @@ -0,0 +1,75 @@ +package primary + +import ( +	"fmt" +	"testing" + +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +var ( +	testWitVK  = types.PublicKey{} +	testConfig = Config{ +		LogID:      fmt.Sprintf("%x", merkle.HashFn([]byte("logid"))[:]), +		TreeID:     0, +		Prefix:     "testonly", +		MaxRange:   3, +		Deadline:   10, +		Interval:   10, +		ShardStart: 10, +		Witnesses: map[merkle.Hash]types.PublicKey{ +			*merkle.HashFn(testWitVK[:]): testWitVK, +		}, +	} +) + +// TestPublicHandlers checks that the expected external handlers are configured +func TestPublicHandlers(t *testing.T) { +	endpoints := map[types.Endpoint]bool{ +		types.EndpointAddLeaf:             false, +		types.EndpointAddCosignature:      false, +		types.EndpointGetTreeHeadToCosign: false, +		types.EndpointGetTreeHeadCosigned: false, +		types.EndpointGetConsistencyProof: false, +		types.EndpointGetInclusionProof:   false, +		types.EndpointGetLeaves:           false, +	} +	node := Primary{ +		Config: testConfig, +	} +	for _, handler := range node.PublicHTTPHandlers() { +		if _, ok := endpoints[handler.Endpoint]; !ok { +			t.Errorf("got unexpected endpoint: %s", handler.Endpoint) +		} +		endpoints[handler.Endpoint] = true +	} +	for endpoint, ok := range endpoints { +		if !ok { +			t.Errorf("endpoint %s is not configured", endpoint) +		} +	} +} + +// TestIntHandlers checks that the expected internal handlers are configured +func TestIntHandlers(t *testing.T) { +	endpoints := map[types.Endpoint]bool{ +		types.EndpointGetTreeHeadUnsigned: false, +		types.EndpointGetConsistencyProof: false, +		types.EndpointGetLeaves:           false, +	} +	node := Primary{ +		Config: testConfig, +	} +	for _, handler := range node.InternalHTTPHandlers() { +		if _, ok := endpoints[handler.Endpoint]; !ok { +			t.Errorf("got unexpected endpoint: %s", handler.Endpoint) +		} +		endpoints[handler.Endpoint] = true +	} +	for endpoint, ok := range endpoints { +		if !ok { +			t.Errorf("endpoint %s is not configured", endpoint) +		} +	} +} diff --git a/internal/node/secondary/endpoint_internal.go b/internal/node/secondary/endpoint_internal.go new file mode 100644 index 0000000..f60d6d8 --- /dev/null +++ b/internal/node/secondary/endpoint_internal.go @@ -0,0 +1,44 @@ +package secondary + +// This file implements internal HTTP handler callbacks for secondary nodes. + +import ( +	"context" +	"crypto/ed25519" +	"fmt" +	"net/http" + +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +func getTreeHeadToCosign(ctx context.Context, c handler.Config, w http.ResponseWriter, _ *http.Request) (int, error) { +	s := c.(Secondary) +	log.Debug("handling get-tree-head-to-cosign request") + +	signedTreeHead := func() (*types.SignedTreeHead, error) { +		tctx, cancel := context.WithTimeout(ctx, s.Config.Deadline) +		defer cancel() +		th, err := treeHeadFromTrillian(tctx, s.TrillianClient) +		if err != nil { +			return nil, fmt.Errorf("getting tree head: %w", err) +		} +		namespace := merkle.HashFn(s.Signer.Public().(ed25519.PublicKey)) +		sth, err := th.Sign(s.Signer, namespace) +		if err != nil { +			return nil, fmt.Errorf("signing tree head: %w", err) +		} +		return sth, nil +	} + +	sth, err := signedTreeHead() +	if err != nil { +		return http.StatusInternalServerError, err +	} +	if err := sth.ToASCII(w); err != nil { +		return http.StatusInternalServerError, err +	} +	return http.StatusOK, nil +} diff --git a/internal/node/secondary/endpoint_internal_test.go b/internal/node/secondary/endpoint_internal_test.go new file mode 100644 index 0000000..9637e29 --- /dev/null +++ b/internal/node/secondary/endpoint_internal_test.go @@ -0,0 +1,111 @@ +package secondary + +import ( +	"crypto" +	"crypto/ed25519" +	"fmt" +	"io" +	"net/http" +	"net/http/httptest" +	"testing" + +	mocksDB "git.sigsum.org/log-go/internal/mocks/db" +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +// TestSigner implements the signer interface.  It can be used to mock +// an Ed25519 signer that always return the same public key, +// signature, and error. +// NOTE: Code duplication with internal/state/single_test.go +type TestSigner struct { +	PublicKey [ed25519.PublicKeySize]byte +	Signature [ed25519.SignatureSize]byte +	Error     error +} + +func (ts *TestSigner) Public() crypto.PublicKey { +	return ed25519.PublicKey(ts.PublicKey[:]) +} + +func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { +	return ts.Signature[:], ts.Error +} + +var ( +	testTH = types.TreeHead{ +		Timestamp: 0, +		TreeSize:  0, +		RootHash:  *merkle.HashFn([]byte("root hash")), +	} +	testSignerFailing    = TestSigner{types.PublicKey{}, types.Signature{}, fmt.Errorf("mocked error")} +	testSignerSucceeding = TestSigner{types.PublicKey{}, types.Signature{}, nil} +) + +func TestGetTreeHeadToCosign(t *testing.T) { +	for _, tbl := range []struct { +		desc          string +		trillianTHErr error +		trillianTHRet *types.TreeHead +		signer        crypto.Signer +		httpStatus    int +	}{ +		{ +			desc:          "trillian GetTreeHead error", +			trillianTHErr: fmt.Errorf("mocked error"), +			httpStatus:    http.StatusInternalServerError, +		}, +		{ +			desc:          "signer error", +			trillianTHRet: &testTH, +			signer:        &testSignerFailing, +			httpStatus:    http.StatusInternalServerError, +		}, +		{ +			desc:          "success", +			trillianTHRet: &testTH, +			signer:        &testSignerSucceeding, +			httpStatus:    http.StatusOK, +		}, +	} { +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() + +			trillianClient := mocksDB.NewMockClient(ctrl) +			trillianClient.EXPECT().GetTreeHead(gomock.Any()).Return(tbl.trillianTHRet, tbl.trillianTHErr) + +			node := Secondary{ +				Config:         testConfig, +				TrillianClient: trillianClient, +				Signer:         tbl.signer, +			} + +			// Create HTTP request +			url := types.EndpointAddLeaf.Path("http://example.com", node.Prefix()) +			req, err := http.NewRequest("GET", url, nil) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			// Run HTTP request +			w := httptest.NewRecorder() +			mustHandleInternal(t, node, types.EndpointGetTreeHeadToCosign).ServeHTTP(w, req) +			if got, want := w.Code, tbl.httpStatus; got != want { +				t.Errorf("got HTTP status code %v but wanted %v in test %q", got, want, tbl.desc) +			} +		}() +	} +} + +func mustHandleInternal(t *testing.T, s Secondary, e types.Endpoint) handler.Handler { +	for _, h := range s.InternalHTTPHandlers() { +		if h.Endpoint == e { +			return h +		} +	} +	t.Fatalf("must handle endpoint: %v", e) +	return handler.Handler{} +} diff --git a/internal/node/secondary/secondary.go b/internal/node/secondary/secondary.go new file mode 100644 index 0000000..c181420 --- /dev/null +++ b/internal/node/secondary/secondary.go @@ -0,0 +1,112 @@ +package secondary + +import ( +	"context" +	"crypto" +	"fmt" +	"net/http" +	"time" + +	"git.sigsum.org/log-go/internal/db" +	"git.sigsum.org/log-go/internal/node/handler" +	"git.sigsum.org/sigsum-go/pkg/client" +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// Config is a collection of log parameters +type Config struct { +	LogID    string        // H(public key), then hex-encoded +	TreeID   int64         // Merkle tree identifier used by Trillian +	Prefix   string        // The portion between base URL and st/v0 (may be "") +	Deadline time.Duration // Deadline used for gRPC requests +	Interval time.Duration // Signing frequency +} + +// Secondary is an instance of a secondary node +type Secondary struct { +	Config +	PublicHTTPMux   *http.ServeMux +	InternalHTTPMux *http.ServeMux +	TrillianClient  db.Client     // provides access to the Trillian backend +	Signer          crypto.Signer // provides access to Ed25519 private key +	Primary         client.Client +} + +// Implementing handler.Config +func (s Secondary) Prefix() string { +	return s.Config.Prefix +} +func (s Secondary) LogID() string { +	return s.Config.LogID +} +func (s Secondary) Deadline() time.Duration { +	return s.Config.Deadline +} + +func (s Secondary) Run(ctx context.Context) { +	ticker := time.NewTicker(s.Interval) +	defer ticker.Stop() + +	for { +		select { +		case <-ticker.C: +			s.fetchLeavesFromPrimary(ctx) +		case <-ctx.Done(): +			return +		} +	} +} + +// TODO: nit-pick: the internal endpoint is used by primaries to figure out how much can be signed; not cosigned - update name? + +func (s Secondary) InternalHTTPHandlers() []handler.Handler { +	return []handler.Handler{ +		handler.Handler{s, getTreeHeadToCosign, types.EndpointGetTreeHeadToCosign, http.MethodGet}, +	} +} + +func (s Secondary) fetchLeavesFromPrimary(ctx context.Context) { +	sctx, cancel := context.WithTimeout(ctx, time.Second*10) // FIXME: parameterize 10 +	defer cancel() + +	prim, err := s.Primary.GetUnsignedTreeHead(sctx) +	if err != nil { +		log.Warning("unable to get tree head from primary: %v", err) +		return +	} +	log.Debug("got tree head from primary, size %d", prim.TreeSize) + +	curTH, err := treeHeadFromTrillian(sctx, s.TrillianClient) +	if err != nil { +		log.Warning("unable to get tree head from trillian: %v", err) +		return +	} +	var leaves types.Leaves +	for index := int64(curTH.TreeSize); index < int64(prim.TreeSize); index += int64(len(leaves)) { +		req := requests.Leaves{ +			StartSize: uint64(index), +			EndSize:   prim.TreeSize - 1, +		} +		leaves, err = s.Primary.GetLeaves(sctx, req) +		if err != nil { +			log.Warning("error fetching leaves [%d..%d] from primary: %v", req.StartSize, req.EndSize, err) +			return +		} +		log.Debug("got %d leaves from primary when asking for [%d..%d]", len(leaves), req.StartSize, req.EndSize) +		if err := s.TrillianClient.AddSequencedLeaves(ctx, leaves, index); err != nil { +			log.Error("AddSequencedLeaves: %v", err) +			return +		} +	} +} + +func treeHeadFromTrillian(ctx context.Context, trillianClient db.Client) (*types.TreeHead, error) { +	th, err := trillianClient.GetTreeHead(ctx) +	if err != nil { +		return nil, fmt.Errorf("fetching tree head from trillian: %v", err) +	} +	log.Debug("got tree head from trillian, size %d", th.TreeSize) +	return th, nil +} diff --git a/internal/node/secondary/secondary_test.go b/internal/node/secondary/secondary_test.go new file mode 100644 index 0000000..164bdf6 --- /dev/null +++ b/internal/node/secondary/secondary_test.go @@ -0,0 +1,138 @@ +package secondary + +import ( +	"context" +	"fmt" +	"testing" + +	mocksClient "git.sigsum.org/log-go/internal/mocks/client" +	mocksDB "git.sigsum.org/log-go/internal/mocks/db" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +var ( +	testConfig = Config{ +		LogID:    fmt.Sprintf("%x", merkle.HashFn([]byte("logid"))[:]), +		TreeID:   0, +		Prefix:   "testonly", +		Deadline: 10, +	} +) + +// TestHandlers checks that the expected internal handlers are configured +func TestIntHandlers(t *testing.T) { +	endpoints := map[types.Endpoint]bool{ +		types.EndpointGetTreeHeadToCosign: false, +	} +	node := Secondary{ +		Config: testConfig, +	} +	for _, handler := range node.InternalHTTPHandlers() { +		if _, ok := endpoints[handler.Endpoint]; !ok { +			t.Errorf("got unexpected endpoint: %s", handler.Endpoint) +		} +		endpoints[handler.Endpoint] = true +	} +	for endpoint, ok := range endpoints { +		if !ok { +			t.Errorf("endpoint %s is not configured", endpoint) +		} +	} +} + +func TestFetchLeavesFromPrimary(t *testing.T) { +	for _, tbl := range []struct { +		desc string +		// client.GetUnsignedTreeHead() +		primaryTHRet types.TreeHead +		primaryTHErr error +		// db.GetTreeHead() +		trillianTHRet *types.TreeHead +		trillianTHErr error +		// client.GetLeaves() +		primaryGetLeavesRet types.Leaves +		primaryGetLeavesErr error +		// db.AddSequencedLeaves() +		trillianAddLeavesExp bool +		trillianAddLeavesErr error +	}{ +		{ +			desc:         "no tree head from primary", +			primaryTHErr: fmt.Errorf("mocked error"), +		}, +		{ +			desc:          "no tree head from trillian", +			primaryTHRet:  types.TreeHead{}, +			trillianTHErr: fmt.Errorf("mocked error"), +		}, +		{ +			desc:                "error fetching leaves", +			primaryTHRet:        types.TreeHead{TreeSize: 6}, +			trillianTHRet:       &types.TreeHead{TreeSize: 5}, // 6-5 => 1 expected GetLeaves +			primaryGetLeavesErr: fmt.Errorf("mocked error"), +		}, +		{ +			desc:          "error adding leaves", +			primaryTHRet:  types.TreeHead{TreeSize: 6}, +			trillianTHRet: &types.TreeHead{TreeSize: 5}, // 6-5 => 1 expected GetLeaves +			primaryGetLeavesRet: types.Leaves{ +				types.Leaf{}, +			}, +			trillianAddLeavesErr: fmt.Errorf("mocked error"), +		}, +		{ +			desc:          "success", +			primaryTHRet:  types.TreeHead{TreeSize: 10}, +			trillianTHRet: &types.TreeHead{TreeSize: 5}, +			primaryGetLeavesRet: types.Leaves{ +				types.Leaf{}, +			}, +			trillianAddLeavesExp: true, +		}, +	} { +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() + +			primaryClient := mocksClient.NewMockClient(ctrl) +			primaryClient.EXPECT().GetUnsignedTreeHead(gomock.Any()).Return(tbl.primaryTHRet, tbl.primaryTHErr) + +			trillianClient := mocksDB.NewMockClient(ctrl) +			if tbl.trillianTHErr != nil || tbl.trillianTHRet != nil { +				trillianClient.EXPECT().GetTreeHead(gomock.Any()).Return(tbl.trillianTHRet, tbl.trillianTHErr) +			} + +			if tbl.primaryGetLeavesErr != nil || tbl.primaryGetLeavesRet != nil { +				primaryClient.EXPECT().GetLeaves(gomock.Any(), gomock.Any()).Return(tbl.primaryGetLeavesRet, tbl.primaryGetLeavesErr) +				if tbl.trillianAddLeavesExp { +					for i := tbl.trillianTHRet.TreeSize; i < tbl.primaryTHRet.TreeSize-1; i++ { +						primaryClient.EXPECT().GetLeaves(gomock.Any(), gomock.Any()).Return(tbl.primaryGetLeavesRet, tbl.primaryGetLeavesErr) +					} +				} +			} + +			if tbl.trillianAddLeavesErr != nil || tbl.trillianAddLeavesExp { +				trillianClient.EXPECT().AddSequencedLeaves(gomock.Any(), gomock.Any(), gomock.Any()).Return(tbl.trillianAddLeavesErr) +				if tbl.trillianAddLeavesExp { +					for i := tbl.trillianTHRet.TreeSize; i < tbl.primaryTHRet.TreeSize-1; i++ { +						trillianClient.EXPECT().AddSequencedLeaves(gomock.Any(), gomock.Any(), gomock.Any()).Return(tbl.trillianAddLeavesErr) +					} +				} +			} + +			node := Secondary{ +				Config:         testConfig, +				Primary:        primaryClient, +				TrillianClient: trillianClient, +			} + +			node.fetchLeavesFromPrimary(context.Background()) + +			// NOTE: We are not verifying that +			// AddSequencedLeaves() is being called with +			// the right data. +		}() +	} +} diff --git a/internal/requests/requests.go b/internal/requests/requests.go new file mode 100644 index 0000000..cfd563f --- /dev/null +++ b/internal/requests/requests.go @@ -0,0 +1,91 @@ +package requests + +import ( +	"context" +	"fmt" +	"net/http" +	"time" + +	"git.sigsum.org/sigsum-go/pkg/dns" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	sigsumreq "git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +func LeafRequestFromHTTP(r *http.Request, shardStart uint64, ctx context.Context, vf dns.Verifier) (*sigsumreq.Leaf, error) { +	var req sigsumreq.Leaf +	if err := req.FromASCII(r.Body); err != nil { +		return nil, fmt.Errorf("parse ascii: %w", err) +	} +	stmt := types.Statement{ +		ShardHint: req.ShardHint, +		Checksum:  *merkle.HashFn(req.Message[:]), +	} +	if !stmt.Verify(&req.PublicKey, &req.Signature) { +		return nil, fmt.Errorf("invalid signature") +	} +	shardEnd := uint64(time.Now().Unix()) +	if req.ShardHint < shardStart { +		return nil, fmt.Errorf("invalid shard hint: %d not in [%d, %d]", req.ShardHint, shardStart, shardEnd) +	} +	if req.ShardHint > shardEnd { +		return nil, fmt.Errorf("invalid shard hint: %d not in [%d, %d]", req.ShardHint, shardStart, shardEnd) +	} +	if err := vf.Verify(ctx, req.DomainHint, &req.PublicKey); err != nil { +		return nil, fmt.Errorf("invalid domain hint: %v", err) +	} +	return &req, nil +} + +func CosignatureRequestFromHTTP(r *http.Request, w map[merkle.Hash]types.PublicKey) (*sigsumreq.Cosignature, error) { +	var req sigsumreq.Cosignature +	if err := req.FromASCII(r.Body); err != nil { +		return nil, fmt.Errorf("parse ascii: %w", err) +	} +	if _, ok := w[req.KeyHash]; !ok { +		return nil, fmt.Errorf("unknown witness: %x", req.KeyHash) +	} +	return &req, nil +} + +func ConsistencyProofRequestFromHTTP(r *http.Request) (*sigsumreq.ConsistencyProof, error) { +	var req sigsumreq.ConsistencyProof +	if err := req.FromURL(r.URL.Path); err != nil { +		return nil, fmt.Errorf("parse url: %w", err) +	} +	if req.OldSize < 1 { +		return nil, fmt.Errorf("old_size(%d) must be larger than zero", req.OldSize) +	} +	if req.NewSize <= req.OldSize { +		return nil, fmt.Errorf("new_size(%d) must be larger than old_size(%d)", req.NewSize, req.OldSize) +	} +	return &req, nil +} + +func InclusionProofRequestFromHTTP(r *http.Request) (*sigsumreq.InclusionProof, error) { +	var req sigsumreq.InclusionProof +	if err := req.FromURL(r.URL.Path); err != nil { +		return nil, fmt.Errorf("parse url: %w", err) +	} +	if req.TreeSize < 2 { +		// TreeSize:0 => not possible to prove inclusion of anything +		// TreeSize:1 => you don't need an inclusion proof (it is always empty) +		return nil, fmt.Errorf("tree_size(%d) must be larger than one", req.TreeSize) +	} +	return &req, nil +} + +func LeavesRequestFromHTTP(r *http.Request, maxRange uint64) (*sigsumreq.Leaves, error) { +	var req sigsumreq.Leaves +	if err := req.FromURL(r.URL.Path); err != nil { +		return nil, fmt.Errorf("parse url: %w", err) +	} + +	if req.StartSize > req.EndSize { +		return nil, fmt.Errorf("start_size(%d) must be less than or equal to end_size(%d)", req.StartSize, req.EndSize) +	} +	if req.EndSize-req.StartSize+1 > maxRange { +		req.EndSize = req.StartSize + maxRange - 1 +	} +	return &req, nil +} diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go new file mode 100644 index 0000000..46d6e15 --- /dev/null +++ b/internal/requests/requests_test.go @@ -0,0 +1,218 @@ +package requests + +import ( +	"bytes" +	"context" +	"crypto/ed25519" +	"crypto/rand" +	"fmt" +	"io" +	"net/http" +	"reflect" +	"testing" +	"time" + +	mocksDNS "git.sigsum.org/log-go/internal/mocks/dns" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	sigsumreq "git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +func TestLeafRequestFromHTTP(t *testing.T) { +	st := uint64(10) +	dh := "_sigsum_v0.example.org" +	msg := merkle.Hash{} +	var pub types.PublicKey +	b, priv, err := ed25519.GenerateKey(rand.Reader) +	if err != nil { +		t.Fatalf("must generate key pair: %v", err) +	} +	copy(pub[:], b) + +	sign := func(sh uint64, msg merkle.Hash) *types.Signature { +		stm := types.Statement{sh, *merkle.HashFn(msg[:])} +		sig, err := stm.Sign(priv) +		if err != nil { +			t.Fatalf("must sign: %v", err) +		} +		return sig +	} +	input := func(sh uint64, msg merkle.Hash, badSig bool) io.Reader { +		sig := sign(sh, msg)[:] +		if badSig { +			msg[0] += 1 // use a different message +		} +		str := fmt.Sprintf("shard_hint=%d\n", sh) +		str += fmt.Sprintf("message=%x\n", msg[:]) +		str += fmt.Sprintf("signature=%x\n", sig[:]) +		str += fmt.Sprintf("public_key=%x\n", pub[:]) +		str += fmt.Sprintf("domain_hint=%s\n", dh) +		return bytes.NewBufferString(str) +	} + +	for _, table := range []struct { +		desc      string +		params    io.Reader +		dnsExpect bool +		dnsErr    error +		wantRsp   *sigsumreq.Leaf +	}{ +		{"invalid: parse ascii", bytes.NewBufferString("a=b"), false, nil, nil}, +		{"invalid: signature", input(st, msg, true), false, nil, nil}, +		{"invalid: shard start", input(st-1, msg, false), false, nil, nil}, +		{"invalid: shard end", input(uint64(time.Now().Unix())+1024, msg, false), false, nil, nil}, +		{"invalid: mocked dns error", input(st, msg, false), true, fmt.Errorf("mocked dns error"), nil}, +		{"valid", input(st, msg, false), true, nil, &sigsumreq.Leaf{st, msg, *sign(st, msg), pub, dh}}, +	} { +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			vf := mocksDNS.NewMockVerifier(ctrl) +			if table.dnsExpect { +				vf.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(table.dnsErr) +			} + +			url := types.EndpointAddLeaf.Path("http://example.org/sigsum/v0") +			req, err := http.NewRequest(http.MethodPost, url, table.params) +			if err != nil { +				t.Fatalf("must create http request: %v", err) +			} + +			parsedReq, err := LeafRequestFromHTTP(req, st, context.Background(), vf) +			if got, want := err != nil, table.desc != "valid"; got != want { +				t.Errorf("%s: got error %v but wanted %v: %v", table.desc, got, want, err) +			} +			if err != nil { +				return +			} +			if got, want := parsedReq, table.wantRsp; !reflect.DeepEqual(got, want) { +				t.Errorf("%s: got request %v but wanted %v", table.desc, got, want) +			} +		}() +	} +} + +func TestCosignatureRequestFromHTTP(t *testing.T) { +	input := func(h merkle.Hash) io.Reader { +		return bytes.NewBufferString(fmt.Sprintf("cosignature=%x\nkey_hash=%x\n", types.Signature{}, h)) +	} +	w := map[merkle.Hash]types.PublicKey{ +		*merkle.HashFn([]byte("w1")): types.PublicKey{}, +	} +	for _, table := range []struct { +		desc    string +		params  io.Reader +		wantRsp *sigsumreq.Cosignature +	}{ +		{"invalid: parser error", bytes.NewBufferString("abcd"), nil}, +		{"invalid: unknown witness", input(*merkle.HashFn([]byte("w2"))), nil}, +		{"valid", input(*merkle.HashFn([]byte("w1"))), &sigsumreq.Cosignature{types.Signature{}, *merkle.HashFn([]byte("w1"))}}, +	} { +		url := types.EndpointAddCosignature.Path("http://example.org/sigsum/v0") +		req, err := http.NewRequest(http.MethodPost, url, table.params) +		if err != nil { +			t.Fatalf("must create http request: %v", err) +		} + +		parsedReq, err := CosignatureRequestFromHTTP(req, w) +		if got, want := err != nil, table.desc != "valid"; got != want { +			t.Errorf("%s: got error %v but wanted %v: %v", table.desc, got, want, err) +		} +		if err != nil { +			continue +		} +		if got, want := parsedReq, table.wantRsp; !reflect.DeepEqual(got, want) { +			t.Errorf("%s: got request %v but wanted %v", table.desc, got, want) +		} +	} +} + +func TestConsistencyProofRequestFromHTTP(t *testing.T) { +	for _, table := range []struct { +		desc    string +		params  string +		wantRsp *sigsumreq.ConsistencyProof +	}{ +		{"invalid: bad request (parser error)", "a/1", nil}, +		{"invalid: bad request (out of range 1/2)", "0/1", nil}, +		{"invalid: bad request (out of range 2/2)", "1/1", nil}, +		{"valid", "1/2", &sigsumreq.ConsistencyProof{1, 2}}, +	} { +		url := types.EndpointGetConsistencyProof.Path("http://example.org/sigsum/v0/") +		req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +		if err != nil { +			t.Fatalf("must create http request: %v", err) +		} + +		parsedReq, err := ConsistencyProofRequestFromHTTP(req) +		if got, want := err != nil, table.desc != "valid"; got != want { +			t.Errorf("%s: got error %v but wanted %v: %v", table.desc, got, want, err) +		} +		if err != nil { +			continue +		} +		if got, want := parsedReq, table.wantRsp; !reflect.DeepEqual(got, want) { +			t.Errorf("%s: got request %v but wanted %v", table.desc, got, want) +		} +	} +} + +func TestInclusionProofRequestFromHTTP(t *testing.T) { +	for _, table := range []struct { +		desc    string +		params  string +		wantRsp *sigsumreq.InclusionProof +	}{ +		{"invalid: bad request (parser error)", "a/0000000000000000000000000000000000000000000000000000000000000000", nil}, +		{"invalid: bad request (out of range)", "1/0000000000000000000000000000000000000000000000000000000000000000", nil}, +		{"valid", "2/0000000000000000000000000000000000000000000000000000000000000000", &sigsumreq.InclusionProof{2, merkle.Hash{}}}, +	} { +		url := types.EndpointGetInclusionProof.Path("http://example.org/sigsum/v0/") +		req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +		if err != nil { +			t.Fatalf("must create http request: %v", err) +		} + +		parsedReq, err := InclusionProofRequestFromHTTP(req) +		if got, want := err != nil, table.desc != "valid"; got != want { +			t.Errorf("%s: got error %v but wanted %v: %v", table.desc, got, want, err) +		} +		if err != nil { +			continue +		} +		if got, want := parsedReq, table.wantRsp; !reflect.DeepEqual(got, want) { +			t.Errorf("%s: got request %v but wanted %v", table.desc, got, want) +		} +	} +} + +func TestGetLeaves(t *testing.T) { +	maxRange := uint64(10) +	for _, table := range []struct { +		desc    string +		params  string +		wantRsp *sigsumreq.Leaves +	}{ +		{"invalid: bad request (parser error)", "a/1", nil}, +		{"invalid: bad request (StartSize > EndSize)", "1/0", nil}, +		{"valid", "0/10", &sigsumreq.Leaves{0, maxRange - 1}}, +	} { +		url := types.EndpointGetLeaves.Path("http://example.org/sigsum/v0/") +		req, err := http.NewRequest(http.MethodGet, url+table.params, nil) +		if err != nil { +			t.Fatalf("must create http request: %v", err) +		} + +		parsedReq, err := LeavesRequestFromHTTP(req, maxRange) +		if got, want := err != nil, table.desc != "valid"; got != want { +			t.Errorf("%s: got error %v but wanted %v: %v", table.desc, got, want, err) +		} +		if err != nil { +			continue +		} +		if got, want := parsedReq, table.wantRsp; !reflect.DeepEqual(got, want) { +			t.Errorf("%s: got request %v but wanted %v", table.desc, got, want) +		} +	} +} diff --git a/internal/state/single.go b/internal/state/single.go new file mode 100644 index 0000000..fd73b3f --- /dev/null +++ b/internal/state/single.go @@ -0,0 +1,265 @@ +package state + +import ( +	"context" +	"crypto" +	"crypto/ed25519" +	"fmt" +	"sync" +	"time" + +	"git.sigsum.org/log-go/internal/db" +	"git.sigsum.org/sigsum-go/pkg/client" +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/requests" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// StateManagerSingle implements a single-instance StateManagerPrimary for primary nodes +type StateManagerSingle struct { +	client    db.Client +	signer    crypto.Signer +	namespace merkle.Hash +	interval  time.Duration +	deadline  time.Duration +	secondary client.Client + +	// Lock-protected access to pointers.  A write lock is only obtained once +	// per interval when doing pointer rotation.  All endpoints are readers. +	sync.RWMutex +	signedTreeHead   *types.SignedTreeHead +	cosignedTreeHead *types.CosignedTreeHead + +	// Syncronized and deduplicated witness cosignatures for signedTreeHead +	events       chan *event +	cosignatures map[merkle.Hash]*types.Signature +} + +// NewStateManagerSingle() sets up a new state manager, in particular its +// signedTreeHead.  An optional secondary node can be used to ensure that +// a newer primary tree is not signed unless it has been replicated. +func NewStateManagerSingle(dbcli db.Client, signer crypto.Signer, interval, deadline time.Duration, secondary client.Client) (*StateManagerSingle, error) { +	sm := &StateManagerSingle{ +		client:    dbcli, +		signer:    signer, +		namespace: *merkle.HashFn(signer.Public().(ed25519.PublicKey)), +		interval:  interval, +		deadline:  deadline, +		secondary: secondary, +	} +	sth, err := sm.restoreTreeHead() +	if err != nil { +		return nil, fmt.Errorf("restore signed tree head: %v", err) +	} +	sm.signedTreeHead = sth + +	ictx, cancel := context.WithTimeout(context.Background(), sm.deadline) +	defer cancel() +	return sm, sm.tryRotate(ictx) +} + +func (sm *StateManagerSingle) ToCosignTreeHead() *types.SignedTreeHead { +	sm.RLock() +	defer sm.RUnlock() +	return sm.signedTreeHead +} + +func (sm *StateManagerSingle) CosignedTreeHead(_ context.Context) (*types.CosignedTreeHead, error) { +	sm.RLock() +	defer sm.RUnlock() +	if sm.cosignedTreeHead == nil { +		return nil, fmt.Errorf("no cosignatures available") +	} +	return sm.cosignedTreeHead, nil +} + +func (sm *StateManagerSingle) AddCosignature(ctx context.Context, pub *types.PublicKey, sig *types.Signature) error { +	sm.RLock() +	defer sm.RUnlock() + +	msg := sm.signedTreeHead.TreeHead.ToBinary(&sm.namespace) +	if !ed25519.Verify(ed25519.PublicKey(pub[:]), msg, sig[:]) { +		return fmt.Errorf("invalid cosignature") +	} +	select { +	case sm.events <- &event{merkle.HashFn(pub[:]), sig}: +		return nil +	case <-ctx.Done(): +		return fmt.Errorf("request timeout") +	} +} + +func (sm *StateManagerSingle) Run(ctx context.Context) { +	sm.events = make(chan *event, 4096) +	defer close(sm.events) +	ticker := time.NewTicker(sm.interval) +	defer ticker.Stop() + +	for { +		select { +		case <-ticker.C: +			ictx, cancel := context.WithTimeout(ctx, sm.deadline) +			defer cancel() +			if err := sm.tryRotate(ictx); err != nil { +				log.Warning("failed rotating tree heads: %v", err) +			} +		case ev := <-sm.events: +			sm.handleEvent(ev) +		case <-ctx.Done(): +			return +		} +	} +} + +func (sm *StateManagerSingle) tryRotate(ctx context.Context) error { +	th, err := sm.client.GetTreeHead(ctx) +	if err != nil { +		return fmt.Errorf("get tree head: %v", err) +	} +	nextSTH, err := sm.chooseTree(ctx, th).Sign(sm.signer, &sm.namespace) +	if err != nil { +		return fmt.Errorf("sign tree head: %v", err) +	} +	log.Debug("wanted to advance to size %d, chose size %d", th.TreeSize, nextSTH.TreeSize) + +	sm.rotate(nextSTH) +	return nil +} + +// chooseTree picks a tree to publish, taking the state of a possible secondary node into account. +func (sm *StateManagerSingle) chooseTree(ctx context.Context, proposedTreeHead *types.TreeHead) *types.TreeHead { +	if !sm.secondary.Initiated() { +		return proposedTreeHead +	} + +	secSTH, err := sm.secondary.GetToCosignTreeHead(ctx) +	if err != nil { +		log.Warning("failed fetching tree head from secondary: %v", err) +		return refreshTreeHead(sm.signedTreeHead.TreeHead) +	} +	if secSTH.TreeSize > proposedTreeHead.TreeSize { +		log.Error("secondary is ahead of us: %d > %d", secSTH.TreeSize, proposedTreeHead.TreeSize) +		return refreshTreeHead(sm.signedTreeHead.TreeHead) +	} + +	if secSTH.TreeSize == proposedTreeHead.TreeSize { +		if secSTH.RootHash != proposedTreeHead.RootHash { +			log.Error("secondary root hash doesn't match our root hash at tree size %d", secSTH.TreeSize) +			return refreshTreeHead(sm.signedTreeHead.TreeHead) +		} +		log.Debug("secondary is up-to-date with matching tree head, using proposed tree, size %d", proposedTreeHead.TreeSize) +		return proposedTreeHead +	} +	// +	// Now we know that the proposed tree size is larger than the secondary's tree size. +	// We also now that the secondary's minimum tree size is 0. +	// This means that the proposed tree size is at least 1. +	// +	// Case 1: secondary tree size is 0, primary tree size is >0 --> return based on what we signed before +	// Case 2: secondary tree size is 1, primary tree size is >1 --> fetch consistency proof, if ok -> +	//   2a) secondary tree size is smaller than or equal to what we than signed before -> return whatever we signed before +	//   2b) secondary tree size is larger than what we signed before -> return secondary tree head +	// +	// (If not ok in case 2, return based on what we signed before) +	// +	if secSTH.TreeSize == 0 { +		return refreshTreeHead(sm.signedTreeHead.TreeHead) +	} +	if err := sm.verifyConsistencyWithLatest(ctx, secSTH.TreeHead); err != nil { +		log.Error("secondaries tree not consistent with ours: %v", err) +		return refreshTreeHead(sm.signedTreeHead.TreeHead) +	} +	if secSTH.TreeSize <= sm.signedTreeHead.TreeSize { +		log.Warning("secondary is behind what primary already signed: %d <= %d", secSTH.TreeSize, sm.signedTreeHead.TreeSize) +		return refreshTreeHead(sm.signedTreeHead.TreeHead) +	} + +	log.Debug("using latest tree head from secondary: size %d", secSTH.TreeSize) +	return refreshTreeHead(secSTH.TreeHead) +} + +func (sm *StateManagerSingle) verifyConsistencyWithLatest(ctx context.Context, to types.TreeHead) error { +	from := sm.signedTreeHead.TreeHead +	req := &requests.ConsistencyProof{ +		OldSize: from.TreeSize, +		NewSize: to.TreeSize, +	} +	proof, err := sm.client.GetConsistencyProof(ctx, req) +	if err != nil { +		return fmt.Errorf("unable to get consistency proof from %d to %d: %w", req.OldSize, req.NewSize, err) +	} +	if err := proof.Verify(&from.RootHash, &to.RootHash); err != nil { +		return fmt.Errorf("invalid consistency proof from %d to %d: %v", req.OldSize, req.NewSize, err) +	} +	log.Debug("consistency proof from %d to %d verified", req.OldSize, req.NewSize) +	return nil +} + +func (sm *StateManagerSingle) rotate(nextSTH *types.SignedTreeHead) { +	sm.Lock() +	defer sm.Unlock() + +	log.Debug("about to rotate tree heads, next at %d: %s", nextSTH.TreeSize, sm.treeStatusString()) +	sm.handleEvents() +	sm.setCosignedTreeHead() +	sm.setToCosignTreeHead(nextSTH) +	log.Debug("tree heads rotated: %s", sm.treeStatusString()) +} + +func (sm *StateManagerSingle) handleEvents() { +	log.Debug("handling any outstanding events") +	for i, n := 0, len(sm.events); i < n; i++ { +		sm.handleEvent(<-sm.events) +	} +} + +func (sm *StateManagerSingle) handleEvent(ev *event) { +	log.Debug("handling event from witness %x", ev.keyHash[:]) +	sm.cosignatures[*ev.keyHash] = ev.cosignature +} + +func (sm *StateManagerSingle) setCosignedTreeHead() { +	n := len(sm.cosignatures) +	if n == 0 { +		sm.cosignedTreeHead = nil +		return +	} + +	var cth types.CosignedTreeHead +	cth.SignedTreeHead = *sm.signedTreeHead +	cth.Cosignature = make([]types.Signature, 0, n) +	cth.KeyHash = make([]merkle.Hash, 0, n) +	for keyHash, cosignature := range sm.cosignatures { +		cth.KeyHash = append(cth.KeyHash, keyHash) +		cth.Cosignature = append(cth.Cosignature, *cosignature) +	} +	sm.cosignedTreeHead = &cth +} + +func (sm *StateManagerSingle) setToCosignTreeHead(nextSTH *types.SignedTreeHead) { +	sm.cosignatures = make(map[merkle.Hash]*types.Signature) +	sm.signedTreeHead = nextSTH +} + +func (sm *StateManagerSingle) treeStatusString() string { +	var cosigned uint64 +	if sm.cosignedTreeHead != nil { +		cosigned = sm.cosignedTreeHead.TreeSize +	} +	return fmt.Sprintf("signed at %d, cosigned at %d", sm.signedTreeHead.TreeSize, cosigned) +} + +func (sm *StateManagerSingle) restoreTreeHead() (*types.SignedTreeHead, error) { +	th := zeroTreeHead() // TODO: restore from disk, stored when advanced the tree; zero tree head if "bootstrap" +	return refreshTreeHead(*th).Sign(sm.signer, &sm.namespace) +} + +func zeroTreeHead() *types.TreeHead { +	return refreshTreeHead(types.TreeHead{RootHash: *merkle.HashFn([]byte(""))}) +} + +func refreshTreeHead(th types.TreeHead) *types.TreeHead { +	th.Timestamp = uint64(time.Now().Unix()) +	return &th +} diff --git a/internal/state/single_test.go b/internal/state/single_test.go new file mode 100644 index 0000000..9442fdc --- /dev/null +++ b/internal/state/single_test.go @@ -0,0 +1,233 @@ +package state + +import ( +	"bytes" +	"context" +	"crypto" +	"crypto/ed25519" +	"crypto/rand" +	"fmt" +	"io" +	"reflect" +	"testing" +	"time" + +	mocksClient "git.sigsum.org/log-go/internal/mocks/client" +	mocksDB "git.sigsum.org/log-go/internal/mocks/db" +	"git.sigsum.org/sigsum-go/pkg/hex" +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +	"github.com/golang/mock/gomock" +) + +// TestSigner implements the signer interface.  It can be used to mock +// an Ed25519 signer that always return the same public key, +// signature, and error. +// NOTE: Code duplication with internal/node/secondary/endpoint_internal_test.go +type TestSigner struct { +	PublicKey [ed25519.PublicKeySize]byte +	Signature [ed25519.SignatureSize]byte +	Error     error +} + +func (ts *TestSigner) Public() crypto.PublicKey { +	return ed25519.PublicKey(ts.PublicKey[:]) +} + +func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { +	return ts.Signature[:], ts.Error +} + +func TestNewStateManagerSingle(t *testing.T) { +	signerOk := &TestSigner{types.PublicKey{}, types.Signature{}, nil} +	signerErr := &TestSigner{types.PublicKey{}, types.Signature{}, fmt.Errorf("err")} +	for _, table := range []struct { +		description string +		signer      crypto.Signer +		thExp       bool +		thErr       error +		th          types.TreeHead +		secExp      bool +		wantErr     bool +	}{ +		{"invalid: signer failure", signerErr, false, nil, types.TreeHead{}, false, true}, +		{"valid", signerOk, true, nil, types.TreeHead{Timestamp: now(t)}, true, false}, +	} { +		func() { +			ctrl := gomock.NewController(t) +			defer ctrl.Finish() +			trillianClient := mocksDB.NewMockClient(ctrl) +			if table.thExp { +				trillianClient.EXPECT().GetTreeHead(gomock.Any()).Return(&table.th, table.thErr) +			} +			secondary := mocksClient.NewMockClient(ctrl) +			if table.secExp { +				secondary.EXPECT().Initiated().Return(false) +			} + +			sm, err := NewStateManagerSingle(trillianClient, table.signer, time.Duration(0), time.Duration(0), secondary) +			if got, want := err != nil, table.description != "valid"; got != want { +				t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) +			} +			if err != nil { +				return +			} + +			if got, want := sm.signedTreeHead.TreeSize, table.th.TreeSize; got != want { +				t.Errorf("%q: got tree size %d but wanted %d", table.description, got, want) +			} +			if got, want := sm.signedTreeHead.RootHash[:], table.th.RootHash[:]; !bytes.Equal(got, want) { +				t.Errorf("%q: got tree size %v but wanted %v", table.description, got, want) +			} +			if got, want := sm.signedTreeHead.Timestamp, table.th.Timestamp; got < want { +				t.Errorf("%q: got timestamp %d but wanted at least %d", table.description, got, want) +			} +			if got := sm.cosignedTreeHead; got != nil { +				t.Errorf("%q: got cosigned tree head but should have none", table.description) +			} +		}() +	} +} + +func TestToCosignTreeHead(t *testing.T) { +	want := &types.SignedTreeHead{} +	sm := StateManagerSingle{ +		signedTreeHead: want, +	} +	sth := sm.ToCosignTreeHead() +	if got := sth; !reflect.DeepEqual(got, want) { +		t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v", got, want) +	} +} + +func TestCosignedTreeHead(t *testing.T) { +	want := &types.CosignedTreeHead{ +		Cosignature: make([]types.Signature, 1), +		KeyHash:     make([]merkle.Hash, 1), +	} +	sm := StateManagerSingle{ +		cosignedTreeHead: want, +	} +	cth, err := sm.CosignedTreeHead(context.Background()) +	if err != nil { +		t.Errorf("should not fail with error: %v", err) +		return +	} +	if got := cth; !reflect.DeepEqual(got, want) { +		t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v", got, want) +	} + +	sm.cosignedTreeHead = nil +	cth, err = sm.CosignedTreeHead(context.Background()) +	if err == nil { +		t.Errorf("should fail without a cosigned tree head") +		return +	} +} + +func TestAddCosignature(t *testing.T) { +	secret, public := mustKeyPair(t) +	for _, table := range []struct { +		desc    string +		signer  crypto.Signer +		vk      types.PublicKey +		wantErr bool +	}{ +		{ +			desc:    "invalid: wrong public key", +			signer:  secret, +			vk:      types.PublicKey{}, +			wantErr: true, +		}, +		{ +			desc:   "valid", +			signer: secret, +			vk:     public, +		}, +	} { +		sm := &StateManagerSingle{ +			namespace:      *merkle.HashFn(nil), +			signedTreeHead: &types.SignedTreeHead{}, +			events:         make(chan *event, 1), +		} +		defer close(sm.events) + +		sth := mustSign(t, table.signer, &sm.signedTreeHead.TreeHead, &sm.namespace) +		ctx := context.Background() +		err := sm.AddCosignature(ctx, &table.vk, &sth.Signature) +		if got, want := err != nil, table.wantErr; got != want { +			t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.desc, err) +		} +		if err != nil { +			continue +		} + +		ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) +		defer cancel() +		if err := sm.AddCosignature(ctx, &table.vk, &sth.Signature); err == nil { +			t.Errorf("expected full channel in test %q", table.desc) +		} +		if got, want := len(sm.events), 1; got != want { +			t.Errorf("wanted %d cosignatures but got %d in test %q", want, got, table.desc) +		} +	} +} + +func mustKeyPair(t *testing.T) (crypto.Signer, types.PublicKey) { +	t.Helper() +	vk, sk, err := ed25519.GenerateKey(rand.Reader) +	if err != nil { +		t.Fatal(err) +	} +	var pub types.PublicKey +	copy(pub[:], vk[:]) +	return sk, pub +} + +func mustSign(t *testing.T, s crypto.Signer, th *types.TreeHead, kh *merkle.Hash) *types.SignedTreeHead { +	t.Helper() +	sth, err := th.Sign(s, kh) +	if err != nil { +		t.Fatal(err) +	} +	return sth +} + +func newHashBufferInc(t *testing.T) *merkle.Hash { +	t.Helper() + +	var buf merkle.Hash +	for i := 0; i < len(buf); i++ { +		buf[i] = byte(i) +	} +	return &buf +} +func validConsistencyProof_5_10(t *testing.T) *types.ConsistencyProof { +	t.Helper() +	// # old tree head +	//     tree_size=5 +	//     root_hash=c8e73a8c09e44c344d515eb717e248c5dbf12420908a6d29568197fae7751803 +	// # new tree head +	//     tree_size=10 +	//     root_hash=2a40f11563b45522ca9eccf993c934238a8fbadcf7d7d65be3583ab2584838aa +	r := bytes.NewReader([]byte("consistency_path=fadca95ab8ca34f17c5f3fa719183fe0e5c194a44c25324745388964a743ecce\nconsistency_path=6366fc0c20f9b8a8c089ed210191e401da6c995592eba78125f0ba0ba142ebaf\nconsistency_path=72b8d4f990b555a72d76fb8da075a65234519070cfa42e082026a8c686160349\nconsistency_path=d92714be792598ff55560298cd3ff099dfe5724646282578531c0d0063437c00\nconsistency_path=4b20d58bbae723755304fb179aef6d5f04d755a601884828c62c07929f6bd84a\n")) +	var proof types.ConsistencyProof +	if err := proof.FromASCII(r, 5, 10); err != nil { +		t.Fatal(err) +	} +	return &proof +} + +func hashFromString(t *testing.T, s string) (h merkle.Hash) { +	b, err := hex.Deserialize(s) +	if err != nil { +		t.Fatal(err) +	} +	copy(h[:], b) +	return h +} + +func now(t *testing.T) uint64 { +	t.Helper() +	return uint64(time.Now().Unix()) +} diff --git a/internal/state/state_manager.go b/internal/state/state_manager.go new file mode 100644 index 0000000..60d2af1 --- /dev/null +++ b/internal/state/state_manager.go @@ -0,0 +1,30 @@ +package state + +import ( +	"context" + +	"git.sigsum.org/sigsum-go/pkg/merkle" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// StateManager coordinates access to a nodes tree heads and (co)signatures. +type StateManager interface { +	// ToCosignTreeHead returns the node's to-cosign tree head +	ToCosignTreeHead() *types.SignedTreeHead + +	// CosignedTreeHead returns the node's cosigned tree head +	CosignedTreeHead(context.Context) (*types.CosignedTreeHead, error) + +	// AddCosignature verifies that a cosignature is valid for the to-cosign +	// tree head before adding it +	AddCosignature(context.Context, *types.PublicKey, *types.Signature) error + +	// Run peridically rotates the node's to-cosign and cosigned tree heads +	Run(context.Context) +} + +// event is a verified cosignature request +type event struct { +	keyHash     *merkle.Hash +	cosignature *types.Signature +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..a453107 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,69 @@ +package utils + +import ( +	"crypto" +	"crypto/ed25519" +	"encoding/hex" +	"fmt" +	"io/ioutil" +	"os" +	"strings" + +	"git.sigsum.org/sigsum-go/pkg/log" +	"git.sigsum.org/sigsum-go/pkg/types" +) + +// TODO: Move SetupLogging to sigsum-go/pkg/log + +func SetupLogging(logFile, logLevel string, logColor bool) error { +	if len(logFile) != 0 { +		f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) +		if err != nil { +			return err +		} +		log.SetOutput(f) +	} + +	switch logLevel { +	case "debug": +		log.SetLevel(log.DebugLevel) +	case "info": +		log.SetLevel(log.InfoLevel) +	case "warning": +		log.SetLevel(log.WarningLevel) +	case "error": +		log.SetLevel(log.ErrorLevel) +	default: +		return fmt.Errorf("invalid logging level %s", logLevel) +	} + +	log.SetColor(logColor) +	return nil +} + +func PubkeyFromHexString(pkhex string) (*types.PublicKey, error) { +	pkbuf, err := hex.DecodeString(pkhex) +	if err != nil { +		return nil, fmt.Errorf("DecodeString: %v", err) +	} + +	var pk types.PublicKey +	if n := copy(pk[:], pkbuf); n != types.PublicKeySize { +		return nil, fmt.Errorf("invalid pubkey size: %v", n) +	} + +	return &pk, nil +} + +func NewLogIdentity(keyFile string) (crypto.Signer, string, error) { +	buf, err := ioutil.ReadFile(keyFile) +	if err != nil { +		return nil, "", err +	} +	if buf, err = hex.DecodeString(strings.TrimSpace(string(buf))); err != nil { +		return nil, "", fmt.Errorf("DecodeString: %v", err) +	} +	sk := crypto.Signer(ed25519.NewKeyFromSeed(buf)) +	vk := sk.Public().(ed25519.PublicKey) +	return sk, hex.EncodeToString([]byte(vk[:])), nil +} | 
