From 47bacfd5c5d22470340e0823c4ad37b45914b68e Mon Sep 17 00:00:00 2001 From: Rasmus Dahlberg Date: Sun, 6 Jun 2021 15:59:35 +0200 Subject: started using the refactored packages in siglog server --- client/client.go | 242 -------------- client/cmd/add-entry/main.go | 52 --- client/cmd/cosign/main.go | 56 ---- client/cmd/example.sh | 49 --- client/cmd/get-consistency-proof/main.go | 70 ---- client/cmd/get-entries/main.go | 83 ----- client/cmd/get-proof-by-hash/main.go | 66 ---- client/cmd/get-sth/main.go | 35 -- client/cmd/keygen/main.go | 17 - client/cmd/submit/main.go | 28 -- client/flag.go | 55 ---- client/verify.go | 52 --- cmd/siglog_server/.gitignore | 1 + cmd/siglog_server/README.md | 60 ++++ cmd/siglog_server/main.go | 176 ++++++++++ cmd/tmp/README.md | 2 + cmd/tmp/cosign/main.go | 56 ++++ cmd/tmp/keygen/main.go | 17 + cmd/tmp/submit/main.go | 29 ++ doc.go | 3 - endpoint.go | 163 ---------- endpoint_test.go | 481 ---------------------------- go.sum | 3 + instance.go | 75 ----- instance_test.go | 159 --------- log_parameters.go | 47 --- log_parameters_test.go | 99 ------ metric.go | 23 -- mocks/crypto.go | 23 -- mocks/stfe.go | 110 ------- mocks/trillian.go | 317 ------------------ pkg/instance/endpoint.go | 122 +++++++ pkg/instance/endpoint_test.go | 480 ++++++++++++++++++++++++++++ pkg/instance/instance.go | 90 ++++++ pkg/instance/instance_test.go | 158 +++++++++ pkg/instance/metric.go | 23 ++ pkg/instance/request.go | 77 +++++ pkg/instance/request_test.go | 318 ++++++++++++++++++ pkg/mocks/crypto.go | 23 ++ pkg/mocks/stfe.go | 110 +++++++ pkg/mocks/trillian.go | 317 ++++++++++++++++++ pkg/state/state_manager.go | 154 +++++++++ pkg/state/state_manager_test.go | 393 +++++++++++++++++++++++ pkg/trillian/client.go | 178 +++++++++++ pkg/trillian/client_test.go | 533 +++++++++++++++++++++++++++++++ pkg/trillian/util.go | 33 ++ pkg/types/ascii.go | 421 ++++++++++++++++++++++++ pkg/types/ascii_test.go | 465 +++++++++++++++++++++++++++ pkg/types/trunnel.go | 60 ++++ pkg/types/trunnel_test.go | 114 +++++++ pkg/types/types.go | 155 +++++++++ pkg/types/types_test.go | 58 ++++ pkg/types/util.go | 21 ++ request.go | 81 ----- request_test.go | 318 ------------------ server/.gitignore | 1 - server/README.md | 60 ---- server/main.go | 165 ---------- state/state_manager.go | 154 --------- state/state_manager_test.go | 393 ----------------------- sth.go | 143 --------- sth_test.go | 466 --------------------------- testdata/data.go | 287 ----------------- trillian.go | 125 -------- trillian/client.go | 178 ----------- trillian/client_test.go | 533 ------------------------------- trillian/util.go | 33 -- trillian_test.go | 282 ---------------- types/ascii.go | 421 ------------------------ types/ascii_test.go | 465 --------------------------- types/trunnel.go | 60 ---- types/trunnel_test.go | 114 ------- types/types.go | 155 --------- types/types_test.go | 58 ---- types/util.go | 21 -- util.go | 27 -- util_test.go | 17 - 77 files changed, 4647 insertions(+), 6862 deletions(-) delete mode 100644 client/client.go delete mode 100644 client/cmd/add-entry/main.go delete mode 100644 client/cmd/cosign/main.go delete mode 100755 client/cmd/example.sh delete mode 100644 client/cmd/get-consistency-proof/main.go delete mode 100644 client/cmd/get-entries/main.go delete mode 100644 client/cmd/get-proof-by-hash/main.go delete mode 100644 client/cmd/get-sth/main.go delete mode 100644 client/cmd/keygen/main.go delete mode 100644 client/cmd/submit/main.go delete mode 100644 client/flag.go delete mode 100644 client/verify.go create mode 100644 cmd/siglog_server/.gitignore create mode 100644 cmd/siglog_server/README.md create mode 100644 cmd/siglog_server/main.go create mode 100644 cmd/tmp/README.md create mode 100644 cmd/tmp/cosign/main.go create mode 100644 cmd/tmp/keygen/main.go create mode 100644 cmd/tmp/submit/main.go delete mode 100644 doc.go delete mode 100644 endpoint.go delete mode 100644 endpoint_test.go delete mode 100644 instance.go delete mode 100644 instance_test.go delete mode 100644 log_parameters.go delete mode 100644 log_parameters_test.go delete mode 100644 metric.go delete mode 100644 mocks/crypto.go delete mode 100644 mocks/stfe.go delete mode 100644 mocks/trillian.go create mode 100644 pkg/instance/endpoint.go create mode 100644 pkg/instance/endpoint_test.go create mode 100644 pkg/instance/instance.go create mode 100644 pkg/instance/instance_test.go create mode 100644 pkg/instance/metric.go create mode 100644 pkg/instance/request.go create mode 100644 pkg/instance/request_test.go create mode 100644 pkg/mocks/crypto.go create mode 100644 pkg/mocks/stfe.go create mode 100644 pkg/mocks/trillian.go create mode 100644 pkg/state/state_manager.go create mode 100644 pkg/state/state_manager_test.go create mode 100644 pkg/trillian/client.go create mode 100644 pkg/trillian/client_test.go create mode 100644 pkg/trillian/util.go create mode 100644 pkg/types/ascii.go create mode 100644 pkg/types/ascii_test.go create mode 100644 pkg/types/trunnel.go create mode 100644 pkg/types/trunnel_test.go create mode 100644 pkg/types/types.go create mode 100644 pkg/types/types_test.go create mode 100644 pkg/types/util.go delete mode 100644 request.go delete mode 100644 request_test.go delete mode 100644 server/.gitignore delete mode 100644 server/README.md delete mode 100644 server/main.go delete mode 100644 state/state_manager.go delete mode 100644 state/state_manager_test.go delete mode 100644 sth.go delete mode 100644 sth_test.go delete mode 100644 testdata/data.go delete mode 100644 trillian.go delete mode 100644 trillian/client.go delete mode 100644 trillian/client_test.go delete mode 100644 trillian/util.go delete mode 100644 trillian_test.go delete mode 100644 types/ascii.go delete mode 100644 types/ascii_test.go delete mode 100644 types/trunnel.go delete mode 100644 types/trunnel_test.go delete mode 100644 types/types.go delete mode 100644 types/types_test.go delete mode 100644 types/util.go delete mode 100644 util.go delete mode 100644 util_test.go diff --git a/client/client.go b/client/client.go deleted file mode 100644 index ba81f4d..0000000 --- a/client/client.go +++ /dev/null @@ -1,242 +0,0 @@ -package client - -import ( - "bytes" - "context" - "crypto" - "fmt" - - "io/ioutil" - "net/http" - - "github.com/golang/glog" - "github.com/google/trillian/merkle/rfc6962" - "github.com/system-transparency/stfe" - "github.com/system-transparency/stfe/types" - "golang.org/x/net/context/ctxhttp" -) - -// Descriptor is a log descriptor -type Descriptor struct { - Namespace *types.Namespace // log identifier is a namespace - Url string // log url, e.g., http://example.com/st/v1 -} - -// Client is a log client -type Client struct { - HttpClient *http.Client - Signer crypto.Signer // client's private identity - Namespace *types.Namespace // client's public identity - Log *Descriptor // log's public identity -} - -// GetLatestSth fetches and verifies the signature of the most recent STH. -// Outputs the resulting STH. -func (c *Client) GetLatestSth(ctx context.Context) (*types.StItem, error) { - url := stfe.EndpointGetLatestSth.Path(c.Log.Url) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed creating http request: %v", err) - } - glog.V(3).Infof("created http request: %s %s", req.Method, req.URL) - - item, err := c.doRequestWithStItemResponse(ctx, req) - if err != nil { - return nil, err - } - if got, want := item.Format, types.StFormatSignedTreeHeadV1; got != want { - return nil, fmt.Errorf("unexpected StItem format: %v", got) - } - if err := VerifySignedTreeHeadV1(c.Log.Namespace, item); err != nil { - return nil, fmt.Errorf("signature verification failed: %v", err) - } - glog.V(3).Infof("verified sth") - return item, nil -} - -// GetProofByHash fetches and verifies an inclusion proof for a leaf hash -// against an STH. Outputs the resulting proof. -func (c *Client) GetProofByHash(ctx context.Context, leafHash []byte, sth *types.StItem) (*types.StItem, error) { - if err := VerifySignedTreeHeadV1(c.Log.Namespace, sth); err != nil { - return nil, fmt.Errorf("invalid sth: %v", err) - } - glog.V(3).Infof("verified sth") - params := types.GetProofByHashV1{ - TreeSize: sth.SignedTreeHeadV1.TreeHead.TreeSize, - } - copy(params.Hash[:], leafHash) - buf, err := types.Marshal(params) - if err != nil { - return nil, fmt.Errorf("req: Marshal: %v", err) - } - - url := stfe.EndpointGetProofByHash.Path(c.Log.Url) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(buf)) - if err != nil { - return nil, fmt.Errorf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - glog.V(3).Infof("created http request: %s %s", req.Method, req.URL) - - item, err := c.doRequestWithStItemResponse(ctx, req) - if err != nil { - return nil, fmt.Errorf("doRequestWithStItemResponse: %v", err) - } - if got, want := item.Format, types.StFormatInclusionProofV1; got != want { - return nil, fmt.Errorf("unexpected StItem format: %v", item.Format) - } - if err := VerifyInclusionProofV1(item, sth, params.Hash[:]); err != nil { - return nil, fmt.Errorf("invalid inclusion proof: %v", err) - } - glog.V(3).Infof("verified inclusion proof") - return item, nil -} - -// GetConsistencyProof fetches and verifies a consistency proof betweeen two -// STHs. Outputs the resulting proof. -func (c *Client) GetConsistencyProof(ctx context.Context, sth1, sth2 *types.StItem) (*types.StItem, error) { - if err := VerifySignedTreeHeadV1(c.Log.Namespace, sth1); err != nil { - return nil, fmt.Errorf("invalid first sth: %v", err) - } - if err := VerifySignedTreeHeadV1(c.Log.Namespace, sth2); err != nil { - return nil, fmt.Errorf("invalid second sth: %v", err) - } - glog.V(3).Infof("verified sths") - buf, err := types.Marshal(types.GetConsistencyProofV1{ - First: sth1.SignedTreeHeadV1.TreeHead.TreeSize, - Second: sth2.SignedTreeHeadV1.TreeHead.TreeSize, - }) - if err != nil { - return nil, fmt.Errorf("req: Marshal: %v", err) - } - - url := stfe.EndpointGetConsistencyProof.Path(c.Log.Url) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(buf)) - if err != nil { - return nil, fmt.Errorf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - glog.V(3).Infof("created http request: %s %s", req.Method, req.URL) - - item, err := c.doRequestWithStItemResponse(ctx, req) - if err != nil { - return nil, fmt.Errorf("doRequestWithStItemResponse: %v", err) - } - if got, want := item.Format, types.StFormatConsistencyProofV1; got != want { - return nil, fmt.Errorf("unexpected StItem format: %v", item.Format) - } - if err := VerifyConsistencyProofV1(item, sth1, sth2); err != nil { - return nil, fmt.Errorf("invalid inclusion proof: %v", err) - } - glog.V(3).Infof("verified inclusion proof") - return item, nil -} - -// AddEntry signs and submits a checksum_v1 entry to the log. Outputs the -// resulting leaf-hash on success. -func (c *Client) AddEntry(ctx context.Context, data *types.ChecksumV1) ([]byte, error) { - msg, err := types.Marshal(*data) - if err != nil { - return nil, fmt.Errorf("failed marshaling ChecksumV1: %v", err) - } - sig, err := c.Signer.Sign(nil, msg, crypto.Hash(0)) - if err != nil { - return nil, fmt.Errorf("failed signing ChecksumV1: %v", err) - } - leaf, err := types.Marshal(*types.NewSignedChecksumV1(data, &types.SignatureV1{ - Namespace: *c.Namespace, - Signature: sig, - })) - if err != nil { - return nil, fmt.Errorf("failed marshaling SignedChecksumV1: %v", err) - } - glog.V(3).Infof("signed checksum entry for identifier %q", string(data.Identifier)) - - url := stfe.EndpointAddEntry.Path(c.Log.Url) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(leaf)) - if err != nil { - return nil, fmt.Errorf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - glog.V(3).Infof("created http request: %s %s", req.Method, req.URL) - - if rsp, err := c.doRequest(ctx, req); err != nil { - return nil, fmt.Errorf("doRequest: %v", err) - } else if len(rsp) != 0 { - return nil, fmt.Errorf("extra data: %v", err) - } - glog.V(3).Infof("add-entry succeded") - return rfc6962.DefaultHasher.HashLeaf(leaf), nil -} - -// GetEntries fetches a range of entries from the log, verifying that they are -// of type signed_checksum_v1 but nothing more than that. Outputs the resulting -// range that may be truncated by the log if [start,end] is too large. -func (c *Client) GetEntries(ctx context.Context, start, end uint64) ([]*types.StItem, error) { - buf, err := types.Marshal(types.GetEntriesV1{ - Start: start, - End: end, - }) - if err != nil { - return nil, fmt.Errorf("Marshal: %v", err) - } - url := stfe.EndpointGetEntries.Path(c.Log.Url) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(buf)) - if err != nil { - return nil, fmt.Errorf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - glog.V(3).Infof("created http request: %s %s", req.Method, req.URL) - glog.V(3).Infof("request data: start(%d), end(%d)", start, end) - - body, err := c.doRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("doRequest: %v", err) - } - var list types.StItemList - if err := types.Unmarshal(body, &list); err != nil { - return nil, fmt.Errorf("Unmarshal: %v", err) - } - ret := make([]*types.StItem, 0, len(list.Items)) - for i, _ := range list.Items { - item := list.Items[i] - if got, want := item.Format, types.StFormatSignedChecksumV1; got != want { - return nil, fmt.Errorf("unexpected StItem format: %v", got) - } - ret = append(ret, &item) - } - return ret, nil -} - -// doRequest sends an HTTP request and outputs the raw body -func (c *Client) doRequest(ctx context.Context, req *http.Request) ([]byte, error) { - rsp, err := ctxhttp.Do(ctx, c.HttpClient, req) - if err != nil { - return nil, fmt.Errorf("no response: %v", err) - } - defer rsp.Body.Close() - if got, want := rsp.StatusCode, http.StatusOK; got != want { - return nil, fmt.Errorf("bad http status: %v", got) - } - body, err := ioutil.ReadAll(rsp.Body) - if err != nil { - return nil, fmt.Errorf("cannot read body: %v", err) - } - return body, nil -} - -// -// doRequestWithStItemResponse sends an HTTP request and returns a decoded -// StItem that the resulting HTTP response contained json:ed and marshaled -func (c *Client) doRequestWithStItemResponse(ctx context.Context, req *http.Request) (*types.StItem, error) { - body, err := c.doRequest(ctx, req) - if err != nil { - return nil, err - } - var item types.StItem - if err := types.Unmarshal(body, &item); err != nil { - return nil, fmt.Errorf("failed decoding StItem: %v", err) - } - glog.V(9).Infof("got StItem: %v", item) - return &item, nil -} diff --git a/client/cmd/add-entry/main.go b/client/cmd/add-entry/main.go deleted file mode 100644 index a29d01f..0000000 --- a/client/cmd/add-entry/main.go +++ /dev/null @@ -1,52 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - - "encoding/base64" - - "github.com/golang/glog" - "github.com/system-transparency/stfe/client" - "github.com/system-transparency/stfe/types" -) - -var ( - identifier = flag.String("identifier", "", "checksum identifier") - checksum = flag.String("checksum", "", "base64-encoded checksum") -) - -func main() { - flag.Parse() - defer glog.Flush() - - client, err := client.NewClientFromFlags() - if err != nil { - glog.Errorf("NewClientFromFlags: %v", err) - return - } - data, err := NewChecksumV1FromFlags() - if err != nil { - glog.Errorf("NewChecksumV1FromFlags: %v", err) - return - } - leafHash, err := client.AddEntry(context.Background(), data) - if err != nil { - glog.Errorf("AddEntry: %v", err) - return - } - fmt.Println("leaf hash:", base64.StdEncoding.EncodeToString(leafHash)) -} - -func NewChecksumV1FromFlags() (*types.ChecksumV1, error) { - var err error - data := types.ChecksumV1{ - Identifier: []byte(*identifier), - } - data.Checksum, err = base64.StdEncoding.DecodeString(*checksum) - if err != nil { - return nil, fmt.Errorf("entry_checksum: DecodeString: %v", err) - } - return &data, nil -} diff --git a/client/cmd/cosign/main.go b/client/cmd/cosign/main.go deleted file mode 100644 index e86842b..0000000 --- a/client/cmd/cosign/main.go +++ /dev/null @@ -1,56 +0,0 @@ -package main - -import ( - "bytes" - "crypto/ed25519" - "encoding/hex" - "flag" - "fmt" - "log" - "net/http" - - "github.com/system-transparency/stfe/types" -) - -var ( - url = flag.String("url", "http://localhost:6965/st/v0", "base url") - sk = flag.String("sk", "e1d7c494dacb0ddf809a17e4528b01f584af22e3766fa740ec52a1711c59500d711090dd2286040b50961b0fe09f58aa665ccee5cb7ee042d819f18f6ab5046b", "hex key") -) - -func main() { - priv, err := hex.DecodeString(*sk) - if err != nil { - log.Fatalf("DecodeString: %v", err) - } - sk := ed25519.PrivateKey(priv) - vk := sk.Public().(ed25519.PublicKey) - fmt.Printf("sk: %x\nvk: %x\n", sk, vk) - - rsp, err := http.Get(*url + "/get-tree-head-to-sign") - if err != nil { - log.Fatalf("Get: %v", err) - } - var sth types.SignedTreeHead - if err := sth.UnmarshalASCII(rsp.Body); err != nil { - log.Fatalf("UnmarshalASCII: %v", err) - } - fmt.Printf("%+v\n", sth) - - msg := sth.TreeHead.Marshal() - sig := ed25519.Sign(sk, msg) - sigident := &types.SigIdent{ - KeyHash: types.Hash(vk[:]), - Signature: &[types.SignatureSize]byte{}, - } - copy(sigident.Signature[:], sig) - - buf := bytes.NewBuffer(nil) - if err := sigident.MarshalASCII(buf); err != nil { - log.Fatalf("MarshalASCII: %v", err) - } - rsp, err = http.Post(*url+"/add-cosignature", "type/stfe", buf) - if err != nil { - log.Fatalf("Post: %v", err) - } - fmt.Printf("Status: %v\n", rsp.StatusCode) -} diff --git a/client/cmd/example.sh b/client/cmd/example.sh deleted file mode 100755 index d790712..0000000 --- a/client/cmd/example.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -set -eu - -log_url=http://tlog-poc.system-transparency.org:4780/st/v1 -log_id=AAG+ZW+UesWdMFytUGkp28csBcziomSB3U2vvkAW55MVZQ== -tmpdir=$(mktemp -dt stfe.XXXXXXXX) -cp $0 $tmpdir/ -cd $tmpdir - -commonargs="--log_id $log_id --log_url $log_url" # --logtostderr -v 3 -pause="sleep 1" - -echo "arguments used:" -echo $commonargs -echo "" - -echo "fetching sth..." -get-sth $commonargs | tee sth1.output -echo "" && $pause - -echo "adding an entry..." -add-entry $commonargs \ - --identifier "example.sh v0.0.1-$(cat /dev/urandom | base64 | head -c 10)" \ - --checksum $(sha256sum "$0") | tee add-entry.output -echo "" && $pause - -echo "fetching another sth..." -get-sth $commonargs | tee sth2.output -echo "" && $pause - -echo "verifying inclusion..." -get-proof-by-hash $commonargs \ - --leaf_hash $(cat add-entry.output | awk '{print $3}') \ - --sth $(cat sth2.output | awk '{print $2}') -echo "" && $pause - -echo "verifying consistency..." -get-consistency-proof $commonargs \ - --first $(cat sth1.output | awk '{print $2}') \ - --second $(cat sth2.output | awk '{print $2}') -echo "" && $pause - -echo "fetching the log's first entry..." -get-entries $commonargs --start 0 --end 0 -echo "" - -rm *.output $0 -cd -rmdir $tmpdir diff --git a/client/cmd/get-consistency-proof/main.go b/client/cmd/get-consistency-proof/main.go deleted file mode 100644 index bb8a7a6..0000000 --- a/client/cmd/get-consistency-proof/main.go +++ /dev/null @@ -1,70 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - - "encoding/base64" - - "github.com/golang/glog" - "github.com/system-transparency/stfe/client" - "github.com/system-transparency/stfe/types" -) - -var ( - first = flag.String("first", "", "base64-encoded sth") - second = flag.String("second", "", "base64-encoded sth") -) - -func main() { - flag.Parse() - defer glog.Flush() - - client, err := client.NewClientFromFlags() - if err != nil { - glog.Errorf("NewClientFromFlags: %v", err) - return - } - sth1, sth2, err := newParamsFromFlags() - if err != nil { - glog.Errorf("NewRequestFromFlags: %v", err) - return - } - - proof, err := client.GetConsistencyProof(context.Background(), sth1, sth2) - if err != nil { - glog.Errorf("GetConsistencyProof: %v", err) - return - } - serialized, err := types.Marshal(*proof) - if err != nil { - glog.Errorf("Marshal: %v", err) - return - } - fmt.Println("proof:", base64.StdEncoding.EncodeToString(serialized)) -} - -func newParamsFromFlags() (*types.StItem, *types.StItem, error) { - sth1, err := decodeSthStr(*first) - if err != nil { - return nil, nil, fmt.Errorf("first: decodeSthStr: %v", err) - } - sth2, err := decodeSthStr(*second) - if err != nil { - return nil, nil, fmt.Errorf("second: decodeSthStr: %v", err) - } - return sth1, sth2, nil -} - -func decodeSthStr(sthStr string) (*types.StItem, error) { - serialized, err := base64.StdEncoding.DecodeString(sthStr) - if err != nil { - return nil, fmt.Errorf("DecodeString: %v", err) - } - var item types.StItem - if err = types.Unmarshal(serialized, &item); err != nil { - return nil, fmt.Errorf("Unmarshal: %v", err) - } - return &item, nil -} diff --git a/client/cmd/get-entries/main.go b/client/cmd/get-entries/main.go deleted file mode 100644 index f32fdbf..0000000 --- a/client/cmd/get-entries/main.go +++ /dev/null @@ -1,83 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - - "encoding/base64" - - "github.com/golang/glog" - "github.com/google/trillian/merkle/rfc6962" - "github.com/system-transparency/stfe/client" - "github.com/system-transparency/stfe/types" -) - -var ( - start = flag.Uint64("start", 0, "inclusive start index to download") - end = flag.Uint64("end", 0, "inclusive stop index to download") -) - -func main() { - flag.Parse() - defer glog.Flush() - - client, err := client.NewClientFromFlags() - if err != nil { - glog.Errorf("NewClientFromFlags: %v", err) - return - } - items, err := getRange(client, *start, *end) - if err != nil { - glog.Errorf("getRange: %v", err) - return - } - if err := printRange(items); err != nil { - glog.Errorf("printRange: %v", err) - return - } -} - -func getRange(client *client.Client, start, end uint64) ([]*types.StItem, error) { - items := make([]*types.StItem, 0, end-start+1) - for len(items) != cap(items) { - rsp, err := client.GetEntries(context.Background(), start, end) - if err != nil { - return nil, fmt.Errorf("fetching entries failed: %v", err) - } - items = append(items, rsp...) - start += uint64(len(rsp)) - } - return items, nil -} - -func printRange(items []*types.StItem) error { - for i, item := range items { - var status string - msg, err := types.Marshal(item.SignedChecksumV1.Data) - if err != nil { - return fmt.Errorf("Marshal data failed: %v", err) - } - sig := item.SignedChecksumV1.Signature.Signature - namespace := &item.SignedChecksumV1.Signature.Namespace - if err := namespace.Verify(msg, sig); err != nil { - status = "unverified signature" - } else { - status = "verified signature" - } - serializedNamespace, err := types.Marshal(*namespace) - if err != nil { - return fmt.Errorf("Marshal namespace failed: %v", err) - } - serializedLeaf, err := types.Marshal(*item) - if err != nil { - return fmt.Errorf("Marshal item on index %d: %v", *start+uint64(i), err) - } - fmt.Printf("Index(%d) - %s\n", *start+uint64(i), status) - fmt.Printf("-> Namespace: %s\n", base64.StdEncoding.EncodeToString(serializedNamespace)) - fmt.Printf("-> Identifier: %s\n", string(item.SignedChecksumV1.Data.Identifier)) - fmt.Printf("-> Checksum: %s\n", base64.StdEncoding.EncodeToString(item.SignedChecksumV1.Data.Checksum)) - fmt.Printf("-> Leaf hash: %s\n", base64.StdEncoding.EncodeToString(rfc6962.DefaultHasher.HashLeaf(serializedLeaf))) - } - return nil -} diff --git a/client/cmd/get-proof-by-hash/main.go b/client/cmd/get-proof-by-hash/main.go deleted file mode 100644 index 1f4f304..0000000 --- a/client/cmd/get-proof-by-hash/main.go +++ /dev/null @@ -1,66 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - - "encoding/base64" - - "github.com/golang/glog" - "github.com/system-transparency/stfe/client" - "github.com/system-transparency/stfe/types" -) - -var ( - sthStr = flag.String("sth", "", "base64-encoded StItem of type StFormatSignedTreeHeadV1 (default: fetch new sth)") - leafHashStr = flag.String("leaf_hash", "", "base64-encoded leaf hash") -) - -func main() { - flag.Parse() - defer glog.Flush() - - client, err := client.NewClientFromFlags() - if err != nil { - glog.Errorf("NewClientFromFlags: %v", err) - return - } - leafHash, sth, err := newParamsFromFlags(client) - if err != nil { - glog.Errorf("NewRequestFromFlags: %v", err) - return - } - - proof, err := client.GetProofByHash(context.Background(), leafHash, sth) - if err != nil { - glog.Errorf("GetProofByHash: %v", err) - return - } - serialized, err := types.Marshal(*proof) - if err != nil { - glog.Errorf("Marshal: %v", err) - } - fmt.Println("proof:", base64.StdEncoding.EncodeToString(serialized)) -} - -func newParamsFromFlags(client *client.Client) ([]byte, *types.StItem, error) { - serialized, err := base64.StdEncoding.DecodeString(*sthStr) - if err != nil { - return nil, nil, fmt.Errorf("sth: DecodeString: %v", err) - } - var item types.StItem - if err = types.Unmarshal(serialized, &item); err != nil { - return nil, nil, fmt.Errorf("sth: Unmarshal: %v", err) - } else if got, want := item.Format, types.StFormatSignedTreeHeadV1; got != want { - return nil, nil, fmt.Errorf("unexpected StItem format: %v", got) - } - leafHash, err := base64.StdEncoding.DecodeString(*leafHashStr) - if err != nil { - return nil, nil, fmt.Errorf("leaf_hash: DecodeString: %v", err) - } else if got, want := len(leafHash), 32; got != want { - return nil, nil, fmt.Errorf("leaf_hash: unexpected size: %v", got) - } - glog.V(3).Infof("created request parameters TreeSize(%d) and LeafHash(%s)", item.SignedTreeHeadV1.TreeHead.TreeSize, *leafHashStr) - return leafHash, &item, nil -} diff --git a/client/cmd/get-sth/main.go b/client/cmd/get-sth/main.go deleted file mode 100644 index 6b23b06..0000000 --- a/client/cmd/get-sth/main.go +++ /dev/null @@ -1,35 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - - "encoding/base64" - - "github.com/golang/glog" - "github.com/system-transparency/stfe/client" - "github.com/system-transparency/stfe/types" -) - -func main() { - flag.Parse() - defer glog.Flush() - - client, err := client.NewClientFromFlags() - if err != nil { - glog.Errorf("NewClientFromFlags: %v", err) - return - } - sth, err := client.GetLatestSth(context.Background()) - if err != nil { - glog.Errorf("GetLatestSth: %v", err) - return - } - serialized, err := types.Marshal(*sth) - if err != nil { - glog.Errorf("Marshal: %v", err) - return - } - fmt.Println("sth:", base64.StdEncoding.EncodeToString(serialized)) -} diff --git a/client/cmd/keygen/main.go b/client/cmd/keygen/main.go deleted file mode 100644 index c1c1b58..0000000 --- a/client/cmd/keygen/main.go +++ /dev/null @@ -1,17 +0,0 @@ -package main - -import ( - "crypto/ed25519" - "crypto/rand" - "fmt" - "log" -) - -func main() { - vk, sk, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("GenerateKey: %v", err) - } - fmt.Printf("sk: %x\n", sk[:]) - fmt.Printf("vk: %x\n", vk[:]) -} diff --git a/client/cmd/submit/main.go b/client/cmd/submit/main.go deleted file mode 100644 index 36c7271..0000000 --- a/client/cmd/submit/main.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -// go run . | bash - -import ( - "crypto/ed25519" - "crypto/rand" - "fmt" - "github.com/system-transparency/stfe/types" -) - -func main() { - checksum := [32]byte{} - msg := types.Message{ - ShardHint: 0, - Checksum: &checksum, - } - - vk, sk, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - fmt.Printf("ed25519.GenerateKey: %v\n", err) - return - } - sig := ed25519.Sign(sk, msg.Marshal()) - //fmt.Printf("sk: %x\nvk: %x\n", sk[:], vk[:]) - - fmt.Printf("echo \"shard_hint=%d\nchecksum=%x\nsignature_over_message=%x\nverification_key=%x\ndomain_hint=%s\" | curl --data-binary @- localhost:6965/st/v0/add-leaf\n", msg.ShardHint, msg.Checksum[:], sig, vk[:], "example.com") -} diff --git a/client/flag.go b/client/flag.go deleted file mode 100644 index 8ba7a10..0000000 --- a/client/flag.go +++ /dev/null @@ -1,55 +0,0 @@ -package client - -import ( - "flag" - "fmt" - - "crypto/ed25519" - "encoding/base64" - "net/http" - - "github.com/system-transparency/stfe/types" -) - -var ( - logId = flag.String("log_id", "AAG+ZW+UesWdMFytUGkp28csBcziomSB3U2vvkAW55MVZQ==", "base64-encoded log identifier") - logUrl = flag.String("log_url", "http://tlog-poc.system-transparency.org:4780/st/v1", "log url") - ed25519_sk = flag.String("ed25519_sk", "d8i6nud7PS1vdO0sIk9H+W0nyxbM63Y3/mSeUPRafWaFh8iH8QXvL7NaAYn2RZPrnEey+FdpmTYXE47OFO70eg==", "base64-encoded ed25519 signing key") -) - -func NewClientFromFlags() (*Client, error) { - var err error - c := Client{ - HttpClient: &http.Client{}, - } - if len(*ed25519_sk) != 0 { - sk, err := base64.StdEncoding.DecodeString(*ed25519_sk) - if err != nil { - return nil, fmt.Errorf("ed25519_sk: DecodeString: %v", err) - } - c.Signer = ed25519.PrivateKey(sk) - c.Namespace, err = types.NewNamespaceEd25519V1([]byte(ed25519.PrivateKey(sk).Public().(ed25519.PublicKey))) - if err != nil { - return nil, fmt.Errorf("ed25519_vk: NewNamespaceEd25519V1: %v", err) - } - } - if c.Log, err = NewDescriptorFromFlags(); err != nil { - return nil, fmt.Errorf("NewDescriptorFromFlags: %v", err) - } - return &c, nil -} - -func NewDescriptorFromFlags() (*Descriptor, error) { - b, err := base64.StdEncoding.DecodeString(*logId) - if err != nil { - return nil, fmt.Errorf("LogId: DecodeString: %v", err) - } - var namespace types.Namespace - if err := types.Unmarshal(b, &namespace); err != nil { - return nil, fmt.Errorf("LogId: Unmarshal: %v", err) - } - return &Descriptor{ - Namespace: &namespace, - Url: *logUrl, - }, nil -} diff --git a/client/verify.go b/client/verify.go deleted file mode 100644 index c95828c..0000000 --- a/client/verify.go +++ /dev/null @@ -1,52 +0,0 @@ -package client - -import ( - "fmt" - "reflect" - - "github.com/google/trillian/merkle" - "github.com/google/trillian/merkle/rfc6962" - "github.com/system-transparency/stfe/types" -) - -func VerifySignedTreeHeadV1(namespace *types.Namespace, sth *types.StItem) error { - if got, want := &sth.SignedTreeHeadV1.Signature.Namespace, namespace; !reflect.DeepEqual(got, want) { - return fmt.Errorf("unexpected log id: %v", want) - } - th, err := types.Marshal(sth.SignedTreeHeadV1.TreeHead) - if err != nil { - return fmt.Errorf("Marshal: %v", err) - } - if err := namespace.Verify(th, sth.SignedTreeHeadV1.Signature.Signature); err != nil { - return fmt.Errorf("Verify: %v", err) - } - return nil -} - -func VerifyConsistencyProofV1(proof, first, second *types.StItem) error { - path := make([][]byte, 0, len(proof.ConsistencyProofV1.ConsistencyPath)) - for _, nh := range proof.ConsistencyProofV1.ConsistencyPath { - path = append(path, nh.Data) - } - return merkle.NewLogVerifier(rfc6962.DefaultHasher).VerifyConsistencyProof( - int64(proof.ConsistencyProofV1.TreeSize1), - int64(proof.ConsistencyProofV1.TreeSize2), - first.SignedTreeHeadV1.TreeHead.RootHash.Data, - second.SignedTreeHeadV1.TreeHead.RootHash.Data, - path, - ) -} - -func VerifyInclusionProofV1(proof, sth *types.StItem, leafHash []byte) error { - path := make([][]byte, 0, len(proof.InclusionProofV1.InclusionPath)) - for _, nh := range proof.InclusionProofV1.InclusionPath { - path = append(path, nh.Data) - } - return merkle.NewLogVerifier(rfc6962.DefaultHasher).VerifyInclusionProof( - int64(proof.InclusionProofV1.LeafIndex), - int64(proof.InclusionProofV1.TreeSize), - path, - sth.SignedTreeHeadV1.TreeHead.RootHash.Data, - leafHash, - ) -} diff --git a/cmd/siglog_server/.gitignore b/cmd/siglog_server/.gitignore new file mode 100644 index 0000000..254defd --- /dev/null +++ b/cmd/siglog_server/.gitignore @@ -0,0 +1 @@ +server diff --git a/cmd/siglog_server/README.md b/cmd/siglog_server/README.md new file mode 100644 index 0000000..71bb3ac --- /dev/null +++ b/cmd/siglog_server/README.md @@ -0,0 +1,60 @@ +# Run Trillian + STFE locally +Trillian uses a database. So, we will need to set that up. It is documented +[here](https://github.com/google/trillian#mysql-setup), and how to check that it +is setup properly +[here](https://github.com/google/certificate-transparency-go/blob/master/trillian/docs/ManualDeployment.md#data-storage). + +Other than the database we need the Trillian log signer, Trillian log server, +and STFE server. +``` +$ go install github.com/google/trillian/cmd/trillian_log_signer +$ go install github.com/google/trillian/cmd/trillian_log_server +$ go install +``` + +Start Trillian log signer: +``` +trillian_log_signer --logtostderr -v 9 --force_master --rpc_endpoint=localhost:6961 --http_endpoint=localhost:6964 --num_sequencers 1 --sequencer_interval 100ms --batch_size 100 +``` + +Start Trillian log server: +``` +trillian_log_server --logtostderr -v 9 --rpc_endpoint=localhost:6962 --http_endpoint=localhost:6963 +``` + +As described in more detail +[here](https://github.com/google/certificate-transparency-go/blob/master/trillian/docs/ManualDeployment.md#trillian-services), +we need to provision a Merkle tree once: +``` +$ go install github.com/google/trillian/cmd/createtree +$ createtree --admin_server localhost:6962 + +``` + +Hang on to ``. Our STFE server will use it when talking to the +Trillian log server to specify which Merkle tree we are working against. + +(If you take a look in the `Trees` table you will see that the tree has been +provisioned.) + +We will also need a public key-pair and log identifier for the STFE server. +``` +$ go install github.com/system-transparency/stfe/types/cmd/new-namespace +sk: +vk: +ed25519_v1: +``` + +The log's identifier is `` and contains the public verification key +``. The log's corresponding secret signing key is ``. + +Start STFE server: +``` +$ ./server --logtostderr -v 9 --http_endpoint localhost:6965 --log_rpc_server localhost:6962 --trillian_id --key +``` + +If the log is responsive on, e.g., `GET http://localhost:6965/st/v1/get-latest-sth` you +may want to try running +`github.com/system-transparency/stfe/client/cmd/example.sh`. You need to +configure the log's id though for verification to work (flag `log_id`, which +should be set to the `` output above). diff --git a/cmd/siglog_server/main.go b/cmd/siglog_server/main.go new file mode 100644 index 0000000..368b0a7 --- /dev/null +++ b/cmd/siglog_server/main.go @@ -0,0 +1,176 @@ +// Package main provides an STFE server binary +package main + +import ( + "context" + "crypto" + "crypto/ed25519" + "encoding/hex" + "flag" + "fmt" + "net/http" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "github.com/golang/glog" + "github.com/google/trillian" + "github.com/prometheus/client_golang/prometheus/promhttp" + stfe "github.com/system-transparency/stfe/pkg/instance" + "github.com/system-transparency/stfe/pkg/state" + trillianWrapper "github.com/system-transparency/stfe/pkg/trillian" + "github.com/system-transparency/stfe/pkg/types" + "google.golang.org/grpc" +) + +var ( + httpEndpoint = flag.String("http_endpoint", "localhost:6965", "host:port specification of where stfe serves clients") + rpcBackend = flag.String("log_rpc_server", "localhost:6962", "host:port specification of where Trillian serves clients") + prefix = flag.String("prefix", "", "a prefix that proceeds /st/v0/") + trillianID = flag.Int64("trillian_id", 0, "log identifier in the Trillian database") + deadline = flag.Duration("deadline", time.Second*10, "deadline for backend requests") + key = flag.String("key", "", "hex-encoded Ed25519 signing key") + witnesses = flag.String("witnesses", "", "comma-separated list of trusted witness verification keys in hex") + maxRange = flag.Int64("max_range", 10, "maximum number of entries that can be retrived in a single request") + interval = flag.Duration("interval", time.Second*30, "interval used to rotate the log's cosigned STH") +) + +func main() { + flag.Parse() + defer glog.Flush() + + // wait for clean-up before exit + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + glog.V(3).Infof("configuring stfe instance...") + instance, err := setupInstanceFromFlags() + if err != nil { + glog.Errorf("setupInstance: %v", err) + return + } + + glog.V(3).Infof("spawning state manager") + go func() { + wg.Add(1) + defer wg.Done() + instance.Stateman.Run(ctx) + glog.Errorf("state manager shutdown") + cancel() // must have state manager running + }() + + glog.V(3).Infof("spawning await") + server := http.Server{Addr: *httpEndpoint} + go await(ctx, func() { + wg.Add(1) + defer wg.Done() + ctxInner, _ := context.WithTimeout(ctx, time.Second*60) + glog.Infof("Shutting down HTTP server...") + server.Shutdown(ctxInner) + glog.V(3).Infof("HTTP server shutdown") + glog.Infof("Shutting down spawned go routines...") + cancel() + }) + + glog.Infof("Serving on %v/%v", *httpEndpoint, *prefix) + if err = server.ListenAndServe(); err != http.ErrServerClosed { + glog.Errorf("ListenAndServe: %v", err) + } +} + +// SetupInstance sets up a new STFE instance from flags +func setupInstanceFromFlags() (*stfe.Instance, error) { + var i stfe.Instance + var err error + + // Setup log configuration + i.Signer, i.LogID, err = newLogIdentity(*key) + if err != nil { + return nil, fmt.Errorf("newLogIdentity: %v", err) + } + i.TreeID = *trillianID + i.Prefix = *prefix + i.MaxRange = *maxRange + i.Deadline = *deadline + i.Interval = *interval + i.Witnesses, err = newWitnessMap(*witnesses) + if err != nil { + return nil, fmt.Errorf("newWitnessMap: %v", err) + } + + // Setup log client + dialOpts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(i.Deadline)} + conn, err := grpc.Dial(*rpcBackend, dialOpts...) + if err != nil { + return nil, fmt.Errorf("Dial: %v", err) + } + i.Client = &trillianWrapper.TrillianClient{ + TreeID: i.TreeID, + GRPC: trillian.NewTrillianLogClient(conn), + } + + // Setup state manager + i.Stateman, err = state.NewStateManagerSingle(i.Client, i.Signer, i.Interval, i.Deadline) + if err != nil { + return nil, fmt.Errorf("NewStateManager: %v", err) + } + + // Register HTTP endpoints + mux := http.NewServeMux() + http.Handle("/", mux) + for _, handler := range i.Handlers() { + glog.V(3).Infof("adding handler: %s", handler.Path()) + mux.Handle(handler.Path(), handler) + } + glog.V(3).Infof("Adding prometheus handler on path: /metrics") + http.Handle("/metrics", promhttp.Handler()) + + return &i, nil +} + +func newLogIdentity(key string) (crypto.Signer, string, error) { + buf, err := hex.DecodeString(key) + if err != nil { + return nil, "", fmt.Errorf("DecodeString: %v", err) + } + sk := crypto.Signer(ed25519.PrivateKey(buf)) + vk := sk.Public().(ed25519.PublicKey) + return sk, hex.EncodeToString([]byte(vk[:])), nil +} + +// newWitnessMap creates a new map of trusted witnesses +func newWitnessMap(witnesses string) (map[[types.HashSize]byte][types.VerificationKeySize]byte, error) { + w := make(map[[types.HashSize]byte][types.VerificationKeySize]byte) + if len(witnesses) > 0 { + for _, witness := range strings.Split(witnesses, ",") { + b, err := hex.DecodeString(witness) + if err != nil { + return nil, fmt.Errorf("DecodeString: %v", err) + } + + var vk [types.VerificationKeySize]byte + if n := copy(vk[:], b); n != types.VerificationKeySize { + return nil, fmt.Errorf("Invalid verification key size: %v", n) + } + w[*types.Hash(vk[:])] = vk + } + } + return w, nil +} + +// await waits for a shutdown signal and then runs a clean-up function +func await(ctx context.Context, done func()) { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + select { + case <-sigs: + case <-ctx.Done(): + } + glog.V(3).Info("received shutdown signal") + done() +} diff --git a/cmd/tmp/README.md b/cmd/tmp/README.md new file mode 100644 index 0000000..30d5317 --- /dev/null +++ b/cmd/tmp/README.md @@ -0,0 +1,2 @@ +# Warning +These basic commands will be moved or replaced by proper tooling. diff --git a/cmd/tmp/cosign/main.go b/cmd/tmp/cosign/main.go new file mode 100644 index 0000000..a51f17d --- /dev/null +++ b/cmd/tmp/cosign/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "bytes" + "crypto/ed25519" + "encoding/hex" + "flag" + "fmt" + "log" + "net/http" + + "github.com/system-transparency/stfe/pkg/types" +) + +var ( + url = flag.String("url", "http://localhost:6965/st/v0", "base url") + sk = flag.String("sk", "e1d7c494dacb0ddf809a17e4528b01f584af22e3766fa740ec52a1711c59500d711090dd2286040b50961b0fe09f58aa665ccee5cb7ee042d819f18f6ab5046b", "hex key") +) + +func main() { + priv, err := hex.DecodeString(*sk) + if err != nil { + log.Fatalf("DecodeString: %v", err) + } + sk := ed25519.PrivateKey(priv) + vk := sk.Public().(ed25519.PublicKey) + fmt.Printf("sk: %x\nvk: %x\n", sk, vk) + + rsp, err := http.Get(*url + "/get-tree-head-to-sign") + if err != nil { + log.Fatalf("Get: %v", err) + } + var sth types.SignedTreeHead + if err := sth.UnmarshalASCII(rsp.Body); err != nil { + log.Fatalf("UnmarshalASCII: %v", err) + } + fmt.Printf("%+v\n", sth) + + msg := sth.TreeHead.Marshal() + sig := ed25519.Sign(sk, msg) + sigident := &types.SigIdent{ + KeyHash: types.Hash(vk[:]), + Signature: &[types.SignatureSize]byte{}, + } + copy(sigident.Signature[:], sig) + + buf := bytes.NewBuffer(nil) + if err := sigident.MarshalASCII(buf); err != nil { + log.Fatalf("MarshalASCII: %v", err) + } + rsp, err = http.Post(*url+"/add-cosignature", "type/stfe", buf) + if err != nil { + log.Fatalf("Post: %v", err) + } + fmt.Printf("Status: %v\n", rsp.StatusCode) +} diff --git a/cmd/tmp/keygen/main.go b/cmd/tmp/keygen/main.go new file mode 100644 index 0000000..c1c1b58 --- /dev/null +++ b/cmd/tmp/keygen/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "crypto/ed25519" + "crypto/rand" + "fmt" + "log" +) + +func main() { + vk, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("GenerateKey: %v", err) + } + fmt.Printf("sk: %x\n", sk[:]) + fmt.Printf("vk: %x\n", vk[:]) +} diff --git a/cmd/tmp/submit/main.go b/cmd/tmp/submit/main.go new file mode 100644 index 0000000..3dcaa97 --- /dev/null +++ b/cmd/tmp/submit/main.go @@ -0,0 +1,29 @@ +package main + +// go run . | bash + +import ( + "crypto/ed25519" + "crypto/rand" + "fmt" + + "github.com/system-transparency/stfe/pkg/types" +) + +func main() { + checksum := [32]byte{} + msg := types.Message{ + ShardHint: 0, + Checksum: &checksum, + } + + vk, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + fmt.Printf("ed25519.GenerateKey: %v\n", err) + return + } + sig := ed25519.Sign(sk, msg.Marshal()) + //fmt.Printf("sk: %x\nvk: %x\n", sk[:], vk[:]) + + fmt.Printf("echo \"shard_hint=%d\nchecksum=%x\nsignature_over_message=%x\nverification_key=%x\ndomain_hint=%s\" | curl --data-binary @- localhost:6965/st/v0/add-leaf\n", msg.ShardHint, msg.Checksum[:], sig, vk[:], "example.com") +} diff --git a/doc.go b/doc.go deleted file mode 100644 index 4e86552..0000000 --- a/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package stfe implements a System Transparency Front-End (STFE) personality -// for the Trillian log server gRPC API. -package stfe diff --git a/endpoint.go b/endpoint.go deleted file mode 100644 index 9be55b4..0000000 --- a/endpoint.go +++ /dev/null @@ -1,163 +0,0 @@ -package stfe - -import ( - "context" - "crypto/ed25519" - "fmt" - "net/http" - - "github.com/golang/glog" - "github.com/google/trillian" - "github.com/system-transparency/stfe/types" -) - -func addEntry(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { - glog.V(3).Info("handling add-entry request") - leaf, err := i.LogParameters.parseAddEntryV1Request(r) - if err != nil { - return http.StatusBadRequest, fmt.Errorf("parseAddEntryV1Request: %v", err) - } - trsp, err := i.Client.QueueLeaf(ctx, &trillian.QueueLeafRequest{ - LogId: i.LogParameters.TreeId, - Leaf: &trillian.LogLeaf{ - LeafValue: leaf.Marshal(), - ExtraData: nil, - }, - }) - if errInner := checkQueueLeaf(trsp, err); errInner != nil { - return http.StatusInternalServerError, fmt.Errorf("bad QueueLeafResponse: %v", errInner) - } - return http.StatusOK, nil -} - -func addCosignature(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { - glog.V(3).Info("handling add-cosignature request") - req, err := i.LogParameters.parseAddCosignatureRequest(r) - if err != nil { - return http.StatusBadRequest, fmt.Errorf("parseAddCosignatureRequest: %v", err) - } - vk := i.LogParameters.Witnesses[*req.KeyHash] - if err := i.SthSource.AddCosignature(ctx, ed25519.PublicKey(vk[:]), req.Signature); err != nil { - return http.StatusBadRequest, fmt.Errorf("AddCosignature: %v", err) - } - return http.StatusOK, nil -} - -func getLatestSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { - glog.V(3).Info("handling get-latest-sth request") - sth, err := i.SthSource.Latest(ctx) - if err != nil { - return http.StatusInternalServerError, fmt.Errorf("Latest: %v", err) - } - if err := sth.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - return http.StatusOK, nil -} - -func getStableSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { - glog.V(3).Info("handling get-stable-sth request") - sth, err := i.SthSource.Stable(ctx) - if err != nil { - return http.StatusInternalServerError, fmt.Errorf("Latest: %v", err) - } - if err := sth.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - return http.StatusOK, nil -} - -func getCosignedSth(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { - glog.V(3).Info("handling get-cosigned-sth request") - sth, err := i.SthSource.Cosigned(ctx) - if err != nil { - return http.StatusInternalServerError, fmt.Errorf("Cosigned: %v", err) - } - if err := sth.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - return http.StatusOK, nil -} - -func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { - glog.V(3).Info("handling get-consistency-proof request") - req, err := i.LogParameters.parseGetConsistencyProofRequest(r) - if err != nil { - return http.StatusBadRequest, err - } - - trsp, err := i.Client.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ - LogId: i.LogParameters.TreeId, - FirstTreeSize: int64(req.OldSize), - SecondTreeSize: int64(req.NewSize), - }) - if errInner := checkGetConsistencyProof(i.LogParameters, trsp, err); errInner != nil { - return http.StatusInternalServerError, fmt.Errorf("bad GetConsistencyProofResponse: %v", errInner) - } - - proof := &types.ConsistencyProof{ - NewSize: req.NewSize, - OldSize: req.OldSize, - Path: NodePathFromHashes(trsp.Proof.Hashes), - } - if err := proof.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - return http.StatusOK, nil -} - -func getProofByHash(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { - glog.V(3).Info("handling get-proof-by-hash request") - req, err := i.LogParameters.parseGetProofByHashRequest(r) - if err != nil { - return http.StatusBadRequest, err - } - - trsp, err := i.Client.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ - LogId: i.LogParameters.TreeId, - LeafHash: req.LeafHash[:], - TreeSize: int64(req.TreeSize), - OrderBySequence: true, - }) - if errInner := checkGetInclusionProofByHash(i.LogParameters, trsp, err); errInner != nil { - return http.StatusInternalServerError, fmt.Errorf("bad GetInclusionProofByHashResponse: %v", errInner) - } - - proof := &types.InclusionProof{ - TreeSize: req.TreeSize, - LeafIndex: uint64(trsp.Proof[0].LeafIndex), - Path: NodePathFromHashes(trsp.Proof[0].Hashes), - } - if err := proof.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - return http.StatusOK, nil -} - -func getEntries(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { - glog.V(3).Info("handling get-entries request") - req, err := i.LogParameters.parseGetEntriesRequest(r) - if err != nil { - return http.StatusBadRequest, err - } - - trsp, err := i.Client.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ - LogId: i.LogParameters.TreeId, - StartIndex: int64(req.StartSize), - Count: int64(req.EndSize-req.StartSize) + 1, - }) - if errInner := checkGetLeavesByRange(req, trsp, err); errInner != nil { - return http.StatusInternalServerError, fmt.Errorf("checkGetLeavesByRangeResponse: %v", errInner) // there is one StatusBadRequest in here tho.. - } - - for _, serialized := range trsp.Leaves { - var leaf types.Leaf - if err := leaf.Unmarshal(serialized.LeafValue); err != nil { - return http.StatusInternalServerError, fmt.Errorf("Unmarshal: %v", err) - } - if err := leaf.MarshalASCII(w); err != nil { - return http.StatusInternalServerError, fmt.Errorf("MarshalASCII: %v", err) - } - } - return http.StatusOK, nil -} diff --git a/endpoint_test.go b/endpoint_test.go deleted file mode 100644 index e515635..0000000 --- a/endpoint_test.go +++ /dev/null @@ -1,481 +0,0 @@ -package stfe - -import ( - "bytes" - "context" - "fmt" - "reflect" - "testing" - - "net/http" - "net/http/httptest" - - "github.com/golang/mock/gomock" - cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" - "github.com/google/trillian" - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -func TestEndpointAddEntry(t *testing.T) { - for _, table := range []struct { - description string - breq *bytes.Buffer - trsp *trillian.QueueLeafResponse - terr error - wantCode int - }{ - { - description: "invalid: bad request: empty", - breq: bytes.NewBuffer(nil), - wantCode: http.StatusBadRequest, - }, - { - description: "invalid: bad Trillian response: error", - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), - terr: fmt.Errorf("backend failure"), - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), - trsp: testdata.DefaultTQlr(t, false), - wantCode: http.StatusOK, - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - url := EndpointAddEntry.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - if table.trsp != nil || table.terr != nil { - ti.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - } - - w := httptest.NewRecorder() - ti.postHandler(t, EndpointAddEntry).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointAddCosignature(t *testing.T) { - for _, table := range []struct { - description string - breq *bytes.Buffer - wantCode int - }{ - { - description: "invalid: bad request: empty", - breq: bytes.NewBuffer(nil), - wantCode: http.StatusBadRequest, - }, - { - description: "invalid: signed wrong sth", // newLogParameters() use testdata.Ed25519VkLog as default - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog2), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), - wantCode: http.StatusBadRequest, - }, - { - description: "valid", - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), - wantCode: http.StatusOK, - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - url := EndpointAddCosignature.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - w := httptest.NewRecorder() - ti.postHandler(t, EndpointAddCosignature).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetLatestSth(t *testing.T) { - for _, table := range []struct { - description string - trsp *trillian.GetLatestSignedLogRootResponse - terr error - wantCode int - wantItem *types.StItem - }{ - { - description: "backend failure", - terr: fmt.Errorf("backend failure"), - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - trsp: testdata.DefaultTSlr(t), - wantCode: http.StatusOK, - wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, cttestdata.NewSignerWithFixedSig(nil, testdata.Signature)) - ti.ctrl.Finish() - - // Setup and run client query - url := EndpointGetLatestSth.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - if table.trsp != nil || table.terr != nil { - ti.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - } - - w := httptest.NewRecorder() - ti.getHandler(t, EndpointGetLatestSth).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItem - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetStableSth(t *testing.T) { - for _, table := range []struct { - description string - useBadSource bool - wantCode int - wantItem *types.StItem - }{ - { - description: "invalid: sth source failure", - useBadSource: true, - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - wantCode: http.StatusOK, - wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - ti.ctrl.Finish() - if table.useBadSource { - ti.instance.SthSource = &ActiveSthSource{} - } - - // Setup and run client query - url := EndpointGetStableSth.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - - w := httptest.NewRecorder() - ti.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItem - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetCosignedSth(t *testing.T) { - for _, table := range []struct { - description string - useBadSource bool - wantCode int - wantItem *types.StItem - }{ - { - description: "invalid: sth source failure", - useBadSource: true, - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - wantCode: http.StatusOK, - wantItem: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - ti.ctrl.Finish() - if table.useBadSource { - ti.instance.SthSource = &ActiveSthSource{} - } - - // Setup and run client query - url := EndpointGetCosignedSth.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - - w := httptest.NewRecorder() - ti.getHandler(t, EndpointGetCosignedSth).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItem - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetProofByHash(t *testing.T) { - for _, table := range []struct { - description string - breq *bytes.Buffer - trsp *trillian.GetInclusionProofByHashResponse - terr error - wantCode int - wantItem *types.StItem - }{ - { - description: "invalid: bad request: empty", - breq: bytes.NewBuffer(nil), - wantCode: http.StatusBadRequest, - }, - { - description: "invalid: bad Trillian response: error", - breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), - terr: fmt.Errorf("backend failure"), - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), - trsp: testdata.DefaultTGipbhr(t), - wantCode: http.StatusOK, - wantItem: testdata.DefaultInclusionProof(t, 1), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - url := EndpointGetProofByHash.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - if table.trsp != nil || table.terr != nil { - ti.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - } - - w := httptest.NewRecorder() - ti.postHandler(t, EndpointGetProofByHash).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItem - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetConsistencyProof(t *testing.T) { - for _, table := range []struct { - description string - breq *bytes.Buffer - trsp *trillian.GetConsistencyProofResponse - terr error - wantCode int - wantItem *types.StItem - }{ - { - description: "invalid: bad request: empty", - breq: bytes.NewBuffer(nil), - wantCode: http.StatusBadRequest, - }, - { - description: "invalid: bad Trillian response: error", - breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), - terr: fmt.Errorf("backend failure"), - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", - breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), - trsp: testdata.DefaultTGcpr(t), - wantCode: http.StatusOK, - wantItem: testdata.DefaultConsistencyProof(t, 1, 2), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - url := EndpointGetConsistencyProof.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - if table.trsp != nil || table.terr != nil { - ti.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - } - - w := httptest.NewRecorder() - ti.postHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItem - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -func TestEndpointGetEntriesV1(t *testing.T) { - for _, table := range []struct { - description string - breq *bytes.Buffer - trsp *trillian.GetLeavesByRangeResponse - terr error - wantCode int - wantItem *types.StItemList - }{ - { - description: "invalid: bad request: empty", - breq: bytes.NewBuffer(nil), - wantCode: http.StatusBadRequest, - }, - { - description: "invalid: bad Trillian response: error", - breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: 0})), - terr: fmt.Errorf("backend failure"), - wantCode: http.StatusInternalServerError, - }, - { - description: "valid", // remember that newLogParameters() have testdata.MaxRange configured - breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange - 1)})), - trsp: testdata.DefaultTGlbrr(t, 0, testdata.MaxRange-1), - wantCode: http.StatusOK, - wantItem: testdata.DefaultStItemList(t, 0, uint64(testdata.MaxRange)-1), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - url := EndpointGetEntries.Path("http://example.com", ti.instance.LogParameters.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("must create http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - if table.trsp != nil || table.terr != nil { - ti.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - } - - w := httptest.NewRecorder() - ti.postHandler(t, EndpointGetEntries).ServeHTTP(w, req) - if got, want := w.Code, table.wantCode; got != want { - t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) - } - if w.Code != http.StatusOK { - return - } - - var item types.StItemList - if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { - t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) - } - if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { - t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - }() - } -} - -// TODO: TestWriteOctetResponse -func TestWriteOctetResponse(t *testing.T) { -} - -// deadlineMatcher implements gomock.Matcher, such that an error is raised if -// there is no context.Context deadline set -type deadlineMatcher struct{} - -// newDeadlineMatcher returns a new DeadlineMatcher -func newDeadlineMatcher() gomock.Matcher { - return &deadlineMatcher{} -} - -// Matches returns true if the passed interface is a context with a deadline -func (dm *deadlineMatcher) Matches(i interface{}) bool { - ctx, ok := i.(context.Context) - if !ok { - return false - } - _, ok = ctx.Deadline() - return ok -} - -// String is needed to implement gomock.Matcher -func (dm *deadlineMatcher) String() string { - return fmt.Sprintf("deadlineMatcher{}") -} diff --git a/go.sum b/go.sum index 68617e4..839df30 100644 --- a/go.sum +++ b/go.sum @@ -181,6 +181,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -461,6 +462,7 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -645,6 +647,7 @@ golang.org/x/tools v0.0.0-20200630154851-b2d8b0336632/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200706234117-b22de6825cf7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/instance.go b/instance.go deleted file mode 100644 index 4425770..0000000 --- a/instance.go +++ /dev/null @@ -1,75 +0,0 @@ -package stfe - -import ( - "context" - "fmt" - "time" - - "net/http" - - "github.com/golang/glog" - "github.com/google/trillian" - "github.com/system-transparency/stfe/types" -) - -// Instance is an instance of the system transparency front-end -type Instance struct { - Client trillian.TrillianLogClient - LogParameters *LogParameters - SthSource SthSource -} - -// Handlers returns a list of STFE handlers -func (i *Instance) Handlers() []Handler { - return []Handler{ - Handler{Instance: i, Handler: addEntry, Endpoint: types.EndpointAddLeaf, Method: http.MethodPost}, - Handler{Instance: i, Handler: addCosignature, Endpoint: types.EndpointAddCosignature, Method: http.MethodPost}, - Handler{Instance: i, Handler: getLatestSth, Endpoint: types.EndpointGetTreeHeadLatest, Method: http.MethodGet}, - Handler{Instance: i, Handler: getStableSth, Endpoint: types.EndpointGetTreeHeadToSign, Method: http.MethodGet}, - Handler{Instance: i, Handler: getCosignedSth, Endpoint: types.EndpointGetTreeHeadCosigned, Method: http.MethodGet}, - Handler{Instance: i, Handler: getProofByHash, Endpoint: types.EndpointGetProofByHash, Method: http.MethodPost}, - Handler{Instance: i, Handler: getConsistencyProof, Endpoint: types.EndpointGetConsistencyProof, Method: http.MethodPost}, - Handler{Instance: i, Handler: getEntries, Endpoint: types.EndpointGetLeaves, Method: http.MethodPost}, - } -} - -// Handler implements the http.Handler interface, and contains a reference -// to an STFE server instance as well as a function that uses it. -type Handler struct { - Instance *Instance - Endpoint types.Endpoint - Method string - Handler func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error) -} - -// Path returns a path that should be configured for this handler -func (h Handler) Path() string { - return h.Endpoint.Path("", h.Instance.LogParameters.Prefix) -} - -// ServeHTTP is part of the http.Handler interface -func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // export prometheus metrics - var now time.Time = time.Now() - var statusCode int - defer func() { - rspcnt.Inc(a.Instance.LogParameters.LogId, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) - latency.Observe(time.Now().Sub(now).Seconds(), a.Instance.LogParameters.LogId, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) - }() - reqcnt.Inc(a.Instance.LogParameters.LogId, string(a.Endpoint)) - - ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.LogParameters.Deadline)) - defer cancel() - - if r.Method != a.Method { - glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.LogParameters.Prefix, string(a.Endpoint), r.Method, a.Method) - http.Error(w, "", http.StatusMethodNotAllowed) - return - } - - statusCode, err := a.Handler(ctx, a.Instance, w, r) - if err != nil { - glog.Warningf("handler error %s/%s: %v", a.Instance.LogParameters.Prefix, a.Endpoint, err) - http.Error(w, fmt.Sprintf("%s%s%s%s", "Error", types.Delim, err.Error(), types.EOL), statusCode) - } -} diff --git a/instance_test.go b/instance_test.go deleted file mode 100644 index de539a1..0000000 --- a/instance_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package stfe - -import ( - "crypto" - "testing" - - "net/http" - "net/http/httptest" - - "github.com/golang/mock/gomock" - "github.com/google/certificate-transparency-go/trillian/mockclient" - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -type testInstance struct { - ctrl *gomock.Controller - client *mockclient.MockTrillianLogClient - instance *Instance -} - -// newTestInstances sets up a test instance that uses default log parameters -// with an optional signer, see newLogParameters() for further details. The -// SthSource is instantiated with an ActiveSthSource that has (i) the default -// STH as the currently cosigned STH based on testdata.Ed25519VkWitness, and -// (ii) the default STH without any cosignatures as the currently stable STH. -func newTestInstance(t *testing.T, signer crypto.Signer) *testInstance { - t.Helper() - ctrl := gomock.NewController(t) - client := mockclient.NewMockTrillianLogClient(ctrl) - return &testInstance{ - ctrl: ctrl, - client: client, - instance: &Instance{ - Client: client, - LogParameters: newLogParameters(t, signer), - SthSource: &ActiveSthSource{ - client: client, - logParameters: newLogParameters(t, signer), - currCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool), - }, - }, - } -} - -// getHandlers returns all endpoints that use HTTP GET as a map to handlers -func (ti *testInstance) getHandlers(t *testing.T) map[Endpoint]Handler { - t.Helper() - return map[Endpoint]Handler{ - EndpointGetLatestSth: Handler{Instance: ti.instance, Handler: getLatestSth, Endpoint: EndpointGetLatestSth, Method: http.MethodGet}, - EndpointGetStableSth: Handler{Instance: ti.instance, Handler: getStableSth, Endpoint: EndpointGetStableSth, Method: http.MethodGet}, - EndpointGetCosignedSth: Handler{Instance: ti.instance, Handler: getCosignedSth, Endpoint: EndpointGetCosignedSth, Method: http.MethodGet}, - } -} - -// postHandlers returns all endpoints that use HTTP POST as a map to handlers -func (ti *testInstance) postHandlers(t *testing.T) map[Endpoint]Handler { - t.Helper() - return map[Endpoint]Handler{ - EndpointAddEntry: Handler{Instance: ti.instance, Handler: addEntry, Endpoint: EndpointAddEntry, Method: http.MethodPost}, - EndpointAddCosignature: Handler{Instance: ti.instance, Handler: addCosignature, Endpoint: EndpointAddCosignature, Method: http.MethodPost}, - EndpointGetConsistencyProof: Handler{Instance: ti.instance, Handler: getConsistencyProof, Endpoint: EndpointGetConsistencyProof, Method: http.MethodPost}, - EndpointGetProofByHash: Handler{Instance: ti.instance, Handler: getProofByHash, Endpoint: EndpointGetProofByHash, Method: http.MethodPost}, - EndpointGetEntries: Handler{Instance: ti.instance, Handler: getEntries, Endpoint: EndpointGetEntries, Method: http.MethodPost}, - } -} - -// getHandler must return a particular HTTP GET handler -func (ti *testInstance) getHandler(t *testing.T, endpoint Endpoint) Handler { - t.Helper() - handler, ok := ti.getHandlers(t)[endpoint] - if !ok { - t.Fatalf("must return HTTP GET handler for endpoint: %s", endpoint) - } - return handler -} - -// postHandler must return a particular HTTP POST handler -func (ti *testInstance) postHandler(t *testing.T, endpoint Endpoint) Handler { - t.Helper() - handler, ok := ti.postHandlers(t)[endpoint] - if !ok { - t.Fatalf("must return HTTP POST handler for endpoint: %s", endpoint) - } - return handler -} - -// TestHandlers checks that we configured all endpoints and that there are no -// unexpected ones. -func TestHandlers(t *testing.T) { - endpoints := map[Endpoint]bool{ - EndpointAddEntry: false, - EndpointAddCosignature: false, - EndpointGetLatestSth: false, - EndpointGetStableSth: false, - EndpointGetCosignedSth: false, - EndpointGetConsistencyProof: false, - EndpointGetProofByHash: false, - EndpointGetEntries: false, - } - i := &Instance{nil, newLogParameters(t, nil), nil} - for _, handler := range i.Handlers() { - if _, ok := endpoints[handler.Endpoint]; !ok { - t.Errorf("got unexpected endpoint: %s", handler.Endpoint) - } - endpoints[handler.Endpoint] = true - } - for endpoint, ok := range endpoints { - if !ok { - t.Errorf("endpoint %s is not configured", endpoint) - } - } -} - -// TestGetHandlersRejectPost checks that all get handlers reject post requests -func TestGetHandlersRejectPost(t *testing.T) { - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - for endpoint, handler := range ti.getHandlers(t) { - t.Run(string(endpoint), func(t *testing.T) { - s := httptest.NewServer(handler) - defer s.Close() - - url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix) - if rsp, err := http.Post(url, "application/json", nil); err != nil { - t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err) - } else if rsp.StatusCode != http.StatusMethodNotAllowed { - t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) - } - }) - } -} - -// TestPostHandlersRejectGet checks that all post handlers reject get requests -func TestPostHandlersRejectGet(t *testing.T) { - ti := newTestInstance(t, nil) - defer ti.ctrl.Finish() - - for endpoint, handler := range ti.postHandlers(t) { - t.Run(string(endpoint), func(t *testing.T) { - s := httptest.NewServer(handler) - defer s.Close() - - url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix) - if rsp, err := http.Get(url); err != nil { - t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err) - } else if rsp.StatusCode != http.StatusMethodNotAllowed { - t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) - } - }) - } -} - -// TODO: TestHandlerPath -func TestHandlerPath(t *testing.T) { -} diff --git a/log_parameters.go b/log_parameters.go deleted file mode 100644 index aceff3e..0000000 --- a/log_parameters.go +++ /dev/null @@ -1,47 +0,0 @@ -package stfe - -import ( - "crypto" - "crypto/ed25519" - "fmt" - "time" - - "github.com/system-transparency/stfe/types" -) - -// LogParameters is a collection of log parameters -type LogParameters struct { - LogId string // serialized log id (hex) - TreeId int64 // used internally by Trillian - Prefix string // e.g., "test" for /test - MaxRange int64 // max entries per get-entries request - Deadline time.Duration // gRPC deadline - Interval time.Duration // cosigning sth frequency - HashType crypto.Hash // hash function used by Trillian - Signer crypto.Signer // access to Ed25519 private key - - // Witnesses map trusted witness identifiers to public verification keys - Witnesses map[[types.HashSize]byte][types.VerificationKeySize]byte -} - -// Sign signs a tree head -func (lp *LogParameters) Sign(th *types.TreeHead) (*types.SignedTreeHead, error) { - sig, err := lp.Signer.Sign(nil, th.Marshal(), crypto.Hash(0)) - if err != nil { - return nil, fmt.Errorf("Sign failed: %v", err) - } - lastSthTimestamp.Set(float64(time.Now().Unix()), lp.LogId) - lastSthSize.Set(float64(th.TreeSize), lp.LogId) - - sigident := types.SigIdent{ - KeyHash: types.Hash(lp.Signer.Public().(ed25519.PublicKey)[:]), - Signature: &[types.SignatureSize]byte{}, - } - copy(sigident.Signature[:], sig) - return &types.SignedTreeHead{ - TreeHead: *th, - SigIdent: []*types.SigIdent{ - &sigident, - }, - }, nil -} diff --git a/log_parameters_test.go b/log_parameters_test.go deleted file mode 100644 index 88e83ad..0000000 --- a/log_parameters_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package stfe - -import ( - "crypto" - "fmt" - "reflect" - "testing" - - cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -// newLogParameters must create new log parameters with an optional log signer -// based on the parameters in "github.com/system-transparency/stfe/testdata". -// The log's namespace is initialized with testdata.LogEd25519Vk, the submmiter -// namespace list is initialized with testdata.SubmmiterEd25519, and the witness -// namespace list is initialized with testdata.WitnessEd25519Vk. The log's -// submitter and witness policies are set to reject unregistered namespace. -func newLogParameters(t *testing.T, signer crypto.Signer) *LogParameters { - t.Helper() - logId := testdata.NewNamespace(t, testdata.Ed25519VkLog) - witnessPool := testdata.NewNamespacePool(t, []*types.Namespace{ - testdata.NewNamespace(t, testdata.Ed25519VkWitness), - }) - submitPool := testdata.NewNamespacePool(t, []*types.Namespace{ - testdata.NewNamespace(t, testdata.Ed25519VkSubmitter), - }) - lp, err := NewLogParameters(signer, logId, testdata.TreeId, testdata.Prefix, submitPool, witnessPool, testdata.MaxRange, testdata.Interval, testdata.Deadline, true, true) - if err != nil { - t.Fatalf("must create new log parameters: %v", err) - } - return lp -} - -func TestNewLogParameters(t *testing.T) { - for _, table := range []struct { - description string - logId *types.Namespace - wantErr bool - }{ - { - description: "invalid: cannot marshal log id", - logId: &types.Namespace{ - Format: types.NamespaceFormatReserved, - }, - wantErr: true, - }, - { - description: "valid", - logId: testdata.NewNamespace(t, testdata.Ed25519VkLog), - }, - } { - _, err := NewLogParameters(nil, table.logId, testdata.TreeId, testdata.Prefix, nil, nil, testdata.MaxRange, testdata.Interval, testdata.Deadline, true, true) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - } -} - -func TestSignTreeHeadV1(t *testing.T) { - for _, table := range []struct { - description string - th *types.TreeHeadV1 - signer crypto.Signer - wantErr bool - wantSth *types.StItem - }{ - { - description: "invalid: marshal failure", - th: types.NewTreeHeadV1(testdata.Timestamp, testdata.TreeSize, nil, testdata.Extension), - wantErr: true, - }, - { - description: "invalid: signature failure", - th: types.NewTreeHeadV1(testdata.Timestamp, testdata.TreeSize, testdata.NodeHash, testdata.Extension), - signer: cttestdata.NewSignerWithErr(nil, fmt.Errorf("signer failed")), - wantErr: true, - }, - { - description: "valid", - th: testdata.DefaultTh(t), - signer: cttestdata.NewSignerWithFixedSig(nil, testdata.Signature), - wantSth: testdata.DefaultSth(t, testdata.Ed25519VkLog), - }, - } { - sth, err := newLogParameters(t, table.signer).SignTreeHeadV1(table.th) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - - if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) { - t.Errorf("got \n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - } -} diff --git a/metric.go b/metric.go deleted file mode 100644 index 7e3e8b2..0000000 --- a/metric.go +++ /dev/null @@ -1,23 +0,0 @@ -package stfe - -import ( - "github.com/google/trillian/monitoring" - "github.com/google/trillian/monitoring/prometheus" -) - -var ( - reqcnt monitoring.Counter // number of incoming http requests - rspcnt monitoring.Counter // number of valid http responses - latency monitoring.Histogram // request-response latency - lastSthTimestamp monitoring.Gauge // unix timestamp from the most recent sth - lastSthSize monitoring.Gauge // tree size of most recent sth -) - -func init() { - mf := prometheus.MetricFactory{} - reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint") - rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status") - latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status") - lastSthTimestamp = mf.NewGauge("last_sth_timestamp", "unix timestamp while handling the most recent sth", "logid") - lastSthSize = mf.NewGauge("last_sth_size", "most recent sth tree size", "logid") -} diff --git a/mocks/crypto.go b/mocks/crypto.go deleted file mode 100644 index 87c883a..0000000 --- a/mocks/crypto.go +++ /dev/null @@ -1,23 +0,0 @@ -package mocks - -import ( - "crypto" - "crypto/ed25519" - "io" -) - -// TestSign implements the signer interface. It can be used to mock an Ed25519 -// signer that always return the same public key, signature, and error. -type TestSigner struct { - PublicKey *[ed25519.PublicKeySize]byte - Signature *[ed25519.SignatureSize]byte - Error error -} - -func (ts *TestSigner) Public() crypto.PublicKey { - return ed25519.PublicKey(ts.PublicKey[:]) -} - -func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return ts.Signature[:], ts.Error -} diff --git a/mocks/stfe.go b/mocks/stfe.go deleted file mode 100644 index e0fe7a9..0000000 --- a/mocks/stfe.go +++ /dev/null @@ -1,110 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/system-transparency/stfe/trillian (interfaces: Client) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - types "github.com/system-transparency/stfe/types" -) - -// MockClient is a mock of Client interface. -type MockClient struct { - ctrl *gomock.Controller - recorder *MockClientMockRecorder -} - -// MockClientMockRecorder is the mock recorder for MockClient. -type MockClientMockRecorder struct { - mock *MockClient -} - -// NewMockClient creates a new mock instance. -func NewMockClient(ctrl *gomock.Controller) *MockClient { - mock := &MockClient{ctrl: ctrl} - mock.recorder = &MockClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClient) EXPECT() *MockClientMockRecorder { - return m.recorder -} - -// AddLeaf mocks base method. -func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *types.LeafRequest) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddLeaf indicates an expected call of AddLeaf. -func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1) -} - -// GetConsistencyProof mocks base method. -func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1) - ret0, _ := ret[0].(*types.ConsistencyProof) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetConsistencyProof indicates an expected call of GetConsistencyProof. -func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1) -} - -// GetInclusionProof mocks base method. -func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *types.InclusionProofRequest) (*types.InclusionProof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1) - ret0, _ := ret[0].(*types.InclusionProof) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInclusionProof indicates an expected call of GetInclusionProof. -func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1) -} - -// GetLeaves mocks base method. -func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *types.LeavesRequest) (*types.LeafList, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1) - ret0, _ := ret[0].(*types.LeafList) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLeaves indicates an expected call of GetLeaves. -func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1) -} - -// GetTreeHead mocks base method. -func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTreeHead", arg0) - ret0, _ := ret[0].(*types.TreeHead) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTreeHead indicates an expected call of GetTreeHead. -func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0) -} diff --git a/mocks/trillian.go b/mocks/trillian.go deleted file mode 100644 index 8aa3a58..0000000 --- a/mocks/trillian.go +++ /dev/null @@ -1,317 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/google/trillian (interfaces: TrillianLogClient) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - trillian "github.com/google/trillian" - grpc "google.golang.org/grpc" -) - -// MockTrillianLogClient is a mock of TrillianLogClient interface. -type MockTrillianLogClient struct { - ctrl *gomock.Controller - recorder *MockTrillianLogClientMockRecorder -} - -// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient. -type MockTrillianLogClientMockRecorder struct { - mock *MockTrillianLogClient -} - -// NewMockTrillianLogClient creates a new mock instance. -func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient { - mock := &MockTrillianLogClient{ctrl: ctrl} - mock.recorder = &MockTrillianLogClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder { - return m.recorder -} - -// AddSequencedLeaf mocks base method. -func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...) - ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AddSequencedLeaf indicates an expected call of AddSequencedLeaf. -func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...) -} - -// AddSequencedLeaves mocks base method. -func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...) - ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. -func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...) -} - -// GetConsistencyProof mocks base method. -func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...) - ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetConsistencyProof indicates an expected call of GetConsistencyProof. -func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...) -} - -// GetEntryAndProof mocks base method. -func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...) - ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetEntryAndProof indicates an expected call of GetEntryAndProof. -func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...) -} - -// GetInclusionProof mocks base method. -func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetInclusionProof", varargs...) - ret0, _ := ret[0].(*trillian.GetInclusionProofResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInclusionProof indicates an expected call of GetInclusionProof. -func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...) -} - -// GetInclusionProofByHash mocks base method. -func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...) - ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash. -func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...) -} - -// GetLatestSignedLogRoot mocks base method. -func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...) - ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot. -func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...) -} - -// GetLeavesByHash mocks base method. -func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...) - ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLeavesByHash indicates an expected call of GetLeavesByHash. -func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...) -} - -// GetLeavesByIndex mocks base method. -func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...) - ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLeavesByIndex indicates an expected call of GetLeavesByIndex. -func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...) -} - -// GetLeavesByRange mocks base method. -func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...) - ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLeavesByRange indicates an expected call of GetLeavesByRange. -func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...) -} - -// GetSequencedLeafCount mocks base method. -func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...) - ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount. -func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...) -} - -// InitLog mocks base method. -func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "InitLog", varargs...) - ret0, _ := ret[0].(*trillian.InitLogResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// InitLog indicates an expected call of InitLog. -func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...) -} - -// QueueLeaf mocks base method. -func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueueLeaf", varargs...) - ret0, _ := ret[0].(*trillian.QueueLeafResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueueLeaf indicates an expected call of QueueLeaf. -func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...) -} - -// QueueLeaves mocks base method. -func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueueLeaves", varargs...) - ret0, _ := ret[0].(*trillian.QueueLeavesResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueueLeaves indicates an expected call of QueueLeaves. -func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...) -} diff --git a/pkg/instance/endpoint.go b/pkg/instance/endpoint.go new file mode 100644 index 0000000..5085c49 --- /dev/null +++ b/pkg/instance/endpoint.go @@ -0,0 +1,122 @@ +package stfe + +import ( + "context" + "net/http" + + "github.com/golang/glog" +) + +func addLeaf(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling add-entry request") + req, err := i.leafRequestFromHTTP(r) + if err != nil { + return http.StatusBadRequest, err + } + if err := i.Client.AddLeaf(ctx, req); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func addCosignature(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling add-cosignature request") + req, err := i.cosignatureRequestFromHTTP(r) + if err != nil { + return http.StatusBadRequest, err + } + vk := i.Witnesses[*req.KeyHash] + if err := i.Stateman.AddCosignature(ctx, &vk, req.Signature); err != nil { + return http.StatusBadRequest, err + } + return http.StatusOK, nil +} + +func getTreeHeadLatest(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { + glog.V(3).Info("handling get-tree-head-latest request") + sth, err := i.Stateman.Latest(ctx) + if err != nil { + return http.StatusInternalServerError, err + } + if err := sth.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func getTreeHeadToSign(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { + glog.V(3).Info("handling get-tree-head-to-sign request") + sth, err := i.Stateman.ToSign(ctx) + if err != nil { + return http.StatusInternalServerError, err + } + if err := sth.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func getTreeHeadCosigned(ctx context.Context, i *Instance, w http.ResponseWriter, _ *http.Request) (int, error) { + glog.V(3).Info("handling get-tree-head-cosigned request") + sth, err := i.Stateman.Cosigned(ctx) + if err != nil { + return http.StatusInternalServerError, err + } + if err := sth.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func getConsistencyProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling get-consistency-proof request") + req, err := i.consistencyProofRequestFromHTTP(r) + if err != nil { + return http.StatusBadRequest, err + } + + proof, err := i.Client.GetConsistencyProof(ctx, req) + if err != nil { + return http.StatusInternalServerError, err + } + if err := proof.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func getInclusionProof(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling get-proof-by-hash request") + req, err := i.inclusionProofRequestFromHTTP(r) + if err != nil { + return http.StatusBadRequest, err + } + + proof, err := i.Client.GetInclusionProof(ctx, req) + if err != nil { + return http.StatusInternalServerError, err + } + if err := proof.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + return http.StatusOK, nil +} + +func getLeaves(ctx context.Context, i *Instance, w http.ResponseWriter, r *http.Request) (int, error) { + glog.V(3).Info("handling get-leaves request") + req, err := i.leavesRequestFromHTTP(r) + if err != nil { + return http.StatusBadRequest, err + } + + leaves, err := i.Client.GetLeaves(ctx, req) + if err != nil { + return http.StatusInternalServerError, err + } + for _, leaf := range *leaves { + if err := leaf.MarshalASCII(w); err != nil { + return http.StatusInternalServerError, err + } + } + return http.StatusOK, nil +} diff --git a/pkg/instance/endpoint_test.go b/pkg/instance/endpoint_test.go new file mode 100644 index 0000000..8511b8d --- /dev/null +++ b/pkg/instance/endpoint_test.go @@ -0,0 +1,480 @@ +package stfe + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" + "github.com/google/trillian" + "github.com/system-transparency/stfe/pkg/testdata" + "github.com/system-transparency/stfe/pkg/types" +) + +func TestEndpointAddEntry(t *testing.T) { + for _, table := range []struct { + description string + breq *bytes.Buffer + trsp *trillian.QueueLeafResponse + terr error + wantCode int + }{ + { + description: "invalid: bad request: empty", + breq: bytes.NewBuffer(nil), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad Trillian response: error", + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), + terr: fmt.Errorf("backend failure"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), + trsp: testdata.DefaultTQlr(t, false), + wantCode: http.StatusOK, + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + url := EndpointAddEntry.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + if table.trsp != nil || table.terr != nil { + ti.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) + } + + w := httptest.NewRecorder() + ti.postHandler(t, EndpointAddEntry).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointAddCosignature(t *testing.T) { + for _, table := range []struct { + description string + breq *bytes.Buffer + wantCode int + }{ + { + description: "invalid: bad request: empty", + breq: bytes.NewBuffer(nil), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: signed wrong sth", // newLogParameters() use testdata.Ed25519VkLog as default + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog2), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), + wantCode: http.StatusBadRequest, + }, + { + description: "valid", + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), + wantCode: http.StatusOK, + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + url := EndpointAddCosignature.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + w := httptest.NewRecorder() + ti.postHandler(t, EndpointAddCosignature).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetLatestSth(t *testing.T) { + for _, table := range []struct { + description string + trsp *trillian.GetLatestSignedLogRootResponse + terr error + wantCode int + wantItem *types.StItem + }{ + { + description: "backend failure", + terr: fmt.Errorf("backend failure"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + trsp: testdata.DefaultTSlr(t), + wantCode: http.StatusOK, + wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, cttestdata.NewSignerWithFixedSig(nil, testdata.Signature)) + ti.ctrl.Finish() + + // Setup and run client query + url := EndpointGetLatestSth.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + if table.trsp != nil || table.terr != nil { + ti.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) + } + + w := httptest.NewRecorder() + ti.getHandler(t, EndpointGetLatestSth).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItem + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetStableSth(t *testing.T) { + for _, table := range []struct { + description string + useBadSource bool + wantCode int + wantItem *types.StItem + }{ + { + description: "invalid: sth source failure", + useBadSource: true, + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + wantCode: http.StatusOK, + wantItem: testdata.DefaultSth(t, testdata.Ed25519VkLog), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + ti.ctrl.Finish() + if table.useBadSource { + ti.instance.SthSource = &ActiveSthSource{} + } + + // Setup and run client query + url := EndpointGetStableSth.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + + w := httptest.NewRecorder() + ti.getHandler(t, EndpointGetStableSth).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItem + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetCosignedSth(t *testing.T) { + for _, table := range []struct { + description string + useBadSource bool + wantCode int + wantItem *types.StItem + }{ + { + description: "invalid: sth source failure", + useBadSource: true, + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + wantCode: http.StatusOK, + wantItem: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + ti.ctrl.Finish() + if table.useBadSource { + ti.instance.SthSource = &ActiveSthSource{} + } + + // Setup and run client query + url := EndpointGetCosignedSth.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + + w := httptest.NewRecorder() + ti.getHandler(t, EndpointGetCosignedSth).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItem + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetProofByHash(t *testing.T) { + for _, table := range []struct { + description string + breq *bytes.Buffer + trsp *trillian.GetInclusionProofByHashResponse + terr error + wantCode int + wantItem *types.StItem + }{ + { + description: "invalid: bad request: empty", + breq: bytes.NewBuffer(nil), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad Trillian response: error", + breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), + terr: fmt.Errorf("backend failure"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + breq: bytes.NewBuffer(marshal(t, types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash})), + trsp: testdata.DefaultTGipbhr(t), + wantCode: http.StatusOK, + wantItem: testdata.DefaultInclusionProof(t, 1), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + url := EndpointGetProofByHash.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + if table.trsp != nil || table.terr != nil { + ti.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) + } + + w := httptest.NewRecorder() + ti.postHandler(t, EndpointGetProofByHash).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItem + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetConsistencyProof(t *testing.T) { + for _, table := range []struct { + description string + breq *bytes.Buffer + trsp *trillian.GetConsistencyProofResponse + terr error + wantCode int + wantItem *types.StItem + }{ + { + description: "invalid: bad request: empty", + breq: bytes.NewBuffer(nil), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad Trillian response: error", + breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), + terr: fmt.Errorf("backend failure"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", + breq: bytes.NewBuffer(marshal(t, types.GetConsistencyProofV1{First: 1, Second: 2})), + trsp: testdata.DefaultTGcpr(t), + wantCode: http.StatusOK, + wantItem: testdata.DefaultConsistencyProof(t, 1, 2), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + url := EndpointGetConsistencyProof.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + if table.trsp != nil || table.terr != nil { + ti.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) + } + + w := httptest.NewRecorder() + ti.postHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItem + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +func TestEndpointGetEntriesV1(t *testing.T) { + for _, table := range []struct { + description string + breq *bytes.Buffer + trsp *trillian.GetLeavesByRangeResponse + terr error + wantCode int + wantItem *types.StItemList + }{ + { + description: "invalid: bad request: empty", + breq: bytes.NewBuffer(nil), + wantCode: http.StatusBadRequest, + }, + { + description: "invalid: bad Trillian response: error", + breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: 0})), + terr: fmt.Errorf("backend failure"), + wantCode: http.StatusInternalServerError, + }, + { + description: "valid", // remember that newLogParameters() have testdata.MaxRange configured + breq: bytes.NewBuffer(marshal(t, types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange - 1)})), + trsp: testdata.DefaultTGlbrr(t, 0, testdata.MaxRange-1), + wantCode: http.StatusOK, + wantItem: testdata.DefaultStItemList(t, 0, uint64(testdata.MaxRange)-1), + }, + } { + func() { // run deferred functions at the end of each iteration + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + url := EndpointGetEntries.Path("http://example.com", ti.instance.LogParameters.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("must create http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + if table.trsp != nil || table.terr != nil { + ti.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) + } + + w := httptest.NewRecorder() + ti.postHandler(t, EndpointGetEntries).ServeHTTP(w, req) + if got, want := w.Code, table.wantCode; got != want { + t.Errorf("got error code %d but wanted %d in test %q", got, want, table.description) + } + if w.Code != http.StatusOK { + return + } + + var item types.StItemList + if err := types.Unmarshal([]byte(w.Body.String()), &item); err != nil { + t.Errorf("valid response cannot be unmarshalled in test %q: %v", table.description, err) + } + if got, want := item, *table.wantItem; !reflect.DeepEqual(got, want) { + t.Errorf("got item\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) + } + }() + } +} + +// TODO: TestWriteOctetResponse +func TestWriteOctetResponse(t *testing.T) { +} + +// deadlineMatcher implements gomock.Matcher, such that an error is raised if +// there is no context.Context deadline set +type deadlineMatcher struct{} + +// newDeadlineMatcher returns a new DeadlineMatcher +func newDeadlineMatcher() gomock.Matcher { + return &deadlineMatcher{} +} + +// Matches returns true if the passed interface is a context with a deadline +func (dm *deadlineMatcher) Matches(i interface{}) bool { + ctx, ok := i.(context.Context) + if !ok { + return false + } + _, ok = ctx.Deadline() + return ok +} + +// String is needed to implement gomock.Matcher +func (dm *deadlineMatcher) String() string { + return fmt.Sprintf("deadlineMatcher{}") +} diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go new file mode 100644 index 0000000..3441a0a --- /dev/null +++ b/pkg/instance/instance.go @@ -0,0 +1,90 @@ +package stfe + +import ( + "context" + "crypto" + "fmt" + "net/http" + "time" + + "github.com/golang/glog" + "github.com/system-transparency/stfe/pkg/state" + "github.com/system-transparency/stfe/pkg/trillian" + "github.com/system-transparency/stfe/pkg/types" +) + +// Config is a collection of log parameters +type Config struct { + LogID string // H(public key), then hex-encoded + TreeID int64 // Merkle tree identifier used by Trillian + Prefix string // The portion between base URL and st/v0 (may be "") + MaxRange int64 // Maximum number of leaves per get-leaves request + Deadline time.Duration // Deadline used for gRPC requests + Interval time.Duration // Cosigning frequency + + // Witnesses map trusted witness identifiers to public verification keys + Witnesses map[[types.HashSize]byte][types.VerificationKeySize]byte +} + +// Instance is an instance of the log's front-end +type Instance struct { + Config // configuration parameters + Client trillian.Client // provides access to the Trillian backend + Signer crypto.Signer // provides access to Ed25519 private key + Stateman state.StateManager // coordinates access to (co)signed tree heads +} + +// Handler implements the http.Handler interface, and contains a reference +// to an STFE server instance as well as a function that uses it. +type Handler struct { + Instance *Instance + Endpoint types.Endpoint + Method string + Handler func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error) +} + +// Handlers returns a list of STFE handlers +func (i *Instance) Handlers() []Handler { + return []Handler{ + Handler{Instance: i, Handler: addLeaf, Endpoint: types.EndpointAddLeaf, Method: http.MethodPost}, + Handler{Instance: i, Handler: addCosignature, Endpoint: types.EndpointAddCosignature, Method: http.MethodPost}, + Handler{Instance: i, Handler: getTreeHeadLatest, Endpoint: types.EndpointGetTreeHeadLatest, Method: http.MethodGet}, + Handler{Instance: i, Handler: getTreeHeadToSign, Endpoint: types.EndpointGetTreeHeadToSign, Method: http.MethodGet}, + Handler{Instance: i, Handler: getTreeHeadCosigned, Endpoint: types.EndpointGetTreeHeadCosigned, Method: http.MethodGet}, + Handler{Instance: i, Handler: getConsistencyProof, Endpoint: types.EndpointGetConsistencyProof, Method: http.MethodPost}, + Handler{Instance: i, Handler: getInclusionProof, Endpoint: types.EndpointGetProofByHash, Method: http.MethodPost}, + Handler{Instance: i, Handler: getLeaves, Endpoint: types.EndpointGetLeaves, Method: http.MethodPost}, + } +} + +// Path returns a path that should be configured for this handler +func (h Handler) Path() string { + return h.Endpoint.Path(h.Instance.Prefix, "st", "v0") +} + +// ServeHTTP is part of the http.Handler interface +func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // export prometheus metrics + var now time.Time = time.Now() + var statusCode int + defer func() { + rspcnt.Inc(a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) + latency.Observe(time.Now().Sub(now).Seconds(), a.Instance.LogID, string(a.Endpoint), fmt.Sprintf("%d", statusCode)) + }() + reqcnt.Inc(a.Instance.LogID, string(a.Endpoint)) + + ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.Instance.Deadline)) + defer cancel() + + if r.Method != a.Method { + glog.Warningf("%s/%s: got HTTP %s, wanted HTTP %s", a.Instance.Prefix, string(a.Endpoint), r.Method, a.Method) + http.Error(w, "", http.StatusMethodNotAllowed) + return + } + + statusCode, err := a.Handler(ctx, a.Instance, w, r) + if err != nil { + glog.Warningf("handler error %s/%s: %v", a.Instance.Prefix, a.Endpoint, err) + http.Error(w, fmt.Sprintf("%s%s%s%s", "Error", types.Delim, err.Error(), types.EOL), statusCode) + } +} diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go new file mode 100644 index 0000000..a7a3d8a --- /dev/null +++ b/pkg/instance/instance_test.go @@ -0,0 +1,158 @@ +package stfe + +import ( + "crypto" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/certificate-transparency-go/trillian/mockclient" + "github.com/system-transparency/stfe/pkg/testdata" + "github.com/system-transparency/stfe/pkg/types" +) + +type testInstance struct { + ctrl *gomock.Controller + client *mockclient.MockTrillianLogClient + instance *Instance +} + +// newTestInstances sets up a test instance that uses default log parameters +// with an optional signer, see newLogParameters() for further details. The +// SthSource is instantiated with an ActiveSthSource that has (i) the default +// STH as the currently cosigned STH based on testdata.Ed25519VkWitness, and +// (ii) the default STH without any cosignatures as the currently stable STH. +func newTestInstance(t *testing.T, signer crypto.Signer) *testInstance { + t.Helper() + ctrl := gomock.NewController(t) + client := mockclient.NewMockTrillianLogClient(ctrl) + return &testInstance{ + ctrl: ctrl, + client: client, + instance: &Instance{ + Client: client, + LogParameters: newLogParameters(t, signer), + SthSource: &ActiveSthSource{ + client: client, + logParameters: newLogParameters(t, signer), + currCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), + nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), + cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool), + }, + }, + } +} + +// getHandlers returns all endpoints that use HTTP GET as a map to handlers +func (ti *testInstance) getHandlers(t *testing.T) map[Endpoint]Handler { + t.Helper() + return map[Endpoint]Handler{ + EndpointGetLatestSth: Handler{Instance: ti.instance, Handler: getLatestSth, Endpoint: EndpointGetLatestSth, Method: http.MethodGet}, + EndpointGetStableSth: Handler{Instance: ti.instance, Handler: getStableSth, Endpoint: EndpointGetStableSth, Method: http.MethodGet}, + EndpointGetCosignedSth: Handler{Instance: ti.instance, Handler: getCosignedSth, Endpoint: EndpointGetCosignedSth, Method: http.MethodGet}, + } +} + +// postHandlers returns all endpoints that use HTTP POST as a map to handlers +func (ti *testInstance) postHandlers(t *testing.T) map[Endpoint]Handler { + t.Helper() + return map[Endpoint]Handler{ + EndpointAddEntry: Handler{Instance: ti.instance, Handler: addEntry, Endpoint: EndpointAddEntry, Method: http.MethodPost}, + EndpointAddCosignature: Handler{Instance: ti.instance, Handler: addCosignature, Endpoint: EndpointAddCosignature, Method: http.MethodPost}, + EndpointGetConsistencyProof: Handler{Instance: ti.instance, Handler: getConsistencyProof, Endpoint: EndpointGetConsistencyProof, Method: http.MethodPost}, + EndpointGetProofByHash: Handler{Instance: ti.instance, Handler: getProofByHash, Endpoint: EndpointGetProofByHash, Method: http.MethodPost}, + EndpointGetEntries: Handler{Instance: ti.instance, Handler: getEntries, Endpoint: EndpointGetEntries, Method: http.MethodPost}, + } +} + +// getHandler must return a particular HTTP GET handler +func (ti *testInstance) getHandler(t *testing.T, endpoint Endpoint) Handler { + t.Helper() + handler, ok := ti.getHandlers(t)[endpoint] + if !ok { + t.Fatalf("must return HTTP GET handler for endpoint: %s", endpoint) + } + return handler +} + +// postHandler must return a particular HTTP POST handler +func (ti *testInstance) postHandler(t *testing.T, endpoint Endpoint) Handler { + t.Helper() + handler, ok := ti.postHandlers(t)[endpoint] + if !ok { + t.Fatalf("must return HTTP POST handler for endpoint: %s", endpoint) + } + return handler +} + +// TestHandlers checks that we configured all endpoints and that there are no +// unexpected ones. +func TestHandlers(t *testing.T) { + endpoints := map[Endpoint]bool{ + EndpointAddEntry: false, + EndpointAddCosignature: false, + EndpointGetLatestSth: false, + EndpointGetStableSth: false, + EndpointGetCosignedSth: false, + EndpointGetConsistencyProof: false, + EndpointGetProofByHash: false, + EndpointGetEntries: false, + } + i := &Instance{nil, newLogParameters(t, nil), nil} + for _, handler := range i.Handlers() { + if _, ok := endpoints[handler.Endpoint]; !ok { + t.Errorf("got unexpected endpoint: %s", handler.Endpoint) + } + endpoints[handler.Endpoint] = true + } + for endpoint, ok := range endpoints { + if !ok { + t.Errorf("endpoint %s is not configured", endpoint) + } + } +} + +// TestGetHandlersRejectPost checks that all get handlers reject post requests +func TestGetHandlersRejectPost(t *testing.T) { + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + for endpoint, handler := range ti.getHandlers(t) { + t.Run(string(endpoint), func(t *testing.T) { + s := httptest.NewServer(handler) + defer s.Close() + + url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix) + if rsp, err := http.Post(url, "application/json", nil); err != nil { + t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err) + } else if rsp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) + } + }) + } +} + +// TestPostHandlersRejectGet checks that all post handlers reject get requests +func TestPostHandlersRejectGet(t *testing.T) { + ti := newTestInstance(t, nil) + defer ti.ctrl.Finish() + + for endpoint, handler := range ti.postHandlers(t) { + t.Run(string(endpoint), func(t *testing.T) { + s := httptest.NewServer(handler) + defer s.Close() + + url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix) + if rsp, err := http.Get(url); err != nil { + t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err) + } else if rsp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed) + } + }) + } +} + +// TODO: TestHandlerPath +func TestHandlerPath(t *testing.T) { +} diff --git a/pkg/instance/metric.go b/pkg/instance/metric.go new file mode 100644 index 0000000..7e3e8b2 --- /dev/null +++ b/pkg/instance/metric.go @@ -0,0 +1,23 @@ +package stfe + +import ( + "github.com/google/trillian/monitoring" + "github.com/google/trillian/monitoring/prometheus" +) + +var ( + reqcnt monitoring.Counter // number of incoming http requests + rspcnt monitoring.Counter // number of valid http responses + latency monitoring.Histogram // request-response latency + lastSthTimestamp monitoring.Gauge // unix timestamp from the most recent sth + lastSthSize monitoring.Gauge // tree size of most recent sth +) + +func init() { + mf := prometheus.MetricFactory{} + reqcnt = mf.NewCounter("http_req", "number of http requests", "logid", "endpoint") + rspcnt = mf.NewCounter("http_rsp", "number of http requests", "logid", "endpoint", "status") + latency = mf.NewHistogram("http_latency", "http request-response latency", "logid", "endpoint", "status") + lastSthTimestamp = mf.NewGauge("last_sth_timestamp", "unix timestamp while handling the most recent sth", "logid") + lastSthSize = mf.NewGauge("last_sth_size", "most recent sth tree size", "logid") +} diff --git a/pkg/instance/request.go b/pkg/instance/request.go new file mode 100644 index 0000000..7475b26 --- /dev/null +++ b/pkg/instance/request.go @@ -0,0 +1,77 @@ +package stfe + +import ( + "crypto/ed25519" + "fmt" + "net/http" + + "github.com/system-transparency/stfe/pkg/types" +) + +func (i *Instance) leafRequestFromHTTP(r *http.Request) (*types.LeafRequest, error) { + var req types.LeafRequest + if err := req.UnmarshalASCII(r.Body); err != nil { + return nil, fmt.Errorf("UnmarshalASCII: %v", err) + } + + vk := ed25519.PublicKey(req.VerificationKey[:]) + msg := req.Message.Marshal() + sig := req.Signature[:] + if !ed25519.Verify(vk, msg, sig) { + return nil, fmt.Errorf("invalid signature") + } + // TODO: check shard hint + // TODO: check domain hint + return &req, nil +} + +func (i *Instance) cosignatureRequestFromHTTP(r *http.Request) (*types.CosignatureRequest, error) { + var req types.CosignatureRequest + if err := req.UnmarshalASCII(r.Body); err != nil { + return nil, fmt.Errorf("unpackOctetPost: %v", err) + } + if _, ok := i.Witnesses[*req.KeyHash]; !ok { + return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash) + } + return &req, nil +} + +func (i *Instance) consistencyProofRequestFromHTTP(r *http.Request) (*types.ConsistencyProofRequest, error) { + var req types.ConsistencyProofRequest + if err := req.UnmarshalASCII(r.Body); err != nil { + return nil, fmt.Errorf("UnmarshalASCII: %v", err) + } + if req.OldSize < 1 { + return nil, fmt.Errorf("OldSize(%d) must be larger than zero", req.OldSize) + } + if req.NewSize <= req.OldSize { + return nil, fmt.Errorf("NewSize(%d) must be larger than OldSize(%d)", req.NewSize, req.OldSize) + } + return &req, nil +} + +func (i *Instance) inclusionProofRequestFromHTTP(r *http.Request) (*types.InclusionProofRequest, error) { + var req types.InclusionProofRequest + if err := req.UnmarshalASCII(r.Body); err != nil { + return nil, fmt.Errorf("UnmarshalASCII: %v", err) + } + if req.TreeSize < 1 { + return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize) + } + return &req, nil +} + +func (i *Instance) leavesRequestFromHTTP(r *http.Request) (*types.LeavesRequest, error) { + var req types.LeavesRequest + if err := req.UnmarshalASCII(r.Body); err != nil { + return nil, fmt.Errorf("UnmarshalASCII: %v", err) + } + + if req.StartSize > req.EndSize { + return nil, fmt.Errorf("StartSize(%d) must be less than or equal to EndSize(%d)", req.StartSize, req.EndSize) + } + if req.EndSize-req.StartSize+1 > uint64(i.MaxRange) { + req.EndSize = req.StartSize + uint64(i.MaxRange) - 1 + } + return &req, nil +} diff --git a/pkg/instance/request_test.go b/pkg/instance/request_test.go new file mode 100644 index 0000000..0a5a908 --- /dev/null +++ b/pkg/instance/request_test.go @@ -0,0 +1,318 @@ +package stfe + +import ( + "bytes" + //"fmt" + "reflect" + "testing" + //"testing/iotest" + + "net/http" + + "github.com/system-transparency/stfe/pkg/testdata" + "github.com/system-transparency/stfe/pkg/types" +) + +func TestParseAddEntryV1Request(t *testing.T) { + lp := newLogParameters(t, nil) + for _, table := range []struct { + description string + breq *bytes.Buffer + wantErr bool + }{ + { + description: "invalid: nothing to unpack", + breq: bytes.NewBuffer(nil), + wantErr: true, + }, + { + description: "invalid: not a signed checksum entry", + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), + wantErr: true, + }, + { + description: "invalid: untrusted submitter", // only testdata.Ed25519VkSubmitter is registered by default in newLogParameters() + + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter2), + wantErr: true, + }, + { + description: "invalid: signature does not cover message", + + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter), + wantErr: true, + }, + { + description: "valid", + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), + }, // TODO: add test case that disables submitter policy (i.e., unregistered namespaces are accepted) + } { + url := EndpointAddEntry.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + _, err = lp.parseAddEntryV1Request(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + } +} + +func TestParseAddCosignatureV1Request(t *testing.T) { + lp := newLogParameters(t, nil) + for _, table := range []struct { + description string + breq *bytes.Buffer + wantErr bool + }{ + { + description: "invalid: nothing to unpack", + breq: bytes.NewBuffer(nil), + wantErr: true, + }, + { + description: "invalid: not a cosigned sth", + breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), + wantErr: true, + }, + { + description: "invalid: no cosignature", + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, nil), + wantErr: true, + }, + { + description: "invalid: untrusted witness", // only testdata.Ed25519VkWitness is registered by default in newLogParameters() + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness2), + wantErr: true, + }, + { + description: "invalid: signature does not cover message", + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness), + wantErr: true, + }, + { + description: "valid", + breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), + }, // TODO: add test case that disables witness policy (i.e., unregistered namespaces are accepted) + } { + url := EndpointAddCosignature.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, table.breq) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + _, err = lp.parseAddCosignatureV1Request(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + } +} + +func TestNewGetConsistencyProofRequest(t *testing.T) { + lp := newLogParameters(t, nil) + for _, table := range []struct { + description string + req *types.GetConsistencyProofV1 + wantErr bool + }{ + { + description: "invalid: nothing to unpack", + req: nil, + wantErr: true, + }, + { + description: "invalid: first must be larger than zero", + req: &types.GetConsistencyProofV1{First: 0, Second: 0}, + wantErr: true, + }, + { + description: "invalid: second must be larger than first", + req: &types.GetConsistencyProofV1{First: 2, Second: 1}, + wantErr: true, + }, + { + description: "valid", + req: &types.GetConsistencyProofV1{First: 1, Second: 2}, + }, + } { + var buf *bytes.Buffer + if table.req == nil { + buf = bytes.NewBuffer(nil) + } else { + buf = bytes.NewBuffer(marshal(t, *table.req)) + } + + url := EndpointGetConsistencyProof.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, buf) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + _, err = lp.parseGetConsistencyProofV1Request(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + } +} + +func TestNewGetProofByHashRequest(t *testing.T) { + lp := newLogParameters(t, nil) + for _, table := range []struct { + description string + req *types.GetProofByHashV1 + wantErr bool + }{ + { + description: "invalid: nothing to unpack", + req: nil, + wantErr: true, + }, + { + description: "invalid: no entry in an empty tree", + req: &types.GetProofByHashV1{TreeSize: 0, Hash: testdata.LeafHash}, + wantErr: true, + }, + { + description: "valid", + req: &types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash}, + }, + } { + var buf *bytes.Buffer + if table.req == nil { + buf = bytes.NewBuffer(nil) + } else { + buf = bytes.NewBuffer(marshal(t, *table.req)) + } + + url := EndpointGetProofByHash.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, buf) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + _, err = lp.parseGetProofByHashV1Request(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + } +} + +func TestParseGetEntriesV1Request(t *testing.T) { + lp := newLogParameters(t, nil) + for _, table := range []struct { + description string + req *types.GetEntriesV1 + wantErr bool + wantReq *types.GetEntriesV1 + }{ + { + description: "invalid: nothing to unpack", + req: nil, + wantErr: true, + }, + { + description: "invalid: start must be larger than end", + req: &types.GetEntriesV1{Start: 1, End: 0}, + wantErr: true, + }, + { + description: "valid: want truncated range", + req: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange)}, + wantReq: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange) - 1}, + }, + { + description: "valid", + req: &types.GetEntriesV1{Start: 0, End: 0}, + wantReq: &types.GetEntriesV1{Start: 0, End: 0}, + }, + } { + var buf *bytes.Buffer + if table.req == nil { + buf = bytes.NewBuffer(nil) + } else { + buf = bytes.NewBuffer(marshal(t, *table.req)) + } + + url := EndpointGetEntries.Path("http://example.com", lp.Prefix) + req, err := http.NewRequest("POST", url, buf) + if err != nil { + t.Fatalf("failed creating http request: %v", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + output, err := lp.parseGetEntriesV1Request(req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := output, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got request\n%v\n\tbut wanted\n%v\n\t in test %q", got, want, table.description) + } + } +} + +func TestUnpackOctetPost(t *testing.T) { + for _, table := range []struct { + description string + req *http.Request + out interface{} + wantErr bool + }{ + //{ + // description: "invalid: cannot read request body", + // req: func() *http.Request { + // req, err := http.NewRequest(http.MethodPost, "", iotest.ErrReader(fmt.Errorf("bad reader"))) + // if err != nil { + // t.Fatalf("must make new http request: %v", err) + // } + // return req + // }(), + // out: &types.StItem{}, + // wantErr: true, + //}, // testcase requires Go 1.16 + { + description: "invalid: cannot unmarshal", + req: func() *http.Request { + req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer(nil)) + if err != nil { + t.Fatalf("must make new http request: %v", err) + } + return req + }(), + out: &types.StItem{}, + wantErr: true, + }, + { + description: "valid", + req: func() *http.Request { + req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer([]byte{0})) + if err != nil { + t.Fatalf("must make new http request: %v", err) + } + return req + }(), + out: &struct{ SomeUint8 uint8 }{}, + }, + } { + err := unpackOctetPost(table.req, table.out) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) + } + } +} + +func marshal(t *testing.T, out interface{}) []byte { + b, err := types.Marshal(out) + if err != nil { + t.Fatalf("must marshal: %v", err) + } + return b +} diff --git a/pkg/mocks/crypto.go b/pkg/mocks/crypto.go new file mode 100644 index 0000000..87c883a --- /dev/null +++ b/pkg/mocks/crypto.go @@ -0,0 +1,23 @@ +package mocks + +import ( + "crypto" + "crypto/ed25519" + "io" +) + +// TestSign implements the signer interface. It can be used to mock an Ed25519 +// signer that always return the same public key, signature, and error. +type TestSigner struct { + PublicKey *[ed25519.PublicKeySize]byte + Signature *[ed25519.SignatureSize]byte + Error error +} + +func (ts *TestSigner) Public() crypto.PublicKey { + return ed25519.PublicKey(ts.PublicKey[:]) +} + +func (ts *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return ts.Signature[:], ts.Error +} diff --git a/pkg/mocks/stfe.go b/pkg/mocks/stfe.go new file mode 100644 index 0000000..def5bc6 --- /dev/null +++ b/pkg/mocks/stfe.go @@ -0,0 +1,110 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/system-transparency/stfe/trillian (interfaces: Client) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + types "github.com/system-transparency/stfe/pkg/types" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// AddLeaf mocks base method. +func (m *MockClient) AddLeaf(arg0 context.Context, arg1 *types.LeafRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddLeaf", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddLeaf indicates an expected call of AddLeaf. +func (mr *MockClientMockRecorder) AddLeaf(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLeaf", reflect.TypeOf((*MockClient)(nil).AddLeaf), arg0, arg1) +} + +// GetConsistencyProof mocks base method. +func (m *MockClient) GetConsistencyProof(arg0 context.Context, arg1 *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConsistencyProof", arg0, arg1) + ret0, _ := ret[0].(*types.ConsistencyProof) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockClient)(nil).GetConsistencyProof), arg0, arg1) +} + +// GetInclusionProof mocks base method. +func (m *MockClient) GetInclusionProof(arg0 context.Context, arg1 *types.InclusionProofRequest) (*types.InclusionProof, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInclusionProof", arg0, arg1) + ret0, _ := ret[0].(*types.InclusionProof) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockClient)(nil).GetInclusionProof), arg0, arg1) +} + +// GetLeaves mocks base method. +func (m *MockClient) GetLeaves(arg0 context.Context, arg1 *types.LeavesRequest) (*types.LeafList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLeaves", arg0, arg1) + ret0, _ := ret[0].(*types.LeafList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeaves indicates an expected call of GetLeaves. +func (mr *MockClientMockRecorder) GetLeaves(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeaves", reflect.TypeOf((*MockClient)(nil).GetLeaves), arg0, arg1) +} + +// GetTreeHead mocks base method. +func (m *MockClient) GetTreeHead(arg0 context.Context) (*types.TreeHead, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTreeHead", arg0) + ret0, _ := ret[0].(*types.TreeHead) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTreeHead indicates an expected call of GetTreeHead. +func (mr *MockClientMockRecorder) GetTreeHead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTreeHead", reflect.TypeOf((*MockClient)(nil).GetTreeHead), arg0) +} diff --git a/pkg/mocks/trillian.go b/pkg/mocks/trillian.go new file mode 100644 index 0000000..8aa3a58 --- /dev/null +++ b/pkg/mocks/trillian.go @@ -0,0 +1,317 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/google/trillian (interfaces: TrillianLogClient) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + trillian "github.com/google/trillian" + grpc "google.golang.org/grpc" +) + +// MockTrillianLogClient is a mock of TrillianLogClient interface. +type MockTrillianLogClient struct { + ctrl *gomock.Controller + recorder *MockTrillianLogClientMockRecorder +} + +// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient. +type MockTrillianLogClientMockRecorder struct { + mock *MockTrillianLogClient +} + +// NewMockTrillianLogClient creates a new mock instance. +func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient { + mock := &MockTrillianLogClient{ctrl: ctrl} + mock.recorder = &MockTrillianLogClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder { + return m.recorder +} + +// AddSequencedLeaf mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaf(arg0 context.Context, arg1 *trillian.AddSequencedLeafRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeafResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AddSequencedLeaf", varargs...) + ret0, _ := ret[0].(*trillian.AddSequencedLeafResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddSequencedLeaf indicates an expected call of AddSequencedLeaf. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaf), varargs...) +} + +// AddSequencedLeaves mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...) + ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...) +} + +// GetConsistencyProof mocks base method. +func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...) + ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...) +} + +// GetEntryAndProof mocks base method. +func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...) + ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEntryAndProof indicates an expected call of GetEntryAndProof. +func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...) +} + +// GetInclusionProof mocks base method. +func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetInclusionProof", varargs...) + ret0, _ := ret[0].(*trillian.GetInclusionProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...) +} + +// GetInclusionProofByHash mocks base method. +func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...) + ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...) +} + +// GetLatestSignedLogRoot mocks base method. +func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...) + ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot. +func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...) +} + +// GetLeavesByHash mocks base method. +func (m *MockTrillianLogClient) GetLeavesByHash(arg0 context.Context, arg1 *trillian.GetLeavesByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByHashResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLeavesByHash", varargs...) + ret0, _ := ret[0].(*trillian.GetLeavesByHashResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeavesByHash indicates an expected call of GetLeavesByHash. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByHash), varargs...) +} + +// GetLeavesByIndex mocks base method. +func (m *MockTrillianLogClient) GetLeavesByIndex(arg0 context.Context, arg1 *trillian.GetLeavesByIndexRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByIndexResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLeavesByIndex", varargs...) + ret0, _ := ret[0].(*trillian.GetLeavesByIndexResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeavesByIndex indicates an expected call of GetLeavesByIndex. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByIndex(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByIndex", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByIndex), varargs...) +} + +// GetLeavesByRange mocks base method. +func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...) + ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeavesByRange indicates an expected call of GetLeavesByRange. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...) +} + +// GetSequencedLeafCount mocks base method. +func (m *MockTrillianLogClient) GetSequencedLeafCount(arg0 context.Context, arg1 *trillian.GetSequencedLeafCountRequest, arg2 ...grpc.CallOption) (*trillian.GetSequencedLeafCountResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetSequencedLeafCount", varargs...) + ret0, _ := ret[0].(*trillian.GetSequencedLeafCountResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSequencedLeafCount indicates an expected call of GetSequencedLeafCount. +func (mr *MockTrillianLogClientMockRecorder) GetSequencedLeafCount(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequencedLeafCount", reflect.TypeOf((*MockTrillianLogClient)(nil).GetSequencedLeafCount), varargs...) +} + +// InitLog mocks base method. +func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "InitLog", varargs...) + ret0, _ := ret[0].(*trillian.InitLogResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InitLog indicates an expected call of InitLog. +func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...) +} + +// QueueLeaf mocks base method. +func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueueLeaf", varargs...) + ret0, _ := ret[0].(*trillian.QueueLeafResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueueLeaf indicates an expected call of QueueLeaf. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...) +} + +// QueueLeaves mocks base method. +func (m *MockTrillianLogClient) QueueLeaves(arg0 context.Context, arg1 *trillian.QueueLeavesRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeavesResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueueLeaves", varargs...) + ret0, _ := ret[0].(*trillian.QueueLeavesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueueLeaves indicates an expected call of QueueLeaves. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaves), varargs...) +} diff --git a/pkg/state/state_manager.go b/pkg/state/state_manager.go new file mode 100644 index 0000000..dfa73f4 --- /dev/null +++ b/pkg/state/state_manager.go @@ -0,0 +1,154 @@ +package state + +import ( + "context" + "crypto" + "fmt" + "reflect" + "sync" + "time" + + "github.com/golang/glog" + "github.com/google/certificate-transparency-go/schedule" + "github.com/system-transparency/stfe/pkg/trillian" + "github.com/system-transparency/stfe/pkg/types" +) + +// StateManager coordinates access to the log's tree heads and (co)signatures +type StateManager interface { + Latest(context.Context) (*types.SignedTreeHead, error) + ToSign(context.Context) (*types.SignedTreeHead, error) + Cosigned(context.Context) (*types.SignedTreeHead, error) + AddCosignature(context.Context, *[types.VerificationKeySize]byte, *[types.SignatureSize]byte) error + Run(context.Context) +} + +// StateManagerSingle implements the StateManager interface. It is assumed that +// the log server is running on a single-instance machine. So, no coordination. +type StateManagerSingle struct { + client trillian.Client + signer crypto.Signer + interval time.Duration + deadline time.Duration + sync.RWMutex + + // cosigned is the current cosigned tree head that is being served + cosigned types.SignedTreeHead + + // tosign is the current tree head that is being cosigned by witnesses + tosign types.SignedTreeHead + + // cosignature keeps track of all cosignatures for the tosign tree head + cosignature map[[types.HashSize]byte]*types.SigIdent +} + +func NewStateManagerSingle(client trillian.Client, signer crypto.Signer, interval, deadline time.Duration) (*StateManagerSingle, error) { + sm := &StateManagerSingle{ + client: client, + signer: signer, + interval: interval, + deadline: deadline, + } + + ctx, _ := context.WithTimeout(context.Background(), sm.deadline) + sth, err := sm.Latest(ctx) + if err != nil { + return nil, fmt.Errorf("Latest: %v", err) + } + + sm.cosigned = *sth + sm.tosign = *sth + sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ + *sth.SigIdent[0].KeyHash: sth.SigIdent[0], // log signature + } + return sm, nil +} + +func (sm *StateManagerSingle) Run(ctx context.Context) { + schedule.Every(ctx, sm.interval, func(ctx context.Context) { + ictx, _ := context.WithTimeout(ctx, sm.deadline) + nextTreeHead, err := sm.Latest(ictx) + if err != nil { + glog.Warningf("rotate failed: Latest: %v", err) + return + } + + sm.Lock() + defer sm.Unlock() + sm.rotate(nextTreeHead) + }) +} + +func (sm *StateManagerSingle) Latest(ctx context.Context) (*types.SignedTreeHead, error) { + th, err := sm.client.GetTreeHead(ctx) + if err != nil { + return nil, fmt.Errorf("LatestTreeHead: %v", err) + } + sth, err := th.Sign(sm.signer) + if err != nil { + return nil, fmt.Errorf("sign: %v", err) + } + return sth, nil +} + +func (sm *StateManagerSingle) ToSign(_ context.Context) (*types.SignedTreeHead, error) { + sm.RLock() + defer sm.RUnlock() + return &sm.tosign, nil +} + +func (sm *StateManagerSingle) Cosigned(_ context.Context) (*types.SignedTreeHead, error) { + sm.RLock() + defer sm.RUnlock() + return &sm.cosigned, nil +} + +func (sm *StateManagerSingle) AddCosignature(_ context.Context, vk *[types.VerificationKeySize]byte, sig *[types.SignatureSize]byte) error { + sm.Lock() + defer sm.Unlock() + + if err := sm.tosign.TreeHead.Verify(vk, sig); err != nil { + return fmt.Errorf("Verify: %v", err) + } + witness := types.Hash(vk[:]) + if _, ok := sm.cosignature[*witness]; ok { + return fmt.Errorf("signature-signer pair is a duplicate") + } + sm.cosignature[*witness] = &types.SigIdent{ + Signature: sig, + KeyHash: witness, + } + + glog.V(3).Infof("accepted new cosignature from witness: %x", *witness) + return nil +} + +// rotate rotates the log's cosigned and stable STH. The caller must aquire the +// source's read-write lock if there are concurrent reads and/or writes. +func (sm *StateManagerSingle) rotate(next *types.SignedTreeHead) { + if reflect.DeepEqual(sm.cosigned.TreeHead, sm.tosign.TreeHead) { + // cosigned and tosign are the same. So, we need to merge all + // cosignatures that we already had with the new collected ones. + for _, sigident := range sm.cosigned.SigIdent { + if _, ok := sm.cosignature[*sigident.KeyHash]; !ok { + sm.cosignature[*sigident.KeyHash] = sigident + } + } + glog.V(3).Infof("cosigned tree head repeated, merged signatures") + } + var cosignatures []*types.SigIdent + for _, sigident := range sm.cosignature { + cosignatures = append(cosignatures, sigident) + } + + // Update cosigned tree head + sm.cosigned.TreeHead = sm.tosign.TreeHead + sm.cosigned.SigIdent = cosignatures + + // Update to-sign tree head + sm.tosign = *next + sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ + *next.SigIdent[0].KeyHash: next.SigIdent[0], // log signature + } + glog.V(3).Infof("rotated tree heads") +} diff --git a/pkg/state/state_manager_test.go b/pkg/state/state_manager_test.go new file mode 100644 index 0000000..08990cc --- /dev/null +++ b/pkg/state/state_manager_test.go @@ -0,0 +1,393 @@ +package state + +import ( + "bytes" + "context" + "crypto" + "crypto/ed25519" + "crypto/rand" + "fmt" + "reflect" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/system-transparency/stfe/pkg/mocks" + "github.com/system-transparency/stfe/pkg/types" +) + +var ( + testSig = &[types.SignatureSize]byte{} + testPub = &[types.VerificationKeySize]byte{} + testTH = &types.TreeHead{ + Timestamp: 0, + TreeSize: 0, + RootHash: types.Hash(nil), + } + testSigIdent = &types.SigIdent{ + Signature: testSig, + KeyHash: types.Hash(testPub[:]), + } + testSTH = &types.SignedTreeHead{ + TreeHead: *testTH, + SigIdent: []*types.SigIdent{testSigIdent}, + } + testSignerOK = &mocks.TestSigner{testPub, testSig, nil} + testSignerErr = &mocks.TestSigner{testPub, testSig, fmt.Errorf("something went wrong")} +) + +func TestNewStateManagerSingle(t *testing.T) { + for _, table := range []struct { + description string + signer crypto.Signer + rsp *types.TreeHead + err error + wantErr bool + wantSth *types.SignedTreeHead + }{ + { + description: "invalid: backend failure", + signer: testSignerOK, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "valid", + signer: testSignerOK, + rsp: testTH, + wantSth: testSTH, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) + + sm, err := NewStateManagerSingle(client, table.signer, time.Duration(0), time.Duration(0)) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := &sm.cosigned, table.wantSth; !reflect.DeepEqual(got, want) { + t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + if got, want := &sm.tosign, table.wantSth; !reflect.DeepEqual(got, want) { + t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + // we only have log signature on startup + if got, want := len(sm.cosignature), 1; got != want { + t.Errorf("got %d cosignatures but wanted %d in test %q", got, want, table.description) + } + }() + } +} + +func TestLatest(t *testing.T) { + for _, table := range []struct { + description string + signer crypto.Signer + rsp *types.TreeHead + err error + wantErr bool + wantSth *types.SignedTreeHead + }{ + { + description: "invalid: backend failure", + signer: testSignerOK, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: signature failure", + rsp: testTH, + signer: testSignerErr, + wantErr: true, + }, + { + description: "valid", + signer: testSignerOK, + rsp: testTH, + wantSth: testSTH, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) + sm := StateManagerSingle{ + client: client, + signer: table.signer, + } + + sth, err := sm.Latest(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) { + t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestToSign(t *testing.T) { + description := "valid" + sm := StateManagerSingle{ + tosign: *testSTH, + } + sth, err := sm.ToSign(context.Background()) + if err != nil { + t.Errorf("ToSign should not fail with error: %v", err) + return + } + if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { + t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) + } +} + +func TestCosigned(t *testing.T) { + description := "valid" + sm := StateManagerSingle{ + cosigned: *testSTH, + } + sth, err := sm.Cosigned(context.Background()) + if err != nil { + t.Errorf("Cosigned should not fail with error: %v", err) + return + } + if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { + t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) + } +} + +func TestAddCosignature(t *testing.T) { + vk, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + if bytes.Equal(vk[:], testPub[:]) { + t.Fatalf("Sampled same key as testPub, aborting...") + } + var vkArray [types.VerificationKeySize]byte + copy(vkArray[:], vk[:]) + + for _, table := range []struct { + description string + signer crypto.Signer + vk *[types.VerificationKeySize]byte + th *types.TreeHead + wantErr bool + }{ + { + description: "invalid: signature error", + signer: sk, + vk: testPub, // wrong key for message + th: testTH, + wantErr: true, + }, + { + description: "valid", + signer: sk, + vk: &vkArray, + th: testTH, + }, + } { + sth, _ := table.th.Sign(testSignerOK) + logKeyHash := sth.SigIdent[0].KeyHash + logSigIdent := sth.SigIdent[0] + sm := &StateManagerSingle{ + signer: testSignerOK, + cosigned: *sth, + tosign: *sth, + cosignature: map[[types.HashSize]byte]*types.SigIdent{ + *logKeyHash: logSigIdent, + }, + } + + // Prepare witness signature + sth, err := table.th.Sign(table.signer) + if err != nil { + t.Fatalf("Sign: %v", err) + } + witnessKeyHash := sth.SigIdent[0].KeyHash + witnessSigIdent := sth.SigIdent[0] + + // Add witness signature + err = sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + + // We should have two signatures (log + witness) + if got, want := len(sm.cosignature), 2; got != want { + t.Errorf("got %d cosignatures but wanted %v in test %q", got, want, table.description) + continue + } + // check that log signature is there + sigident, ok := sm.cosignature[*logKeyHash] + if !ok { + t.Errorf("log signature is missing") + continue + } + if got, want := sigident, logSigIdent; !reflect.DeepEqual(got, want) { + t.Errorf("got log sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + // check that witness signature is there + sigident, ok = sm.cosignature[*witnessKeyHash] + if !ok { + t.Errorf("witness signature is missing") + continue + } + if got, want := sigident, witnessSigIdent; !reflect.DeepEqual(got, want) { + t.Errorf("got witness sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + continue + } + + // Adding a duplicate signature should give an error + if err := sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature); err == nil { + t.Errorf("duplicate witness signature accepted as valid") + } + } +} + +func TestRotate(t *testing.T) { + log := testSigIdent + wit1 := &types.SigIdent{ + Signature: testSig, + KeyHash: types.Hash([]byte("wit1 key")), + } + wit2 := &types.SigIdent{ + Signature: testSig, + KeyHash: types.Hash([]byte("wit2 key")), + } + th0 := testTH + th1 := &types.TreeHead{ + Timestamp: 1, + TreeSize: 1, + RootHash: types.Hash([]byte("1")), + } + th2 := &types.TreeHead{ + Timestamp: 2, + TreeSize: 2, + RootHash: types.Hash([]byte("2")), + } + + for _, table := range []struct { + description string + before, after *StateManagerSingle + next *types.SignedTreeHead + }{ + { + description: "tosign tree head repated, but got one new witnes signature", + before: &StateManagerSingle{ + cosigned: types.SignedTreeHead{ + TreeHead: *th0, + SigIdent: []*types.SigIdent{log, wit1}, + }, + tosign: types.SignedTreeHead{ + TreeHead: *th0, + SigIdent: []*types.SigIdent{log}, + }, + cosignature: map[[types.HashSize]byte]*types.SigIdent{ + *log.KeyHash: log, + *wit2.KeyHash: wit2, // the new witness signature + }, + }, + next: &types.SignedTreeHead{ + TreeHead: *th1, + SigIdent: []*types.SigIdent{log}, + }, + after: &StateManagerSingle{ + cosigned: types.SignedTreeHead{ + TreeHead: *th0, + SigIdent: []*types.SigIdent{log, wit1, wit2}, + }, + tosign: types.SignedTreeHead{ + TreeHead: *th1, + SigIdent: []*types.SigIdent{log}, + }, + cosignature: map[[types.HashSize]byte]*types.SigIdent{ + *log.KeyHash: log, // after rotate we always have log sig + }, + }, + }, + { + description: "tosign tree head did not repeat, it got one witness signature", + before: &StateManagerSingle{ + cosigned: types.SignedTreeHead{ + TreeHead: *th0, + SigIdent: []*types.SigIdent{log, wit1}, + }, + tosign: types.SignedTreeHead{ + TreeHead: *th1, + SigIdent: []*types.SigIdent{log}, + }, + cosignature: map[[types.HashSize]byte]*types.SigIdent{ + *log.KeyHash: log, + *wit2.KeyHash: wit2, // the only witness that signed tosign + }, + }, + next: &types.SignedTreeHead{ + TreeHead: *th2, + SigIdent: []*types.SigIdent{log}, + }, + after: &StateManagerSingle{ + cosigned: types.SignedTreeHead{ + TreeHead: *th1, + SigIdent: []*types.SigIdent{log, wit2}, + }, + tosign: types.SignedTreeHead{ + TreeHead: *th2, + SigIdent: []*types.SigIdent{log}, + }, + cosignature: map[[types.HashSize]byte]*types.SigIdent{ + *log.KeyHash: log, // after rotate we always have log sig + }, + }, + }, + } { + table.before.rotate(table.next) + if got, want := table.before.cosigned.TreeHead, table.after.cosigned.TreeHead; !reflect.DeepEqual(got, want) { + t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + checkWitnessList(t, table.description, table.before.cosigned.SigIdent, table.after.cosigned.SigIdent) + if got, want := table.before.tosign.TreeHead, table.after.tosign.TreeHead; !reflect.DeepEqual(got, want) { + t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + checkWitnessList(t, table.description, table.before.tosign.SigIdent, table.after.tosign.SigIdent) + if got, want := table.before.cosignature, table.after.cosignature; !reflect.DeepEqual(got, want) { + t.Errorf("got cosignature map\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func checkWitnessList(t *testing.T, description string, got, want []*types.SigIdent) { + t.Helper() + for _, si := range got { + found := false + for _, sj := range want { + if reflect.DeepEqual(si, sj) { + found = true + break + } + } + if !found { + t.Errorf("got unexpected signature-signer pair with key hash in test %q: %x", description, si.KeyHash[:]) + } + } + if len(got) != len(want) { + t.Errorf("got %d signature-signer pairs but wanted %d in test %q", len(got), len(want), description) + } +} diff --git a/pkg/trillian/client.go b/pkg/trillian/client.go new file mode 100644 index 0000000..9523e56 --- /dev/null +++ b/pkg/trillian/client.go @@ -0,0 +1,178 @@ +package trillian + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/google/trillian" + ttypes "github.com/google/trillian/types" + "github.com/system-transparency/stfe/pkg/types" + "google.golang.org/grpc/codes" +) + +type Client interface { + AddLeaf(context.Context, *types.LeafRequest) error + GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) + GetTreeHead(context.Context) (*types.TreeHead, error) + GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error) + GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error) +} + +// TrillianClient is a wrapper around the Trillian gRPC client. +type TrillianClient struct { + // TreeID is a Merkle tree identifier that Trillian uses + TreeID int64 + + // GRPC is a Trillian gRPC client + GRPC trillian.TrillianLogClient +} + +func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error { + leaf := types.Leaf{ + Message: req.Message, + SigIdent: types.SigIdent{ + Signature: req.Signature, + KeyHash: types.Hash(req.VerificationKey[:]), + }, + } + serialized := leaf.Marshal() + + glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized)) + rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{ + LogId: c.TreeID, + Leaf: &trillian.LogLeaf{ + LeafValue: serialized, + }, + }) + if err != nil { + return fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return fmt.Errorf("no response") + } + if rsp.QueuedLeaf == nil { + return fmt.Errorf("no queued leaf") + } + if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists { + return fmt.Errorf("leaf is already queued or included") + } + return nil +} + +func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) { + rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ + LogId: c.TreeID, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if rsp.SignedLogRoot == nil { + return nil, fmt.Errorf("no signed log root") + } + if rsp.SignedLogRoot.LogRoot == nil { + return nil, fmt.Errorf("no log root") + } + var r ttypes.LogRootV1 + if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { + return nil, fmt.Errorf("no log root: unmarshal failed: %v", err) + } + if len(r.RootHash) != types.HashSize { + return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash)) + } + return treeHeadFromLogRoot(&r), nil +} + +func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { + rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ + LogId: c.TreeID, + FirstTreeSize: int64(req.OldSize), + SecondTreeSize: int64(req.NewSize), + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if rsp.Proof == nil { + return nil, fmt.Errorf("no consistency proof") + } + if len(rsp.Proof.Hashes) == 0 { + return nil, fmt.Errorf("not a consistency proof: empty") + } + path, err := nodePathFromHashes(rsp.Proof.Hashes) + if err != nil { + return nil, fmt.Errorf("not a consistency proof: %v", err) + } + return &types.ConsistencyProof{ + OldSize: req.OldSize, + NewSize: req.NewSize, + Path: path, + }, nil +} + +func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) { + rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ + LogId: c.TreeID, + LeafHash: req.LeafHash[:], + TreeSize: int64(req.TreeSize), + OrderBySequence: true, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if len(rsp.Proof) != 1 { + return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof)) + } + proof := rsp.Proof[0] + if len(proof.Hashes) == 0 { + return nil, fmt.Errorf("not an inclusion proof: empty") + } + path, err := nodePathFromHashes(proof.Hashes) + if err != nil { + return nil, fmt.Errorf("not an inclusion proof: %v", err) + } + return &types.InclusionProof{ + TreeSize: req.TreeSize, + LeafIndex: uint64(proof.LeafIndex), + Path: path, + }, nil +} + +func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) { + rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ + LogId: c.TreeID, + StartIndex: int64(req.StartSize), + Count: int64(req.EndSize-req.StartSize) + 1, + }) + if err != nil { + return nil, fmt.Errorf("backend failure: %v", err) + } + if rsp == nil { + return nil, fmt.Errorf("no response") + } + if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { + return nil, fmt.Errorf("unexpected number of leaves: %d", got) + } + var list types.LeafList + for i, leaf := range rsp.Leaves { + leafIndex := int64(req.StartSize + uint64(i)) + if leafIndex != leaf.LeafIndex { + return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) + } + + var l types.Leaf + if err := l.Unmarshal(leaf.LeafValue); err != nil { + return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) + } + list = append(list[:], &l) + } + return &list, nil +} diff --git a/pkg/trillian/client_test.go b/pkg/trillian/client_test.go new file mode 100644 index 0000000..6b3d881 --- /dev/null +++ b/pkg/trillian/client_test.go @@ -0,0 +1,533 @@ +package trillian + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/trillian" + ttypes "github.com/google/trillian/types" + "github.com/system-transparency/stfe/pkg/mocks" + "github.com/system-transparency/stfe/pkg/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestAddLeaf(t *testing.T) { + req := &types.LeafRequest{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + Signature: &[types.SignatureSize]byte{}, + VerificationKey: &[types.VerificationKeySize]byte{}, + DomainHint: "example.com", + } + for _, table := range []struct { + description string + req *types.LeafRequest + rsp *trillian.QueueLeafResponse + err error + wantErr bool + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: no queued leaf", + req: req, + rsp: &trillian.QueueLeafResponse{}, + wantErr: true, + }, + { + description: "invalid: leaf is already queued or included", + req: req, + rsp: &trillian.QueueLeafResponse{ + QueuedLeaf: &trillian.QueuedLogLeaf{ + Leaf: &trillian.LogLeaf{ + LeafValue: req.Message.Marshal(), + }, + Status: status.New(codes.AlreadyExists, "duplicate").Proto(), + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.QueueLeafResponse{ + QueuedLeaf: &trillian.QueuedLogLeaf{ + Leaf: &trillian.LogLeaf{ + LeafValue: req.Message.Marshal(), + }, + Status: status.New(codes.OK, "ok").Proto(), + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + err := client.AddLeaf(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + }() + } +} + +func TestGetTreeHead(t *testing.T) { + // valid root + root := &ttypes.LogRootV1{ + TreeSize: 0, + RootHash: make([]byte, types.HashSize), + TimestampNanos: 1622585623133599429, + } + buf, err := root.MarshalBinary() + if err != nil { + t.Fatalf("must marshal log root: %v", err) + } + // invalid root + root.RootHash = make([]byte, types.HashSize+1) + bufBadHash, err := root.MarshalBinary() + if err != nil { + t.Fatalf("must marshal log root: %v", err) + } + + for _, table := range []struct { + description string + rsp *trillian.GetLatestSignedLogRootResponse + err error + wantErr bool + wantTh *types.TreeHead + }{ + { + description: "invalid: backend failure", + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + wantErr: true, + }, + { + description: "invalid: no signed log root", + rsp: &trillian.GetLatestSignedLogRootResponse{}, + wantErr: true, + }, + { + description: "invalid: no log root", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{}, + }, + wantErr: true, + }, + { + description: "invalid: no log root: unmarshal failed", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: buf[1:], + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected hash length", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: bufBadHash, + }, + }, + wantErr: true, + }, + { + description: "valid", + rsp: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: &trillian.SignedLogRoot{ + LogRoot: buf, + }, + }, + wantTh: &types.TreeHead{ + Timestamp: 1622585623, + TreeSize: 0, + RootHash: &[types.HashSize]byte{}, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + th, err := client.GetTreeHead(context.Background()) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) { + t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetConsistencyProof(t *testing.T) { + req := &types.ConsistencyProofRequest{ + OldSize: 1, + NewSize: 3, + } + for _, table := range []struct { + description string + req *types.ConsistencyProofRequest + rsp *trillian.GetConsistencyProofResponse + err error + wantErr bool + wantProof *types.ConsistencyProof + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: no consistency proof", + req: req, + rsp: &trillian.GetConsistencyProofResponse{}, + wantErr: true, + }, + { + description: "invalid: not a consistency proof (1/2)", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{}, + }, + }, + wantErr: true, + }, + { + description: "invalid: not a consistency proof (2/2)", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize+1), + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{ + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize), + }, + }, + }, + wantProof: &types.ConsistencyProof{ + OldSize: 1, + NewSize: 3, + Path: []*[types.HashSize]byte{ + &[types.HashSize]byte{}, + &[types.HashSize]byte{}, + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + proof, err := client.GetConsistencyProof(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { + t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetInclusionProof(t *testing.T) { + req := &types.InclusionProofRequest{ + TreeSize: 4, + LeafHash: &[types.HashSize]byte{}, + } + for _, table := range []struct { + description string + req *types.InclusionProofRequest + rsp *trillian.GetInclusionProofByHashResponse + err error + wantErr bool + wantProof *types.InclusionProof + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: bad proof count", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{}, + &trillian.Proof{}, + }, + }, + wantErr: true, + }, + { + description: "invalid: not an inclusion proof (1/2)", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{}, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: not an inclusion proof (2/2)", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize+1), + }, + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetInclusionProofByHashResponse{ + Proof: []*trillian.Proof{ + &trillian.Proof{ + LeafIndex: 1, + Hashes: [][]byte{ + make([]byte, types.HashSize), + make([]byte, types.HashSize), + }, + }, + }, + }, + wantProof: &types.InclusionProof{ + TreeSize: 4, + LeafIndex: 1, + Path: []*[types.HashSize]byte{ + &[types.HashSize]byte{}, + &[types.HashSize]byte{}, + }, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + proof, err := client.GetInclusionProof(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { + t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} + +func TestGetLeaves(t *testing.T) { + req := &types.LeavesRequest{ + StartSize: 1, + EndSize: 2, + } + firstLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + secondLeaf := &types.Leaf{ + Message: types.Message{ + ShardHint: 0, + Checksum: &[types.HashSize]byte{}, + }, + SigIdent: types.SigIdent{ + Signature: &[types.SignatureSize]byte{}, + KeyHash: &[types.HashSize]byte{}, + }, + } + + for _, table := range []struct { + description string + req *types.LeavesRequest + rsp *trillian.GetLeavesByRangeResponse + err error + wantErr bool + wantLeaves *types.LeafList + }{ + { + description: "invalid: backend failure", + req: req, + err: fmt.Errorf("something went wrong"), + wantErr: true, + }, + { + description: "invalid: no response", + req: req, + wantErr: true, + }, + { + description: "invalid: unexpected number of leaves", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (1/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 3, + }, + }, + }, + wantErr: true, + }, + { + description: "invalid: unexpected leaf (2/2)", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal()[1:], + LeafIndex: 2, + }, + }, + }, + wantErr: true, + }, + { + description: "valid", + req: req, + rsp: &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{ + &trillian.LogLeaf{ + LeafValue: firstLeaf.Marshal(), + LeafIndex: 1, + }, + &trillian.LogLeaf{ + LeafValue: secondLeaf.Marshal(), + LeafIndex: 2, + }, + }, + }, + wantLeaves: &types.LeafList{ + firstLeaf, + secondLeaf, + }, + }, + } { + // Run deferred functions at the end of each iteration + func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + grpc := mocks.NewMockTrillianLogClient(ctrl) + grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) + client := TrillianClient{GRPC: grpc} + + leaves, err := client.GetLeaves(context.Background(), table.req) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + return + } + if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { + t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + }() + } +} diff --git a/pkg/trillian/util.go b/pkg/trillian/util.go new file mode 100644 index 0000000..4cf31fb --- /dev/null +++ b/pkg/trillian/util.go @@ -0,0 +1,33 @@ +package trillian + +import ( + "fmt" + + trillian "github.com/google/trillian/types" + siglog "github.com/system-transparency/stfe/pkg/types" +) + +func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead { + var hash [siglog.HashSize]byte + th := siglog.TreeHead{ + Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000), + TreeSize: uint64(lr.TreeSize), + RootHash: &hash, + } + copy(th.RootHash[:], lr.RootHash) + return &th +} + +func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) { + var path []*[siglog.HashSize]byte + for _, hash := range hashes { + if len(hash) != siglog.HashSize { + return nil, fmt.Errorf("unexpected hash length: %v", len(hash)) + } + + var h [siglog.HashSize]byte + copy(h[:], hash) + path = append(path, &h) + } + return path, nil +} diff --git a/pkg/types/ascii.go b/pkg/types/ascii.go new file mode 100644 index 0000000..d27d79b --- /dev/null +++ b/pkg/types/ascii.go @@ -0,0 +1,421 @@ +package types + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "strconv" +) + +const ( + // Delim is a key-value separator + Delim = "=" + + // EOL is a line sepator + EOL = "\n" + + // NumField* is the number of unique keys in an incoming ASCII message + NumFieldLeaf = 4 + NumFieldSignedTreeHead = 5 + NumFieldConsistencyProof = 3 + NumFieldInclusionProof = 3 + NumFieldLeavesRequest = 2 + NumFieldInclusionProofRequest = 2 + NumFieldConsistencyProofRequest = 2 + NumFieldLeafRequest = 5 + NumFieldCosignatureRequest = 2 + + // New leaf keys + ShardHint = "shard_hint" + Checksum = "checksum" + SignatureOverMessage = "signature_over_message" + VerificationKey = "verification_key" + DomainHint = "domain_hint" + + // Inclusion proof keys + LeafHash = "leaf_hash" + LeafIndex = "leaf_index" + InclusionPath = "inclusion_path" + + // Consistency proof keys + NewSize = "new_size" + OldSize = "old_size" + ConsistencyPath = "consistency_path" + + // Range of leaves keys + StartSize = "start_size" + EndSize = "end_size" + + // Tree head keys + Timestamp = "timestamp" + TreeSize = "tree_size" + RootHash = "root_hash" + + // Signature and signer-identity keys + Signature = "signature" + KeyHash = "key_hash" +) + +// MessageASCI is a wrapper that manages ASCII key-value pairs +type MessageASCII struct { + m map[string][]string +} + +// NewMessageASCII unpacks an incoming ASCII message +func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) { + buf, err := ioutil.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("ReadAll: %v", err) + } + lines := bytes.Split(buf, []byte(EOL)) + if len(lines) <= 1 { + return nil, fmt.Errorf("Not enough lines: empty") + } + lines = lines[:len(lines)-1] // valid message => split gives empty last line + + msg := MessageASCII{make(map[string][]string)} + for _, line := range lines { + split := bytes.Index(line, []byte(Delim)) + if split == -1 { + return nil, fmt.Errorf("invalid line: %v", string(line)) + } + + key := string(line[:split]) + value := string(line[split+len(Delim):]) + values, ok := msg.m[key] + if !ok { + values = nil + msg.m[key] = values + } + msg.m[key] = append(values, value) + } + + if msg.NumField() != numFieldExpected { + return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField()) + } + return &msg, nil +} + +// NumField returns the number of unique keys +func (msg *MessageASCII) NumField() int { + return len(msg.m) +} + +// GetStrings returns a list of strings +func (msg *MessageASCII) GetStrings(key string) []string { + strs, ok := msg.m[key] + if !ok { + return nil + } + return strs +} + +// GetString unpacks a string +func (msg *MessageASCII) GetString(key string) (string, error) { + strs := msg.GetStrings(key) + if len(strs) != 1 { + return "", fmt.Errorf("expected one string: %v", strs) + } + return strs[0], nil +} + +// GetUint64 unpacks an uint64 +func (msg *MessageASCII) GetUint64(key string) (uint64, error) { + str, err := msg.GetString(key) + if err != nil { + return 0, fmt.Errorf("GetString: %v", err) + } + num, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return 0, fmt.Errorf("ParseUint: %v", err) + } + return num, nil +} + +// GetHash unpacks a hash +func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) { + str, err := msg.GetString(key) + if err != nil { + return nil, fmt.Errorf("GetString: %v", err) + } + + var hash [HashSize]byte + if err := decodeHex(str, hash[:]); err != nil { + return nil, fmt.Errorf("decodeHex: %v", err) + } + return &hash, nil +} + +// GetSignature unpacks a signature +func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) { + str, err := msg.GetString(key) + if err != nil { + return nil, fmt.Errorf("GetString: %v", err) + } + + var signature [SignatureSize]byte + if err := decodeHex(str, signature[:]); err != nil { + return nil, fmt.Errorf("decodeHex: %v", err) + } + return &signature, nil +} + +// GetVerificationKey unpacks a verification key +func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) { + str, err := msg.GetString(key) + if err != nil { + return nil, fmt.Errorf("GetString: %v", err) + } + + var vk [VerificationKeySize]byte + if err := decodeHex(str, vk[:]); err != nil { + return nil, fmt.Errorf("decodeHex: %v", err) + } + return &vk, nil +} + +// decodeHex decodes a hex-encoded string into an already-sized byte slice +func decodeHex(str string, out []byte) error { + buf, err := hex.DecodeString(str) + if err != nil { + return fmt.Errorf("DecodeString: %v", err) + } + if len(buf) != len(out) { + return fmt.Errorf("invalid length: %v", len(buf)) + } + copy(out, buf) + return nil +} + +/* + * + * MarshalASCII wrappers for types that the log server outputs + * + */ +func (l *Leaf) MarshalASCII(w io.Writer) error { + if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, SignatureOverMessage, hex.EncodeToString(l.Signature[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + return nil +} + +func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error { + if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + for _, sigident := range sth.SigIdent { + if err := sigident.MarshalASCII(w); err != nil { + return fmt.Errorf("MarshalASCII: %v", err) + } + } + return nil +} + +func (si *SigIdent) MarshalASCII(w io.Writer) error { + if err := writeASCII(w, Signature, hex.EncodeToString(si.Signature[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, KeyHash, hex.EncodeToString(si.KeyHash[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + return nil +} + +func (p *ConsistencyProof) MarshalASCII(w io.Writer) error { + if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + for _, hash := range p.Path { + if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + } + return nil +} + +func (p *InclusionProof) MarshalASCII(w io.Writer) error { + if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + for _, hash := range p.Path { + if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil { + return fmt.Errorf("writeASCII: %v", err) + } + } + return nil +} + +func writeASCII(w io.Writer, key, value string) error { + if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil { + return fmt.Errorf("Fprintf: %v", err) + } + return nil +} + +/* + * + * Unmarshal ASCII wrappers that the log server and/or log clients receive. + * + */ +func (ll *LeafList) UnmarshalASCII(r io.Reader) error { + return nil +} + +func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldSignedTreeHead) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + // TreeHead + if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil { + return fmt.Errorf("GetUint64(Timestamp): %v", err) + } + if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil { + return fmt.Errorf("GetUint64(TreeSize): %v", err) + } + if sth.RootHash, err = msg.GetHash(RootHash); err != nil { + return fmt.Errorf("GetHash(RootHash): %v", err) + } + + // SigIdent + signatures := msg.GetStrings(Signature) + if len(signatures) == 0 { + return fmt.Errorf("no signer") + } + keyHashes := msg.GetStrings(KeyHash) + if len(signatures) != len(keyHashes) { + return fmt.Errorf("mismatched signature-signer count") + } + sth.SigIdent = make([]*SigIdent, 0, len(signatures)) + for i, n := 0, len(signatures); i < n; i++ { + var signature [SignatureSize]byte + if err := decodeHex(signatures[i], signature[:]); err != nil { + return fmt.Errorf("decodeHex: %v", err) + } + var hash [HashSize]byte + if err := decodeHex(keyHashes[i], hash[:]); err != nil { + return fmt.Errorf("decodeHex: %v", err) + } + sth.SigIdent = append(sth.SigIdent, &SigIdent{ + Signature: &signature, + KeyHash: &hash, + }) + } + return nil +} + +func (p *InclusionProof) UnmarshalASCII(r io.Reader) error { + return nil +} + +func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error { + return nil +} + +func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + if req.LeafHash, err = msg.GetHash(LeafHash); err != nil { + return fmt.Errorf("GetHash(LeafHash): %v", err) + } + if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil { + return fmt.Errorf("GetUint64(TreeSize): %v", err) + } + return nil +} + +func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + if req.NewSize, err = msg.GetUint64(NewSize); err != nil { + return fmt.Errorf("GetUint64(NewSize): %v", err) + } + if req.OldSize, err = msg.GetUint64(OldSize); err != nil { + return fmt.Errorf("GetUint64(OldSize): %v", err) + } + return nil +} + +func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldLeavesRequest) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + if req.StartSize, err = msg.GetUint64(StartSize); err != nil { + return fmt.Errorf("GetUint64(StartSize): %v", err) + } + if req.EndSize, err = msg.GetUint64(EndSize); err != nil { + return fmt.Errorf("GetUint64(EndSize): %v", err) + } + return nil +} + +func (req *LeafRequest) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldLeafRequest) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil { + return fmt.Errorf("GetUint64(ShardHint): %v", err) + } + if req.Checksum, err = msg.GetHash(Checksum); err != nil { + return fmt.Errorf("GetHash(Checksum): %v", err) + } + if req.Signature, err = msg.GetSignature(SignatureOverMessage); err != nil { + return fmt.Errorf("GetSignature: %v", err) + } + if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil { + return fmt.Errorf("GetVerificationKey: %v", err) + } + if req.DomainHint, err = msg.GetString(DomainHint); err != nil { + return fmt.Errorf("GetString(DomainHint): %v", err) + } + return nil +} + +func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error { + msg, err := NewMessageASCII(r, NumFieldCosignatureRequest) + if err != nil { + return fmt.Errorf("NewMessageASCII: %v", err) + } + + if req.Signature, err = msg.GetSignature(Signature); err != nil { + return fmt.Errorf("GetSignature: %v", err) + } + if req.KeyHash, err = msg.GetHash(KeyHash); err != nil { + return fmt.Errorf("GetHash(KeyHash): %v", err) + } + return nil +} diff --git a/pkg/types/ascii_test.go b/pkg/types/ascii_test.go new file mode 100644 index 0000000..92732f9 --- /dev/null +++ b/pkg/types/ascii_test.go @@ -0,0 +1,465 @@ +package types + +import ( + "bytes" + "fmt" + "io" + "reflect" + "testing" +) + +/* + * + * MessageASCII methods and helpers + * + */ +func TestNewMessageASCII(t *testing.T) { + for _, table := range []struct { + description string + input io.Reader + wantErr bool + wantMap map[string][]string + }{ + { + description: "invalid: not enough lines", + input: bytes.NewBufferString(""), + wantErr: true, + }, + { + description: "invalid: lines must end with new line", + input: bytes.NewBufferString("k1=v1\nk2=v2"), + wantErr: true, + }, + { + description: "invalid: lines must not be empty", + input: bytes.NewBufferString("k1=v1\n\nk2=v2\n"), + wantErr: true, + }, + { + description: "invalid: wrong number of fields", + input: bytes.NewBufferString("k1=v1\n"), + wantErr: true, + }, + { + description: "valid", + input: bytes.NewBufferString("k1=v1\nk2=v2\nk2=v3=4\n"), + wantMap: map[string][]string{ + "k1": []string{"v1"}, + "k2": []string{"v2", "v3=4"}, + }, + }, + } { + msg, err := NewMessageASCII(table.input, len(table.wantMap)) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := msg.m, table.wantMap; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestNumField(t *testing.T) {} +func TestGetStrings(t *testing.T) {} +func TestGetString(t *testing.T) {} +func TestGetUint64(t *testing.T) {} +func TestGetHash(t *testing.T) {} +func TestGetSignature(t *testing.T) {} +func TestGetVerificationKey(t *testing.T) {} +func TestDecodeHex(t *testing.T) {} + +/* + * + * MarshalASCII methods and helpers + * + */ +func TestLeafMarshalASCII(t *testing.T) { + description := "valid: two leaves" + leafList := []*Leaf{ + &Leaf{ + Message: Message{ + ShardHint: 123, + Checksum: testBuffer32, + }, + SigIdent: SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + &Leaf{ + Message: Message{ + ShardHint: 456, + Checksum: testBuffer32, + }, + SigIdent: SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + } + wantBuf := bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+ + "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", + // Leaf 1 + ShardHint, Delim, 123, EOL, + Checksum, Delim, testBuffer32[:], EOL, + SignatureOverMessage, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + // Leaf 2 + ShardHint, Delim, 456, EOL, + Checksum, Delim, testBuffer32[:], EOL, + SignatureOverMessage, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + )) + buf := bytes.NewBuffer(nil) + for _, leaf := range leafList { + if err := leaf.MarshalASCII(buf); err != nil { + t.Errorf("expected error %v but got %v in test %q: %v", false, true, description, err) + return + } + } + if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) + } +} + +func TestSignedTreeHeadMarshalASCII(t *testing.T) { + description := "valid" + sth := &SignedTreeHead{ + TreeHead: TreeHead{ + Timestamp: 123, + TreeSize: 456, + RootHash: testBuffer32, + }, + SigIdent: []*SigIdent{ + &SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + &SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + } + wantBuf := bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", + Timestamp, Delim, 123, EOL, + TreeSize, Delim, 456, EOL, + RootHash, Delim, testBuffer32[:], EOL, + Signature, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + Signature, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + )) + buf := bytes.NewBuffer(nil) + if err := sth.MarshalASCII(buf); err != nil { + t.Errorf("expected error %v but got %v in test %q", false, true, description) + return + } + if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) + } +} + +func TestInclusionProofMarshalASCII(t *testing.T) { + description := "valid" + proof := InclusionProof{ + TreeSize: 321, + LeafIndex: 123, + Path: []*[HashSize]byte{ + testBuffer32, + testBuffer32, + }, + } + wantBuf := bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", + TreeSize, Delim, 321, EOL, + LeafIndex, Delim, 123, EOL, + InclusionPath, Delim, testBuffer32[:], EOL, + InclusionPath, Delim, testBuffer32[:], EOL, + )) + buf := bytes.NewBuffer(nil) + if err := proof.MarshalASCII(buf); err != nil { + t.Errorf("expected error %v but got %v in test %q", false, true, description) + return + } + if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) + } +} + +func TestConsistencyProofMarshalASCII(t *testing.T) { + description := "valid" + proof := ConsistencyProof{ + NewSize: 321, + OldSize: 123, + Path: []*[HashSize]byte{ + testBuffer32, + testBuffer32, + }, + } + wantBuf := bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", + NewSize, Delim, 321, EOL, + OldSize, Delim, 123, EOL, + ConsistencyPath, Delim, testBuffer32[:], EOL, + ConsistencyPath, Delim, testBuffer32[:], EOL, + )) + buf := bytes.NewBuffer(nil) + if err := proof.MarshalASCII(buf); err != nil { + t.Errorf("expected error %v but got %v in test %q", false, true, description) + return + } + if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) + } +} + +func TestWriteASCII(t *testing.T) { +} + +/* + * + * UnmarshalASCII methods and helpers + * + */ +func TestLeafListUnmarshalASCII(t *testing.T) {} + +func TestSignedTreeHeadUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantSth *SignedTreeHead + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", + Timestamp, Delim, 123, EOL, + TreeSize, Delim, 456, EOL, + RootHash, Delim, testBuffer32[:], EOL, + Signature, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + Signature, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + )), + wantSth: &SignedTreeHead{ + TreeHead: TreeHead{ + Timestamp: 123, + TreeSize: 456, + RootHash: testBuffer32, + }, + SigIdent: []*SigIdent{ + &SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + &SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + }, + }, + } { + var sth SignedTreeHead + err := sth.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &sth, table.wantSth; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestInclusionProofUnmarshalASCII(t *testing.T) {} +func TestConsistencyProofUnmarshalASCII(t *testing.T) {} + +func TestInclusionProofRequestUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantReq *InclusionProofRequest + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%x%s"+"%s%s%d%s", + LeafHash, Delim, testBuffer32[:], EOL, + TreeSize, Delim, 123, EOL, + )), + wantReq: &InclusionProofRequest{ + LeafHash: testBuffer32, + TreeSize: 123, + }, + }, + } { + var req InclusionProofRequest + err := req.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestConsistencyProofRequestUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantReq *ConsistencyProofRequest + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s", + NewSize, Delim, 321, EOL, + OldSize, Delim, 123, EOL, + )), + wantReq: &ConsistencyProofRequest{ + NewSize: 321, + OldSize: 123, + }, + }, + } { + var req ConsistencyProofRequest + err := req.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestLeavesRequestUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantReq *LeavesRequest + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%d%s", + StartSize, Delim, 123, EOL, + EndSize, Delim, 456, EOL, + )), + wantReq: &LeavesRequest{ + StartSize: 123, + EndSize: 456, + }, + }, + } { + var req LeavesRequest + err := req.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestLeafRequestUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantReq *LeafRequest + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%s%s", + ShardHint, Delim, 123, EOL, + Checksum, Delim, testBuffer32[:], EOL, + SignatureOverMessage, Delim, testBuffer64[:], EOL, + VerificationKey, Delim, testBuffer32[:], EOL, + DomainHint, Delim, "example.com", EOL, + )), + wantReq: &LeafRequest{ + Message: Message{ + ShardHint: 123, + Checksum: testBuffer32, + }, + Signature: testBuffer64, + VerificationKey: testBuffer32, + DomainHint: "example.com", + }, + }, + } { + var req LeafRequest + err := req.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} + +func TestCosignatureRequestUnmarshalASCII(t *testing.T) { + for _, table := range []struct { + description string + buf io.Reader + wantErr bool + wantReq *CosignatureRequest + }{ + { + description: "valid", + buf: bytes.NewBufferString(fmt.Sprintf( + "%s%s%x%s"+"%s%s%x%s", + Signature, Delim, testBuffer64[:], EOL, + KeyHash, Delim, testBuffer32[:], EOL, + )), + wantReq: &CosignatureRequest{ + SigIdent: SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + }, + } { + var req CosignatureRequest + err := req.UnmarshalASCII(table.buf) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { + t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) + } + } +} diff --git a/pkg/types/trunnel.go b/pkg/types/trunnel.go new file mode 100644 index 0000000..268f6f7 --- /dev/null +++ b/pkg/types/trunnel.go @@ -0,0 +1,60 @@ +package types + +import ( + "encoding/binary" + "fmt" +) + +const ( + // MessageSize is the number of bytes in a Trunnel-encoded leaf message + MessageSize = 8 + HashSize + // LeafSize is the number of bytes in a Trunnel-encoded leaf + LeafSize = MessageSize + SignatureSize + HashSize +) + +// Marshal returns a Trunnel-encoded message +func (m *Message) Marshal() []byte { + buf := make([]byte, MessageSize) + binary.BigEndian.PutUint64(buf, m.ShardHint) + copy(buf[8:], m.Checksum[:]) + return buf +} + +// Marshal returns a Trunnel-encoded leaf +func (l *Leaf) Marshal() []byte { + buf := l.Message.Marshal() + buf = append(buf, l.SigIdent.Signature[:]...) + buf = append(buf, l.SigIdent.KeyHash[:]...) + return buf +} + +// Marshal returns a Trunnel-encoded tree head +func (th *TreeHead) Marshal() []byte { + buf := make([]byte, 8+8+HashSize) + binary.BigEndian.PutUint64(buf[0:8], th.Timestamp) + binary.BigEndian.PutUint64(buf[8:16], th.TreeSize) + copy(buf[16:], th.RootHash[:]) + return buf +} + +// Unmarshal parses the Trunnel-encoded buffer as a leaf +func (l *Leaf) Unmarshal(buf []byte) error { + if len(buf) != LeafSize { + return fmt.Errorf("invalid leaf size: %v", len(buf)) + } + // Shard hint + l.ShardHint = binary.BigEndian.Uint64(buf) + offset := 8 + // Checksum + l.Checksum = &[HashSize]byte{} + copy(l.Checksum[:], buf[offset:offset+HashSize]) + offset += HashSize + // Signature + l.Signature = &[SignatureSize]byte{} + copy(l.Signature[:], buf[offset:offset+SignatureSize]) + offset += SignatureSize + // KeyHash + l.KeyHash = &[HashSize]byte{} + copy(l.KeyHash[:], buf[offset:]) + return nil +} diff --git a/pkg/types/trunnel_test.go b/pkg/types/trunnel_test.go new file mode 100644 index 0000000..297578c --- /dev/null +++ b/pkg/types/trunnel_test.go @@ -0,0 +1,114 @@ +package types + +import ( + "bytes" + "reflect" + "testing" +) + +var ( + testBuffer32 = &[32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31} + testBuffer64 = &[64]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63} +) + +func TestMarshalMessage(t *testing.T) { + description := "valid: shard hint 72623859790382856, checksum 0x00,0x01,..." + message := &Message{ + ShardHint: 72623859790382856, + Checksum: testBuffer32, + } + want := bytes.Join([][]byte{ + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + testBuffer32[:], + }, nil) + if got := message.Marshal(); !bytes.Equal(got, want) { + t.Errorf("got message\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) + } +} + +func TestMarshalLeaf(t *testing.T) { + description := "valid: shard hint 72623859790382856, buffers 0x00,0x01,..." + leaf := &Leaf{ + Message: Message{ + ShardHint: 72623859790382856, + Checksum: testBuffer32, + }, + SigIdent: SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + } + want := bytes.Join([][]byte{ + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + testBuffer32[:], testBuffer64[:], testBuffer32[:], + }, nil) + if got := leaf.Marshal(); !bytes.Equal(got, want) { + t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) + } +} + +func TestMarshalTreeHead(t *testing.T) { + description := "valid: timestamp 16909060, tree size 72623859790382856, root hash 0x00,0x01,..." + th := &TreeHead{ + Timestamp: 16909060, + TreeSize: 72623859790382856, + RootHash: testBuffer32, + } + want := bytes.Join([][]byte{ + []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04}, + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + testBuffer32[:], + }, nil) + if got := th.Marshal(); !bytes.Equal(got, want) { + t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) + } +} + +func TestUnmarshalLeaf(t *testing.T) { + for _, table := range []struct { + description string + serialized []byte + wantErr bool + want *Leaf + }{ + { + description: "invalid: not enough bytes", + serialized: make([]byte, LeafSize-1), + wantErr: true, + }, + { + description: "invalid: too many bytes", + serialized: make([]byte, LeafSize+1), + wantErr: true, + }, + { + description: "valid: shard hint 72623859790382856, buffers 0x00,0x01,...", + serialized: bytes.Join([][]byte{ + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + testBuffer32[:], testBuffer64[:], testBuffer32[:], + }, nil), + want: &Leaf{ + Message: Message{ + ShardHint: 72623859790382856, + Checksum: testBuffer32, + }, + SigIdent: SigIdent{ + Signature: testBuffer64, + KeyHash: testBuffer32, + }, + }, + }, + } { + var leaf Leaf + err := leaf.Unmarshal(table.serialized) + if got, want := err != nil, table.wantErr; got != want { + t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) + } + if err != nil { + continue + } + if got, want := &leaf, table.want; !reflect.DeepEqual(got, want) { + t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, table.description) + } + } +} diff --git a/pkg/types/types.go b/pkg/types/types.go new file mode 100644 index 0000000..9ca7db8 --- /dev/null +++ b/pkg/types/types.go @@ -0,0 +1,155 @@ +package types + +import ( + "crypto" + "crypto/ed25519" + "crypto/sha256" + "fmt" + "strings" +) + +const ( + HashSize = sha256.Size + SignatureSize = ed25519.SignatureSize + VerificationKeySize = ed25519.PublicKeySize + + EndpointAddLeaf = Endpoint("add-leaf") + EndpointAddCosignature = Endpoint("add-cosignature") + EndpointGetTreeHeadLatest = Endpoint("get-tree-head-latest") + EndpointGetTreeHeadToSign = Endpoint("get-tree-head-to-sign") + EndpointGetTreeHeadCosigned = Endpoint("get-tree-head-cosigned") + EndpointGetProofByHash = Endpoint("get-proof-by-hash") + EndpointGetConsistencyProof = Endpoint("get-consistency-proof") + EndpointGetLeaves = Endpoint("get-leaves") +) + +// Endpoint is a named HTTP API endpoint +type Endpoint string + +// Path joins a number of components to form a full endpoint path. For example, +// EndpointAddLeaf.Path("example.com", "st/v0") -> example.com/st/v0/add-leaf. +func (e Endpoint) Path(components ...string) string { + return strings.Join(append(components, string(e)), "/") +} + +// Leaf is the log's Merkle tree leaf. +type Leaf struct { + Message + SigIdent +} + +// Message is composed of a shard hint and a checksum. The submitter selects +// these values to fit the log's shard interval and the opaque data in question. +type Message struct { + ShardHint uint64 + Checksum *[HashSize]byte +} + +// SigIdent is composed of a signature-signer pair. The signature is computed +// over the Trunnel-serialized leaf message. KeyHash identifies the signer. +type SigIdent struct { + Signature *[SignatureSize]byte + KeyHash *[HashSize]byte +} + +// SignedTreeHead is composed of a tree head and a list of signature-signer +// pairs. Each signature is computed over the Trunnel-serialized tree head. +type SignedTreeHead struct { + TreeHead + SigIdent []*SigIdent +} + +// TreeHead is the log's tree head. +type TreeHead struct { + Timestamp uint64 + TreeSize uint64 + RootHash *[HashSize]byte +} + +// ConsistencyProof is a consistency proof that proves the log's append-only +// property. +type ConsistencyProof struct { + NewSize uint64 + OldSize uint64 + Path []*[HashSize]byte +} + +// InclusionProof is an inclusion proof that proves a leaf is included in the +// log. +type InclusionProof struct { + TreeSize uint64 + LeafIndex uint64 + Path []*[HashSize]byte +} + +// LeafList is a list of leaves +type LeafList []*Leaf + +// ConsistencyProofRequest is a get-consistency-proof request +type ConsistencyProofRequest struct { + NewSize uint64 + OldSize uint64 +} + +// InclusionProofRequest is a get-proof-by-hash request +type InclusionProofRequest struct { + LeafHash *[HashSize]byte + TreeSize uint64 +} + +// LeavesRequest is a get-leaves request +type LeavesRequest struct { + StartSize uint64 + EndSize uint64 +} + +// LeafRequest is an add-leaf request +type LeafRequest struct { + Message + Signature *[SignatureSize]byte + VerificationKey *[VerificationKeySize]byte + DomainHint string +} + +// CosignatureRequest is an add-cosignature request +type CosignatureRequest struct { + SigIdent +} + +// Sign signs the tree head using the log's signature scheme +func (th *TreeHead) Sign(signer crypto.Signer) (*SignedTreeHead, error) { + sig, err := signer.Sign(nil, th.Marshal(), crypto.Hash(0)) + if err != nil { + return nil, fmt.Errorf("Sign: %v", err) + } + + sigident := SigIdent{ + KeyHash: Hash(signer.Public().(ed25519.PublicKey)[:]), + Signature: &[SignatureSize]byte{}, + } + copy(sigident.Signature[:], sig) + return &SignedTreeHead{ + TreeHead: *th, + SigIdent: []*SigIdent{ + &sigident, + }, + }, nil +} + +// Verify verifies the tree head signature using the log's signature scheme +func (th *TreeHead) Verify(vk *[VerificationKeySize]byte, sig *[SignatureSize]byte) error { + if !ed25519.Verify(ed25519.PublicKey(vk[:]), th.Marshal(), sig[:]) { + return fmt.Errorf("invalid tree head signature") + } + return nil +} + +// Verify checks if a leaf is included in the log +func (p *InclusionProof) Verify(leaf *Leaf, th *TreeHead) error { // TODO + return nil +} + +// Verify checks if two tree heads are consistent +func (p *ConsistencyProof) Verify(oldTH, newTH *TreeHead) error { // TODO + return nil +} diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go new file mode 100644 index 0000000..da89c59 --- /dev/null +++ b/pkg/types/types_test.go @@ -0,0 +1,58 @@ +package types + +import ( + "testing" +) + +func TestEndpointPath(t *testing.T) { + base, prefix, proto := "example.com", "log", "st/v0" + for _, table := range []struct { + endpoint Endpoint + want string + }{ + { + endpoint: EndpointAddLeaf, + want: "example.com/log/st/v0/add-leaf", + }, + { + endpoint: EndpointAddCosignature, + want: "example.com/log/st/v0/add-cosignature", + }, + { + endpoint: EndpointGetTreeHeadLatest, + want: "example.com/log/st/v0/get-tree-head-latest", + }, + { + endpoint: EndpointGetTreeHeadToSign, + want: "example.com/log/st/v0/get-tree-head-to-sign", + }, + { + endpoint: EndpointGetTreeHeadCosigned, + want: "example.com/log/st/v0/get-tree-head-cosigned", + }, + { + endpoint: EndpointGetConsistencyProof, + want: "example.com/log/st/v0/get-consistency-proof", + }, + { + endpoint: EndpointGetProofByHash, + want: "example.com/log/st/v0/get-proof-by-hash", + }, + { + endpoint: EndpointGetLeaves, + want: "example.com/log/st/v0/get-leaves", + }, + } { + if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want { + t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want) + } + if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want { + t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want) + } + } +} + +func TestTreeHeadSign(t *testing.T) {} +func TestTreeHeadVerify(t *testing.T) {} +func TestInclusionProofVerify(t *testing.T) {} +func TestConsistencyProofVerify(t *testing.T) {} diff --git a/pkg/types/util.go b/pkg/types/util.go new file mode 100644 index 0000000..3cd7dfa --- /dev/null +++ b/pkg/types/util.go @@ -0,0 +1,21 @@ +package types + +import ( + "crypto/sha256" +) + +const ( + LeafHashPrefix = 0x00 +) + +func Hash(buf []byte) *[HashSize]byte { + var ret [HashSize]byte + hash := sha256.New() + hash.Write(buf) + copy(ret[:], hash.Sum(nil)) + return &ret +} + +func HashLeaf(buf []byte) *[HashSize]byte { + return Hash(append([]byte{LeafHashPrefix}, buf...)) +} diff --git a/request.go b/request.go deleted file mode 100644 index 763d9ed..0000000 --- a/request.go +++ /dev/null @@ -1,81 +0,0 @@ -package stfe - -import ( - "fmt" - - "crypto/ed25519" - "net/http" - - "github.com/system-transparency/stfe/types" -) - -func (lp *LogParameters) parseAddEntryV1Request(r *http.Request) (*types.Leaf, error) { - var req types.LeafRequest - if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("UnmarshalASCII: %v", err) - } - - if pub, msg, sig := ed25519.PublicKey(req.VerificationKey[:]), req.Message.Marshal(), req.Signature[:]; !ed25519.Verify(pub, msg, sig) { - return nil, fmt.Errorf("Invalid signature") - } - // TODO: check shard hint - // TODO: check domain hint - return &types.Leaf{ - Message: req.Message, - SigIdent: types.SigIdent{ - Signature: req.Signature, - KeyHash: types.Hash(req.VerificationKey[:]), - }, - }, nil -} - -func (lp *LogParameters) parseAddCosignatureRequest(r *http.Request) (*types.CosignatureRequest, error) { - var req types.CosignatureRequest - if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("unpackOctetPost: %v", err) - } - if _, ok := lp.Witnesses[*req.KeyHash]; !ok { - return nil, fmt.Errorf("Unknown witness: %x", req.KeyHash) - } - return &req, nil -} - -func (lp *LogParameters) parseGetConsistencyProofRequest(r *http.Request) (*types.ConsistencyProofRequest, error) { - var req types.ConsistencyProofRequest - if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("UnmarshalASCII: %v", err) - } - if req.OldSize < 1 { - return nil, fmt.Errorf("OldSize(%d) must be larger than zero", req.OldSize) - } - if req.NewSize <= req.OldSize { - return nil, fmt.Errorf("NewSize(%d) must be larger than OldSize(%d)", req.NewSize, req.OldSize) - } - return &req, nil -} - -func (lp *LogParameters) parseGetProofByHashRequest(r *http.Request) (*types.InclusionProofRequest, error) { - var req types.InclusionProofRequest - if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("UnmarshalASCII: %v", err) - } - if req.TreeSize < 1 { - return nil, fmt.Errorf("TreeSize(%d) must be larger than zero", req.TreeSize) - } - return &req, nil -} - -func (lp *LogParameters) parseGetEntriesRequest(r *http.Request) (*types.LeavesRequest, error) { - var req types.LeavesRequest - if err := req.UnmarshalASCII(r.Body); err != nil { - return nil, fmt.Errorf("UnmarshalASCII: %v", err) - } - - if req.StartSize > req.EndSize { - return nil, fmt.Errorf("StartSize(%d) must be less than or equal to EndSize(%d)", req.StartSize, req.EndSize) - } - if req.EndSize-req.StartSize+1 > uint64(lp.MaxRange) { - req.EndSize = req.StartSize + uint64(lp.MaxRange) - 1 - } - return &req, nil -} diff --git a/request_test.go b/request_test.go deleted file mode 100644 index 102c56f..0000000 --- a/request_test.go +++ /dev/null @@ -1,318 +0,0 @@ -package stfe - -import ( - "bytes" - //"fmt" - "reflect" - "testing" - //"testing/iotest" - - "net/http" - - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -func TestParseAddEntryV1Request(t *testing.T) { - lp := newLogParameters(t, nil) - for _, table := range []struct { - description string - breq *bytes.Buffer - wantErr bool - }{ - { - description: "invalid: nothing to unpack", - breq: bytes.NewBuffer(nil), - wantErr: true, - }, - { - description: "invalid: not a signed checksum entry", - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), - wantErr: true, - }, - { - description: "invalid: untrusted submitter", // only testdata.Ed25519VkSubmitter is registered by default in newLogParameters() - - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter2), - wantErr: true, - }, - { - description: "invalid: signature does not cover message", - - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter2, testdata.Ed25519VkSubmitter), - wantErr: true, - }, - { - description: "valid", - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), - }, // TODO: add test case that disables submitter policy (i.e., unregistered namespaces are accepted) - } { - url := EndpointAddEntry.Path("http://example.com", lp.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - _, err = lp.parseAddEntryV1Request(req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) - } - } -} - -func TestParseAddCosignatureV1Request(t *testing.T) { - lp := newLogParameters(t, nil) - for _, table := range []struct { - description string - breq *bytes.Buffer - wantErr bool - }{ - { - description: "invalid: nothing to unpack", - breq: bytes.NewBuffer(nil), - wantErr: true, - }, - { - description: "invalid: not a cosigned sth", - breq: testdata.AddSignedChecksumBuffer(t, testdata.Ed25519SkSubmitter, testdata.Ed25519VkSubmitter), - wantErr: true, - }, - { - description: "invalid: no cosignature", - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, nil), - wantErr: true, - }, - { - description: "invalid: untrusted witness", // only testdata.Ed25519VkWitness is registered by default in newLogParameters() - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness2), - wantErr: true, - }, - { - description: "invalid: signature does not cover message", - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness2, &testdata.Ed25519VkWitness), - wantErr: true, - }, - { - description: "valid", - breq: testdata.AddCosignatureBuffer(t, testdata.DefaultSth(t, testdata.Ed25519VkLog), &testdata.Ed25519SkWitness, &testdata.Ed25519VkWitness), - }, // TODO: add test case that disables witness policy (i.e., unregistered namespaces are accepted) - } { - url := EndpointAddCosignature.Path("http://example.com", lp.Prefix) - req, err := http.NewRequest("POST", url, table.breq) - if err != nil { - t.Fatalf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - _, err = lp.parseAddCosignatureV1Request(req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) - } - } -} - -func TestNewGetConsistencyProofRequest(t *testing.T) { - lp := newLogParameters(t, nil) - for _, table := range []struct { - description string - req *types.GetConsistencyProofV1 - wantErr bool - }{ - { - description: "invalid: nothing to unpack", - req: nil, - wantErr: true, - }, - { - description: "invalid: first must be larger than zero", - req: &types.GetConsistencyProofV1{First: 0, Second: 0}, - wantErr: true, - }, - { - description: "invalid: second must be larger than first", - req: &types.GetConsistencyProofV1{First: 2, Second: 1}, - wantErr: true, - }, - { - description: "valid", - req: &types.GetConsistencyProofV1{First: 1, Second: 2}, - }, - } { - var buf *bytes.Buffer - if table.req == nil { - buf = bytes.NewBuffer(nil) - } else { - buf = bytes.NewBuffer(marshal(t, *table.req)) - } - - url := EndpointGetConsistencyProof.Path("http://example.com", lp.Prefix) - req, err := http.NewRequest("POST", url, buf) - if err != nil { - t.Fatalf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - _, err = lp.parseGetConsistencyProofV1Request(req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) - } - } -} - -func TestNewGetProofByHashRequest(t *testing.T) { - lp := newLogParameters(t, nil) - for _, table := range []struct { - description string - req *types.GetProofByHashV1 - wantErr bool - }{ - { - description: "invalid: nothing to unpack", - req: nil, - wantErr: true, - }, - { - description: "invalid: no entry in an empty tree", - req: &types.GetProofByHashV1{TreeSize: 0, Hash: testdata.LeafHash}, - wantErr: true, - }, - { - description: "valid", - req: &types.GetProofByHashV1{TreeSize: 1, Hash: testdata.LeafHash}, - }, - } { - var buf *bytes.Buffer - if table.req == nil { - buf = bytes.NewBuffer(nil) - } else { - buf = bytes.NewBuffer(marshal(t, *table.req)) - } - - url := EndpointGetProofByHash.Path("http://example.com", lp.Prefix) - req, err := http.NewRequest("POST", url, buf) - if err != nil { - t.Fatalf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - _, err = lp.parseGetProofByHashV1Request(req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) - } - } -} - -func TestParseGetEntriesV1Request(t *testing.T) { - lp := newLogParameters(t, nil) - for _, table := range []struct { - description string - req *types.GetEntriesV1 - wantErr bool - wantReq *types.GetEntriesV1 - }{ - { - description: "invalid: nothing to unpack", - req: nil, - wantErr: true, - }, - { - description: "invalid: start must be larger than end", - req: &types.GetEntriesV1{Start: 1, End: 0}, - wantErr: true, - }, - { - description: "valid: want truncated range", - req: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange)}, - wantReq: &types.GetEntriesV1{Start: 0, End: uint64(testdata.MaxRange) - 1}, - }, - { - description: "valid", - req: &types.GetEntriesV1{Start: 0, End: 0}, - wantReq: &types.GetEntriesV1{Start: 0, End: 0}, - }, - } { - var buf *bytes.Buffer - if table.req == nil { - buf = bytes.NewBuffer(nil) - } else { - buf = bytes.NewBuffer(marshal(t, *table.req)) - } - - url := EndpointGetEntries.Path("http://example.com", lp.Prefix) - req, err := http.NewRequest("POST", url, buf) - if err != nil { - t.Fatalf("failed creating http request: %v", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - output, err := lp.parseGetEntriesV1Request(req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got errror %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := output, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got request\n%v\n\tbut wanted\n%v\n\t in test %q", got, want, table.description) - } - } -} - -func TestUnpackOctetPost(t *testing.T) { - for _, table := range []struct { - description string - req *http.Request - out interface{} - wantErr bool - }{ - //{ - // description: "invalid: cannot read request body", - // req: func() *http.Request { - // req, err := http.NewRequest(http.MethodPost, "", iotest.ErrReader(fmt.Errorf("bad reader"))) - // if err != nil { - // t.Fatalf("must make new http request: %v", err) - // } - // return req - // }(), - // out: &types.StItem{}, - // wantErr: true, - //}, // testcase requires Go 1.16 - { - description: "invalid: cannot unmarshal", - req: func() *http.Request { - req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer(nil)) - if err != nil { - t.Fatalf("must make new http request: %v", err) - } - return req - }(), - out: &types.StItem{}, - wantErr: true, - }, - { - description: "valid", - req: func() *http.Request { - req, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer([]byte{0})) - if err != nil { - t.Fatalf("must make new http request: %v", err) - } - return req - }(), - out: &struct{ SomeUint8 uint8 }{}, - }, - } { - err := unpackOctetPost(table.req, table.out) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} - -func marshal(t *testing.T, out interface{}) []byte { - b, err := types.Marshal(out) - if err != nil { - t.Fatalf("must marshal: %v", err) - } - return b -} diff --git a/server/.gitignore b/server/.gitignore deleted file mode 100644 index 254defd..0000000 --- a/server/.gitignore +++ /dev/null @@ -1 +0,0 @@ -server diff --git a/server/README.md b/server/README.md deleted file mode 100644 index 71bb3ac..0000000 --- a/server/README.md +++ /dev/null @@ -1,60 +0,0 @@ -# Run Trillian + STFE locally -Trillian uses a database. So, we will need to set that up. It is documented -[here](https://github.com/google/trillian#mysql-setup), and how to check that it -is setup properly -[here](https://github.com/google/certificate-transparency-go/blob/master/trillian/docs/ManualDeployment.md#data-storage). - -Other than the database we need the Trillian log signer, Trillian log server, -and STFE server. -``` -$ go install github.com/google/trillian/cmd/trillian_log_signer -$ go install github.com/google/trillian/cmd/trillian_log_server -$ go install -``` - -Start Trillian log signer: -``` -trillian_log_signer --logtostderr -v 9 --force_master --rpc_endpoint=localhost:6961 --http_endpoint=localhost:6964 --num_sequencers 1 --sequencer_interval 100ms --batch_size 100 -``` - -Start Trillian log server: -``` -trillian_log_server --logtostderr -v 9 --rpc_endpoint=localhost:6962 --http_endpoint=localhost:6963 -``` - -As described in more detail -[here](https://github.com/google/certificate-transparency-go/blob/master/trillian/docs/ManualDeployment.md#trillian-services), -we need to provision a Merkle tree once: -``` -$ go install github.com/google/trillian/cmd/createtree -$ createtree --admin_server localhost:6962 - -``` - -Hang on to ``. Our STFE server will use it when talking to the -Trillian log server to specify which Merkle tree we are working against. - -(If you take a look in the `Trees` table you will see that the tree has been -provisioned.) - -We will also need a public key-pair and log identifier for the STFE server. -``` -$ go install github.com/system-transparency/stfe/types/cmd/new-namespace -sk: -vk: -ed25519_v1: -``` - -The log's identifier is `` and contains the public verification key -``. The log's corresponding secret signing key is ``. - -Start STFE server: -``` -$ ./server --logtostderr -v 9 --http_endpoint localhost:6965 --log_rpc_server localhost:6962 --trillian_id --key -``` - -If the log is responsive on, e.g., `GET http://localhost:6965/st/v1/get-latest-sth` you -may want to try running -`github.com/system-transparency/stfe/client/cmd/example.sh`. You need to -configure the log's id though for verification to work (flag `log_id`, which -should be set to the `` output above). diff --git a/server/main.go b/server/main.go deleted file mode 100644 index 1fecb43..0000000 --- a/server/main.go +++ /dev/null @@ -1,165 +0,0 @@ -// Package main provides an STFE server binary -package main - -import ( - "context" - "crypto" - "crypto/ed25519" - "encoding/hex" - "flag" - "fmt" - "net/http" - "os" - "os/signal" - "strings" - "sync" - "syscall" - "time" - - "github.com/golang/glog" - "github.com/google/trillian" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/system-transparency/stfe" - "github.com/system-transparency/stfe/types" - "google.golang.org/grpc" -) - -var ( - httpEndpoint = flag.String("http_endpoint", "localhost:6965", "host:port specification of where stfe serves clients") - rpcBackend = flag.String("log_rpc_server", "localhost:6962", "host:port specification of where Trillian serves clients") - prefix = flag.String("prefix", "st/v0", "a prefix that proceeds each endpoint path") - trillianID = flag.Int64("trillian_id", 0, "log identifier in the Trillian database") - deadline = flag.Duration("deadline", time.Second*10, "deadline for backend requests") - key = flag.String("key", "", "hex-encoded Ed25519 signing key") - witnesses = flag.String("witnesses", "", "comma-separated list of trusted witness verification keys in hex") - maxRange = flag.Int64("max_range", 10, "maximum number of entries that can be retrived in a single request") - interval = flag.Duration("interval", time.Second*30, "interval used to rotate the log's cosigned STH") -) - -func main() { - flag.Parse() - defer glog.Flush() - - // wait for clean-up before exit - var wg sync.WaitGroup - defer wg.Wait() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - glog.V(3).Infof("configuring stfe instance...") - instance, err := setupInstanceFromFlags() - if err != nil { - glog.Errorf("setupInstance: %v", err) - return - } - - glog.V(3).Infof("spawning SthSource") - go func() { - wg.Add(1) - defer wg.Done() - instance.SthSource.Run(ctx) - glog.Errorf("SthSource shutdown") - cancel() // must have SthSource running - }() - - glog.V(3).Infof("spawning await") - server := http.Server{Addr: *httpEndpoint} - go await(ctx, func() { - wg.Add(1) - defer wg.Done() - ctxInner, _ := context.WithTimeout(ctx, time.Second*60) - glog.Infof("Shutting down HTTP server...") - server.Shutdown(ctxInner) - glog.V(3).Infof("HTTP server shutdown") - glog.Infof("Shutting down spawned go routines...") - cancel() - }) - - glog.Infof("Serving on %v/%v", *httpEndpoint, *prefix) - if err = server.ListenAndServe(); err != http.ErrServerClosed { - glog.Errorf("ListenAndServe: %v", err) - } -} - -// SetupInstance sets up a new STFE instance from flags -func setupInstanceFromFlags() (*stfe.Instance, error) { - // Trillian gRPC connection - dialOpts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(*deadline)} - conn, err := grpc.Dial(*rpcBackend, dialOpts...) - if err != nil { - return nil, fmt.Errorf("Dial: %v", err) - } - client := trillian.NewTrillianLogClient(conn) - // HTTP multiplexer - mux := http.NewServeMux() - http.Handle("/", mux) - // Prometheus metrics - glog.V(3).Infof("Adding prometheus handler on path: /metrics") - http.Handle("/metrics", promhttp.Handler()) - // Trusted witnesses - witnesses, err := newWitnessMap(*witnesses) - if err != nil { - return nil, fmt.Errorf("newWitnessMap: %v", err) - } - // Secret signing key - sk, err := hex.DecodeString(*key) - if err != nil { - return nil, fmt.Errorf("sk: DecodeString: %v", err) - } - // Setup log parameters - lp := &stfe.LogParameters{ - LogId: hex.EncodeToString([]byte(ed25519.PrivateKey(sk).Public().(ed25519.PublicKey))), - TreeId: *trillianID, - Prefix: *prefix, - MaxRange: *maxRange, - Deadline: *deadline, - Interval: *interval, - HashType: crypto.SHA256, - Signer: ed25519.PrivateKey(sk), - Witnesses: witnesses, - } - // Setup STH source - source, err := stfe.NewActiveSthSource(client, lp) - if err != nil { - return nil, fmt.Errorf("NewActiveSthSource: %v", err) - } - // Setup log instance - i := &stfe.Instance{client, lp, source} - for _, handler := range i.Handlers() { - glog.V(3).Infof("adding handler: %s", handler.Path()) - mux.Handle(handler.Path(), handler) - } - return i, nil -} - -// newWitnessMap creates a new map of trusted witnesses -func newWitnessMap(witnesses string) (map[[types.HashSize]byte][types.VerificationKeySize]byte, error) { - w := make(map[[types.HashSize]byte][types.VerificationKeySize]byte) - if len(witnesses) > 0 { - for _, witness := range strings.Split(witnesses, ",") { - b, err := hex.DecodeString(witness) - if err != nil { - return nil, fmt.Errorf("DecodeString: %v", err) - } - - var vk [types.VerificationKeySize]byte - if n := copy(vk[:], b); n != types.VerificationKeySize { - return nil, fmt.Errorf("Invalid verification key size: %v", n) - } - w[*types.Hash(vk[:])] = vk - } - } - return w, nil -} - -// await waits for a shutdown signal and then runs a clean-up function -func await(ctx context.Context, done func()) { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - select { - case <-sigs: - case <-ctx.Done(): - } - glog.V(3).Info("received shutdown signal") - done() -} diff --git a/state/state_manager.go b/state/state_manager.go deleted file mode 100644 index 3199e61..0000000 --- a/state/state_manager.go +++ /dev/null @@ -1,154 +0,0 @@ -package stfe - -import ( - "context" - "crypto" - "fmt" - "reflect" - "sync" - "time" - - "github.com/golang/glog" - "github.com/google/certificate-transparency-go/schedule" - "github.com/system-transparency/stfe/trillian" - "github.com/system-transparency/stfe/types" -) - -// StateManager coordinates access to the log's tree heads and (co)signatures -type StateManager interface { - Latest(context.Context) (*types.SignedTreeHead, error) - ToSign(context.Context) (*types.SignedTreeHead, error) - Cosigned(context.Context) (*types.SignedTreeHead, error) - AddCosignature(context.Context, *[types.VerificationKeySize]byte, *[types.SignatureSize]byte) error - Run(context.Context) -} - -// StateManagerSingle implements the StateManager interface. It is assumed that -// the log server is running on a single-instance machine. So, no coordination. -type StateManagerSingle struct { - client trillian.Client - signer crypto.Signer - interval time.Duration - deadline time.Duration - sync.RWMutex - - // cosigned is the current cosigned tree head that is being served - cosigned types.SignedTreeHead - - // tosign is the current tree head that is being cosigned by witnesses - tosign types.SignedTreeHead - - // cosignature keeps track of all cosignatures for the tosign tree head - cosignature map[[types.HashSize]byte]*types.SigIdent -} - -func NewStateManagerSingle(client trillian.Client, signer crypto.Signer, interval, deadline time.Duration) (*StateManagerSingle, error) { - sm := &StateManagerSingle{ - client: client, - signer: signer, - interval: interval, - deadline: deadline, - } - - ctx, _ := context.WithTimeout(context.Background(), sm.deadline) - sth, err := sm.Latest(ctx) - if err != nil { - return nil, fmt.Errorf("Latest: %v", err) - } - - sm.cosigned = *sth - sm.tosign = *sth - sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ - *sth.SigIdent[0].KeyHash: sth.SigIdent[0], // log signature - } - return sm, nil -} - -func (sm *StateManagerSingle) Run(ctx context.Context) { - schedule.Every(ctx, sm.interval, func(ctx context.Context) { - ictx, _ := context.WithTimeout(ctx, sm.deadline) - nextTreeHead, err := sm.Latest(ictx) - if err != nil { - glog.Warningf("rotate failed: Latest: %v", err) - return - } - - sm.Lock() - defer sm.Unlock() - sm.rotate(nextTreeHead) - }) -} - -func (sm *StateManagerSingle) Latest(ctx context.Context) (*types.SignedTreeHead, error) { - th, err := sm.client.GetTreeHead(ctx) - if err != nil { - return nil, fmt.Errorf("LatestTreeHead: %v", err) - } - sth, err := th.Sign(sm.signer) - if err != nil { - return nil, fmt.Errorf("sign: %v", err) - } - return sth, nil -} - -func (sm *StateManagerSingle) ToSign(_ context.Context) (*types.SignedTreeHead, error) { - sm.RLock() - defer sm.RUnlock() - return &sm.tosign, nil -} - -func (sm *StateManagerSingle) Cosigned(_ context.Context) (*types.SignedTreeHead, error) { - sm.RLock() - defer sm.RUnlock() - return &sm.cosigned, nil -} - -func (sm *StateManagerSingle) AddCosignature(_ context.Context, vk *[types.VerificationKeySize]byte, sig *[types.SignatureSize]byte) error { - sm.Lock() - defer sm.Unlock() - - if err := sm.tosign.TreeHead.Verify(vk, sig); err != nil { - return fmt.Errorf("Verify: %v", err) - } - witness := types.Hash(vk[:]) - if _, ok := sm.cosignature[*witness]; ok { - return fmt.Errorf("signature-signer pair is a duplicate") - } - sm.cosignature[*witness] = &types.SigIdent{ - Signature: sig, - KeyHash: witness, - } - - glog.V(3).Infof("accepted new cosignature from witness: %x", *witness) - return nil -} - -// rotate rotates the log's cosigned and stable STH. The caller must aquire the -// source's read-write lock if there are concurrent reads and/or writes. -func (sm *StateManagerSingle) rotate(next *types.SignedTreeHead) { - if reflect.DeepEqual(sm.cosigned.TreeHead, sm.tosign.TreeHead) { - // cosigned and tosign are the same. So, we need to merge all - // cosignatures that we already had with the new collected ones. - for _, sigident := range sm.cosigned.SigIdent { - if _, ok := sm.cosignature[*sigident.KeyHash]; !ok { - sm.cosignature[*sigident.KeyHash] = sigident - } - } - glog.V(3).Infof("cosigned tree head repeated, merged signatures") - } - var cosignatures []*types.SigIdent - for _, sigident := range sm.cosignature { - cosignatures = append(cosignatures, sigident) - } - - // Update cosigned tree head - sm.cosigned.TreeHead = sm.tosign.TreeHead - sm.cosigned.SigIdent = cosignatures - - // Update to-sign tree head - sm.tosign = *next - sm.cosignature = map[[types.HashSize]byte]*types.SigIdent{ - *next.SigIdent[0].KeyHash: next.SigIdent[0], // log signature - } - glog.V(3).Infof("rotated tree heads") -} diff --git a/state/state_manager_test.go b/state/state_manager_test.go deleted file mode 100644 index 348074c..0000000 --- a/state/state_manager_test.go +++ /dev/null @@ -1,393 +0,0 @@ -package stfe - -import ( - "bytes" - "context" - "crypto" - "crypto/ed25519" - "crypto/rand" - "fmt" - "reflect" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/system-transparency/stfe/mocks" - "github.com/system-transparency/stfe/types" -) - -var ( - testSig = &[types.SignatureSize]byte{} - testPub = &[types.VerificationKeySize]byte{} - testTH = &types.TreeHead{ - Timestamp: 0, - TreeSize: 0, - RootHash: types.Hash(nil), - } - testSigIdent = &types.SigIdent{ - Signature: testSig, - KeyHash: types.Hash(testPub[:]), - } - testSTH = &types.SignedTreeHead{ - TreeHead: *testTH, - SigIdent: []*types.SigIdent{testSigIdent}, - } - testSignerOK = &mocks.TestSigner{testPub, testSig, nil} - testSignerErr = &mocks.TestSigner{testPub, testSig, fmt.Errorf("something went wrong")} -) - -func TestNewStateManagerSingle(t *testing.T) { - for _, table := range []struct { - description string - signer crypto.Signer - rsp *types.TreeHead - err error - wantErr bool - wantSth *types.SignedTreeHead - }{ - { - description: "invalid: backend failure", - signer: testSignerOK, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "valid", - signer: testSignerOK, - rsp: testTH, - wantSth: testSTH, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - client := mocks.NewMockClient(ctrl) - client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) - - sm, err := NewStateManagerSingle(client, table.signer, time.Duration(0), time.Duration(0)) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := &sm.cosigned, table.wantSth; !reflect.DeepEqual(got, want) { - t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - if got, want := &sm.tosign, table.wantSth; !reflect.DeepEqual(got, want) { - t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - // we only have log signature on startup - if got, want := len(sm.cosignature), 1; got != want { - t.Errorf("got %d cosignatures but wanted %d in test %q", got, want, table.description) - } - }() - } -} - -func TestLatest(t *testing.T) { - for _, table := range []struct { - description string - signer crypto.Signer - rsp *types.TreeHead - err error - wantErr bool - wantSth *types.SignedTreeHead - }{ - { - description: "invalid: backend failure", - signer: testSignerOK, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: signature failure", - rsp: testTH, - signer: testSignerErr, - wantErr: true, - }, - { - description: "valid", - signer: testSignerOK, - rsp: testTH, - wantSth: testSTH, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - client := mocks.NewMockClient(ctrl) - client.EXPECT().GetTreeHead(gomock.Any()).Return(table.rsp, table.err) - sm := StateManagerSingle{ - client: client, - signer: table.signer, - } - - sth, err := sm.Latest(context.Background()) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := sth, table.wantSth; !reflect.DeepEqual(got, want) { - t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - }() - } -} - -func TestToSign(t *testing.T) { - description := "valid" - sm := StateManagerSingle{ - tosign: *testSTH, - } - sth, err := sm.ToSign(context.Background()) - if err != nil { - t.Errorf("ToSign should not fail with error: %v", err) - return - } - if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { - t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) - } -} - -func TestCosigned(t *testing.T) { - description := "valid" - sm := StateManagerSingle{ - cosigned: *testSTH, - } - sth, err := sm.Cosigned(context.Background()) - if err != nil { - t.Errorf("Cosigned should not fail with error: %v", err) - return - } - if got, want := sth, testSTH; !reflect.DeepEqual(got, want) { - t.Errorf("got signed tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, description) - } -} - -func TestAddCosignature(t *testing.T) { - vk, sk, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("GenerateKey: %v", err) - } - if bytes.Equal(vk[:], testPub[:]) { - t.Fatalf("Sampled same key as testPub, aborting...") - } - var vkArray [types.VerificationKeySize]byte - copy(vkArray[:], vk[:]) - - for _, table := range []struct { - description string - signer crypto.Signer - vk *[types.VerificationKeySize]byte - th *types.TreeHead - wantErr bool - }{ - { - description: "invalid: signature error", - signer: sk, - vk: testPub, // wrong key for message - th: testTH, - wantErr: true, - }, - { - description: "valid", - signer: sk, - vk: &vkArray, - th: testTH, - }, - } { - sth, _ := table.th.Sign(testSignerOK) - logKeyHash := sth.SigIdent[0].KeyHash - logSigIdent := sth.SigIdent[0] - sm := &StateManagerSingle{ - signer: testSignerOK, - cosigned: *sth, - tosign: *sth, - cosignature: map[[types.HashSize]byte]*types.SigIdent{ - *logKeyHash: logSigIdent, - }, - } - - // Prepare witness signature - sth, err := table.th.Sign(table.signer) - if err != nil { - t.Fatalf("Sign: %v", err) - } - witnessKeyHash := sth.SigIdent[0].KeyHash - witnessSigIdent := sth.SigIdent[0] - - // Add witness signature - err = sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - - // We should have two signatures (log + witness) - if got, want := len(sm.cosignature), 2; got != want { - t.Errorf("got %d cosignatures but wanted %v in test %q", got, want, table.description) - continue - } - // check that log signature is there - sigident, ok := sm.cosignature[*logKeyHash] - if !ok { - t.Errorf("log signature is missing") - continue - } - if got, want := sigident, logSigIdent; !reflect.DeepEqual(got, want) { - t.Errorf("got log sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - // check that witness signature is there - sigident, ok = sm.cosignature[*witnessKeyHash] - if !ok { - t.Errorf("witness signature is missing") - continue - } - if got, want := sigident, witnessSigIdent; !reflect.DeepEqual(got, want) { - t.Errorf("got witness sigident\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - continue - } - - // Adding a duplicate signature should give an error - if err := sm.AddCosignature(context.Background(), table.vk, witnessSigIdent.Signature); err == nil { - t.Errorf("duplicate witness signature accepted as valid") - } - } -} - -func TestRotate(t *testing.T) { - log := testSigIdent - wit1 := &types.SigIdent{ - Signature: testSig, - KeyHash: types.Hash([]byte("wit1 key")), - } - wit2 := &types.SigIdent{ - Signature: testSig, - KeyHash: types.Hash([]byte("wit2 key")), - } - th0 := testTH - th1 := &types.TreeHead{ - Timestamp: 1, - TreeSize: 1, - RootHash: types.Hash([]byte("1")), - } - th2 := &types.TreeHead{ - Timestamp: 2, - TreeSize: 2, - RootHash: types.Hash([]byte("2")), - } - - for _, table := range []struct { - description string - before, after *StateManagerSingle - next *types.SignedTreeHead - }{ - { - description: "tosign tree head repated, but got one new witnes signature", - before: &StateManagerSingle{ - cosigned: types.SignedTreeHead{ - TreeHead: *th0, - SigIdent: []*types.SigIdent{log, wit1}, - }, - tosign: types.SignedTreeHead{ - TreeHead: *th0, - SigIdent: []*types.SigIdent{log}, - }, - cosignature: map[[types.HashSize]byte]*types.SigIdent{ - *log.KeyHash: log, - *wit2.KeyHash: wit2, // the new witness signature - }, - }, - next: &types.SignedTreeHead{ - TreeHead: *th1, - SigIdent: []*types.SigIdent{log}, - }, - after: &StateManagerSingle{ - cosigned: types.SignedTreeHead{ - TreeHead: *th0, - SigIdent: []*types.SigIdent{log, wit1, wit2}, - }, - tosign: types.SignedTreeHead{ - TreeHead: *th1, - SigIdent: []*types.SigIdent{log}, - }, - cosignature: map[[types.HashSize]byte]*types.SigIdent{ - *log.KeyHash: log, // after rotate we always have log sig - }, - }, - }, - { - description: "tosign tree head did not repeat, it got one witness signature", - before: &StateManagerSingle{ - cosigned: types.SignedTreeHead{ - TreeHead: *th0, - SigIdent: []*types.SigIdent{log, wit1}, - }, - tosign: types.SignedTreeHead{ - TreeHead: *th1, - SigIdent: []*types.SigIdent{log}, - }, - cosignature: map[[types.HashSize]byte]*types.SigIdent{ - *log.KeyHash: log, - *wit2.KeyHash: wit2, // the only witness that signed tosign - }, - }, - next: &types.SignedTreeHead{ - TreeHead: *th2, - SigIdent: []*types.SigIdent{log}, - }, - after: &StateManagerSingle{ - cosigned: types.SignedTreeHead{ - TreeHead: *th1, - SigIdent: []*types.SigIdent{log, wit2}, - }, - tosign: types.SignedTreeHead{ - TreeHead: *th2, - SigIdent: []*types.SigIdent{log}, - }, - cosignature: map[[types.HashSize]byte]*types.SigIdent{ - *log.KeyHash: log, // after rotate we always have log sig - }, - }, - }, - } { - table.before.rotate(table.next) - if got, want := table.before.cosigned.TreeHead, table.after.cosigned.TreeHead; !reflect.DeepEqual(got, want) { - t.Errorf("got cosigned tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - checkWitnessList(t, table.description, table.before.cosigned.SigIdent, table.after.cosigned.SigIdent) - if got, want := table.before.tosign.TreeHead, table.after.tosign.TreeHead; !reflect.DeepEqual(got, want) { - t.Errorf("got tosign tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - checkWitnessList(t, table.description, table.before.tosign.SigIdent, table.after.tosign.SigIdent) - if got, want := table.before.cosignature, table.after.cosignature; !reflect.DeepEqual(got, want) { - t.Errorf("got cosignature map\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func checkWitnessList(t *testing.T, description string, got, want []*types.SigIdent) { - t.Helper() - for _, si := range got { - found := false - for _, sj := range want { - if reflect.DeepEqual(si, sj) { - found = true - break - } - } - if !found { - t.Errorf("got unexpected signature-signer pair with key hash in test %q: %x", description, si.KeyHash[:]) - } - } - if len(got) != len(want) { - t.Errorf("got %d signature-signer pairs but wanted %d in test %q", len(got), len(want), description) - } -} diff --git a/sth.go b/sth.go deleted file mode 100644 index 1399241..0000000 --- a/sth.go +++ /dev/null @@ -1,143 +0,0 @@ -package stfe - -import ( - "context" - "crypto/ed25519" - "fmt" - "reflect" - "sync" - - "github.com/golang/glog" - "github.com/google/certificate-transparency-go/schedule" - "github.com/google/trillian" - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/types" -) - -// SthSource provides access to the log's (co)signed tree heads -type SthSource interface { - Latest(context.Context) (*types.SignedTreeHead, error) - Stable(context.Context) (*types.SignedTreeHead, error) - Cosigned(context.Context) (*types.SignedTreeHead, error) - AddCosignature(context.Context, ed25519.PublicKey, *[types.SignatureSize]byte) error - Run(context.Context) -} - -// ActiveSthSource implements the SthSource interface for an STFE instance that -// accepts new logging requests, i.e., the log is running in read+write mode. -type ActiveSthSource struct { - client trillian.TrillianLogClient - logParameters *LogParameters - sync.RWMutex - - // cosigned is the current cosigned tree head that is served - cosigned types.SignedTreeHead - - // tosign is the current tree head that is being cosigned - tosign types.SignedTreeHead - - // cosignature keeps track of all collected cosignatures for tosign - cosignature map[[types.HashSize]byte]*types.SigIdent -} - -func NewActiveSthSource(cli trillian.TrillianLogClient, lp *LogParameters) (*ActiveSthSource, error) { - s := ActiveSthSource{ - client: cli, - logParameters: lp, - } - - ctx, _ := context.WithTimeout(context.Background(), lp.Deadline) - sth, err := s.Latest(ctx) - if err != nil { - return nil, fmt.Errorf("Latest: %v", err) - } - - s.cosigned = *sth - s.tosign = *sth - s.cosignature = make(map[[types.HashSize]byte]*types.SigIdent) - return &s, nil -} - -func (s *ActiveSthSource) Run(ctx context.Context) { - schedule.Every(ctx, s.logParameters.Interval, func(ctx context.Context) { - // get the next stable sth - ictx, _ := context.WithTimeout(ctx, s.logParameters.Deadline) - sth, err := s.Latest(ictx) - if err != nil { - glog.Warningf("cannot rotate without new sth: Latest: %v", err) - return - } - // rotate - s.Lock() - defer s.Unlock() - s.rotate(sth) - }) -} - -func (s *ActiveSthSource) Latest(ctx context.Context) (*types.SignedTreeHead, error) { - trsp, err := s.client.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ - LogId: s.logParameters.TreeId, - }) - var lr ttypes.LogRootV1 - if errInner := checkGetLatestSignedLogRoot(s.logParameters, trsp, err, &lr); errInner != nil { - return nil, fmt.Errorf("invalid signed log root response: %v", errInner) - } - return s.logParameters.Sign(NewTreeHeadFromLogRoot(&lr)) -} - -func (s *ActiveSthSource) Stable(_ context.Context) (*types.SignedTreeHead, error) { - s.RLock() - defer s.RUnlock() - return &s.tosign, nil -} - -func (s *ActiveSthSource) Cosigned(_ context.Context) (*types.SignedTreeHead, error) { - s.RLock() - defer s.RUnlock() - return &s.cosigned, nil -} - -func (s *ActiveSthSource) AddCosignature(_ context.Context, vk ed25519.PublicKey, sig *[types.SignatureSize]byte) error { - s.Lock() - defer s.Unlock() - - if msg := s.tosign.TreeHead.Marshal(); !ed25519.Verify(vk, msg, sig[:]) { - return fmt.Errorf("Invalid signature for tree head with timestamp: %d", s.tosign.TreeHead.Timestamp) - } - witness := types.Hash(vk[:]) - if _, ok := s.cosignature[*witness]; ok { - glog.V(3).Infof("received cosignature again (duplicate)") - return nil // duplicate - } - s.cosignature[*witness] = &types.SigIdent{ - Signature: sig, - KeyHash: witness, - } - glog.V(3).Infof("accepted new cosignature") - return nil -} - -// rotate rotates the log's cosigned and stable STH. The caller must aquire the -// source's read-write lock if there are concurrent reads and/or writes. -func (s *ActiveSthSource) rotate(next *types.SignedTreeHead) { - if reflect.DeepEqual(s.cosigned.TreeHead, s.tosign.TreeHead) { - for _, sigident := range s.cosigned.SigIdent[1:] { // skip log sigident - if _, ok := s.cosignature[*sigident.KeyHash]; !ok { - s.cosignature[*sigident.KeyHash] = sigident - } - } - } - var cosignatures []*types.SigIdent - for _, sigident := range s.cosignature { - cosignatures = append(cosignatures, sigident) - } // cosignatures contains all cosignatures, even if repeated tree head - - // Update cosigned tree head - s.cosigned.TreeHead = s.tosign.TreeHead - s.cosigned.SigIdent = append(s.tosign.SigIdent, cosignatures...) - - // Update to-sign tree head - s.tosign = *next - s.cosignature = make(map[[types.HashSize]byte]*types.SigIdent) - glog.V(3).Infof("rotated sth") -} diff --git a/sth_test.go b/sth_test.go deleted file mode 100644 index 0942ea1..0000000 --- a/sth_test.go +++ /dev/null @@ -1,466 +0,0 @@ -package stfe - -import ( - "context" - "crypto" - "fmt" - "reflect" - "testing" - - "github.com/golang/mock/gomock" - cttestdata "github.com/google/certificate-transparency-go/trillian/testdata" - "github.com/google/trillian" - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -func TestNewActiveSthSource(t *testing.T) { - for _, table := range []struct { - description string - signer crypto.Signer - trsp *trillian.GetLatestSignedLogRootResponse - terr error - wantErr bool - wantCosi *types.StItem // current cosigned sth - wantStable *types.StItem // next stable sth that signatures are collected for - }{ - { - description: "invalid: no Trillian response", - signer: cttestdata.NewSignerWithFixedSig(nil, testdata.Signature), - terr: fmt.Errorf("internal server error"), - wantErr: true, - }, - { - description: "valid", - signer: cttestdata.NewSignerWithFixedSig(nil, testdata.Signature), - trsp: testdata.DefaultTSlr(t), - wantCosi: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - wantStable: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, table.signer) - defer ti.ctrl.Finish() - ti.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - source, err := NewActiveSthSource(ti.client, ti.instance.LogParameters) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - - if got, want := source.currCosth, table.wantCosi; !reflect.DeepEqual(got, want) { - t.Errorf("got cosigned sth\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - if got, want := source.nextCosth, table.wantStable; !reflect.DeepEqual(got, want) { - t.Errorf("got stable sth\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - cosignatureFrom := make(map[[types.NamespaceFingerprintSize]byte]bool) - for _, cosig := range table.wantStable.CosignedTreeHeadV1.Cosignatures { - cosignatureFrom[testdata.Fingerprint(t, &cosig.Namespace)] = true - } - if got, want := source.cosignatureFrom, cosignatureFrom; !reflect.DeepEqual(got, want) { - if got == nil { - t.Errorf("got uninitialized witness map\n%v\n\tbut wanted\n%v\n\tin test %q", nil, want, table.description) - } else { - t.Errorf("got witness map\n%v\n\t but wanted\n%v\n\tin test %q", got, want, table.description) - } - } - }() - } -} - -func TestLatest(t *testing.T) { - for _, table := range []struct { - description string - signer crypto.Signer - trsp *trillian.GetLatestSignedLogRootResponse - terr error - wantErr bool - wantRsp *types.StItem - }{ - { - description: "invalid: no Trillian response", - signer: cttestdata.NewSignerWithFixedSig(nil, testdata.Signature), - terr: fmt.Errorf("internal server error"), - wantErr: true, - }, - { - description: "invalid: no signature", - signer: cttestdata.NewSignerWithErr(nil, fmt.Errorf("signing failed")), - terr: fmt.Errorf("internal server error"), - wantErr: true, - }, - { - description: "valid", - signer: cttestdata.NewSignerWithFixedSig(nil, testdata.Signature), - trsp: testdata.DefaultTSlr(t), - wantRsp: testdata.DefaultSth(t, testdata.Ed25519VkLog), - }, - } { - func() { // run deferred functions at the end of each iteration - ti := newTestInstance(t, table.signer) - defer ti.ctrl.Finish() - ti.client.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.trsp, table.terr) // no deadline matcher because context is set by the caller of Latest(), i.e., this test on the line below - sth, err := ti.instance.SthSource.Latest(context.Background()) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := sth, table.wantRsp; !reflect.DeepEqual(got, want) { - t.Errorf("got\n%v\n\tbut wanted\n%v\n\t in test %q", got, want, table.description) - } - }() - } -} - -func TestStable(t *testing.T) { - for _, table := range []struct { - description string - source SthSource - wantRsp *types.StItem - wantErr bool - }{ - { - description: "invalid: no stable sth", - source: &ActiveSthSource{}, - wantErr: true, - }, - { - description: "valid", - source: &ActiveSthSource{ - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - }, - wantRsp: testdata.DefaultSth(t, testdata.Ed25519VkLog), - }, - } { - sth, err := table.source.Stable(context.Background()) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := sth, table.wantRsp; !reflect.DeepEqual(got, want) { - t.Errorf("got\n%v\n\t but wanted\n%v\n\t in test %q", got, want, table.description) - } - } -} - -func TestCosigned(t *testing.T) { - for _, table := range []struct { - description string - source SthSource - wantRsp *types.StItem - wantErr bool - }{ - { - description: "invalid: no cosigned sth: nil", - source: &ActiveSthSource{}, - wantErr: true, - }, - { - description: "valid", - source: &ActiveSthSource{ - currCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - }, - wantRsp: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - }, - } { - cosi, err := table.source.Cosigned(context.Background()) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := cosi, table.wantRsp; !reflect.DeepEqual(got, want) { - t.Errorf("got\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - } -} - -func TestAddCosignature(t *testing.T) { - for _, table := range []struct { - description string - source *ActiveSthSource - req *types.StItem - wantWit []*types.Namespace - wantErr bool - }{ - { - description: "invalid: cosignature must target the stable sth", - source: &ActiveSthSource{ - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool), - }, - req: testdata.DefaultCosth(t, testdata.Ed25519VkLog2, [][32]byte{testdata.Ed25519VkWitness}), - wantErr: true, - }, - { - description: "valid: adding duplicate into a pool of cosignatures", - source: &ActiveSthSource{ - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, testdata.NewNamespace(t, testdata.Ed25519VkWitness)): true, - }, - }, - req: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - wantWit: []*types.Namespace{testdata.NewNamespace(t, testdata.Ed25519VkWitness)}, - }, - { - description: "valid: adding into an empty pool of cosignatures", - source: &ActiveSthSource{ - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil), - cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool), - }, - req: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - wantWit: []*types.Namespace{testdata.NewNamespace(t, testdata.Ed25519VkWitness)}, - }, - { - description: "valid: adding into a pool of cosignatures", - source: &ActiveSthSource{ - nextCosth: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, testdata.NewNamespace(t, testdata.Ed25519VkWitness)): true, - }, - }, - req: testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness2}), - wantWit: []*types.Namespace{testdata.NewNamespace(t, testdata.Ed25519VkWitness), testdata.NewNamespace(t, testdata.Ed25519VkWitness2)}, - }, - } { - err := table.source.AddCosignature(context.Background(), table.req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - - // Check that the next cosigned sth is updated - var sigs []types.SignatureV1 - for _, wit := range table.wantWit { - sigs = append(sigs, types.SignatureV1{ - Namespace: *wit, - Signature: testdata.Signature, - }) - } - if got, want := table.source.nextCosth, types.NewCosignedTreeHeadV1(testdata.DefaultSth(t, testdata.Ed25519VkLog).SignedTreeHeadV1, sigs); !reflect.DeepEqual(got, want) { - t.Errorf("got\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - // Check that the map tracking witness signatures is updated - if got, want := len(table.source.cosignatureFrom), len(table.wantWit); got != want { - t.Errorf("witness map got %d cosignatures but wanted %d in test %q", got, want, table.description) - } else { - for _, wit := range table.wantWit { - if _, ok := table.source.cosignatureFrom[testdata.Fingerprint(t, wit)]; !ok { - t.Errorf("missing signature from witness %X in test %q", testdata.Fingerprint(t, wit), table.description) - } - } - } - } -} - -func TestRotate(t *testing.T) { - // distinct sths - sth1 := testdata.DefaultSth(t, testdata.Ed25519VkLog) - sth2 := testdata.DefaultSth(t, testdata.Ed25519VkLog2) - sth3 := testdata.DefaultSth(t, testdata.Ed25519VkLog3) - // distinct witnesses - wit1 := testdata.NewNamespace(t, testdata.Ed25519VkWitness) - wit2 := testdata.NewNamespace(t, testdata.Ed25519VkWitness2) - wit3 := testdata.NewNamespace(t, testdata.Ed25519VkWitness3) - for _, table := range []struct { - description string - source *ActiveSthSource - fixedSth *types.StItem - wantCurrSth *types.StItem - wantNextSth *types.StItem - wantWit []*types.Namespace - }{ - { - description: "not repeated cosigned and not repeated stable", - source: &ActiveSthSource{ - currCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, nil), - nextCosth: types.NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, wit1): true, - }, - }, - fixedSth: sth3, - wantCurrSth: types.NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantNextSth: types.NewCosignedTreeHeadV1(sth3.SignedTreeHeadV1, nil), - wantWit: nil, // no cosignatures for the next stable sth yet - }, - { - description: "not repeated cosigned and repeated stable", - source: &ActiveSthSource{ - currCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, nil), - nextCosth: types.NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, wit1): true, - }, - }, - fixedSth: sth2, - wantCurrSth: types.NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantNextSth: types.NewCosignedTreeHeadV1(sth2.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantWit: []*types.Namespace{wit1}, - }, - { - description: "repeated cosigned and not repeated stable", - source: &ActiveSthSource{ - currCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - }), - nextCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit3, - Signature: testdata.Signature, - }, - }), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, wit2): true, - testdata.Fingerprint(t, wit3): true, - }, - }, - fixedSth: sth3, - wantCurrSth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit3, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantNextSth: types.NewCosignedTreeHeadV1(sth3.SignedTreeHeadV1, nil), - wantWit: nil, // no cosignatures for the next stable sth yet - }, - { - description: "repeated cosigned and repeated stable", - source: &ActiveSthSource{ - currCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - }), - nextCosth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit3, - Signature: testdata.Signature, - }, - }), - cosignatureFrom: map[[types.NamespaceFingerprintSize]byte]bool{ - testdata.Fingerprint(t, wit2): true, - testdata.Fingerprint(t, wit3): true, - }, - }, - fixedSth: sth1, - wantCurrSth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit3, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantNextSth: types.NewCosignedTreeHeadV1(sth1.SignedTreeHeadV1, []types.SignatureV1{ - types.SignatureV1{ - Namespace: *wit2, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit3, - Signature: testdata.Signature, - }, - types.SignatureV1{ - Namespace: *wit1, - Signature: testdata.Signature, - }, - }), - wantWit: []*types.Namespace{wit1, wit2, wit3}, - }, - } { - table.source.rotate(table.fixedSth) - if got, want := table.source.currCosth, table.wantCurrSth; !reflect.DeepEqual(got, want) { - t.Errorf("got currCosth\n%v\n\tbut wanted \n%v\n\tin test %q", got, want, table.description) - } - if got, want := table.source.nextCosth, table.wantNextSth; !reflect.DeepEqual(got, want) { - t.Errorf("got nextCosth\n%v\n\tbut wanted\n%v\n\tin test %q", got, want, table.description) - } - if got, want := len(table.source.cosignatureFrom), len(table.wantWit); got != want { - t.Errorf("witness map got %d cosignatures but wanted %d in test %q", got, want, table.description) - } else { - for _, wit := range table.wantWit { - if _, ok := table.source.cosignatureFrom[testdata.Fingerprint(t, wit)]; !ok { - t.Errorf("missing signature from witness %X in test %q", testdata.Fingerprint(t, wit), table.description) - } - } - } - // check that adding cosignatures to stable will not effect cosigned sth - wantLen := len(table.source.currCosth.CosignedTreeHeadV1.Cosignatures) - table.source.nextCosth.CosignedTreeHeadV1.Cosignatures = append(table.source.nextCosth.CosignedTreeHeadV1.Cosignatures, types.SignatureV1{Namespace: *wit1, Signature: testdata.Signature}) - if gotLen := len(table.source.currCosth.CosignedTreeHeadV1.Cosignatures); gotLen != wantLen { - t.Errorf("adding cosignatures to the stable sth modifies the fixated cosigned sth in test %q", table.description) - } - } -} diff --git a/testdata/data.go b/testdata/data.go deleted file mode 100644 index ac958e5..0000000 --- a/testdata/data.go +++ /dev/null @@ -1,287 +0,0 @@ -package testdata - -import ( - "bytes" - "testing" - "time" - - "crypto/ed25519" - - "github.com/google/trillian" - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/types" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -var ( - Ed25519VkLog = [32]byte{} - Ed25519VkLog2 = [32]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - Ed25519VkLog3 = [32]byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2} - //Ed25519VkWitness = [32]byte{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3} - // Ed25519VkWitness2 = [32]byte{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4} - Ed25519VkWitness3 = [32]byte{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5} - //Ed25519VkSubmitter = [32]byte{6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6} - - TreeId = int64(0) - Prefix = "test" - MaxRange = int64(3) - Interval = time.Second * 10 - Deadline = time.Second * 5 - - Timestamp = uint64(0) - TreeSize = uint64(0) - Extension = make([]byte, 0) - NodeHash = make([]byte, 32) - Signature = make([]byte, 64) - Identifier = []byte("foobar-1.2.3") - Checksum = make([]byte, 32) - Index = int64(0) - HashPath = [][]byte{ - NodeHash, - } - NodePath = []types.NodeHash{ - types.NodeHash{NodeHash}, - } - LeafHash = [32]byte{} - - // TODO: make these unique and load more pretty maybe - Ed25519SkWitness = [64]byte{230, 122, 195, 152, 194, 195, 147, 153, 80, 120, 153, 79, 102, 27, 52, 187, 136, 218, 150, 234, 107, 9, 167, 4, 92, 21, 11, 113, 42, 29, 129, 69, 75, 60, 249, 150, 229, 93, 75, 32, 103, 126, 244, 37, 53, 182, 68, 82, 249, 109, 49, 94, 10, 19, 146, 244, 58, 191, 169, 107, 78, 37, 45, 210} - Ed25519VkWitness = [32]byte{75, 60, 249, 150, 229, 93, 75, 32, 103, 126, 244, 37, 53, 182, 68, 82, 249, 109, 49, 94, 10, 19, 146, 244, 58, 191, 169, 107, 78, 37, 45, 210} - - Ed25519SkWitness2 = [64]byte{98, 65, 92, 117, 33, 167, 138, 36, 252, 147, 87, 173, 44, 62, 17, 66, 126, 70, 218, 87, 91, 148, 64, 194, 241, 248, 62, 90, 140, 122, 234, 76, 144, 6, 250, 185, 37, 217, 77, 201, 180, 42, 81, 37, 165, 27, 22, 32, 25, 8, 156, 228, 78, 207, 208, 18, 91, 77, 189, 51, 112, 31, 237, 6} - Ed25519VkWitness2 = [32]byte{144, 6, 250, 185, 37, 217, 77, 201, 180, 42, 81, 37, 165, 27, 22, 32, 25, 8, 156, 228, 78, 207, 208, 18, 91, 77, 189, 51, 112, 31, 237, 6} - - Ed25519SkSubmitter = [64]byte{230, 122, 195, 152, 194, 195, 147, 153, 80, 120, 153, 79, 102, 27, 52, 187, 136, 218, 150, 234, 107, 9, 167, 4, 92, 21, 11, 113, 42, 29, 129, 69, 75, 60, 249, 150, 229, 93, 75, 32, 103, 126, 244, 37, 53, 182, 68, 82, 249, 109, 49, 94, 10, 19, 146, 244, 58, 191, 169, 107, 78, 37, 45, 210} - Ed25519VkSubmitter = [32]byte{75, 60, 249, 150, 229, 93, 75, 32, 103, 126, 244, 37, 53, 182, 68, 82, 249, 109, 49, 94, 10, 19, 146, 244, 58, 191, 169, 107, 78, 37, 45, 210} - Ed25519SkSubmitter2 = [64]byte{98, 65, 92, 117, 33, 167, 138, 36, 252, 147, 87, 173, 44, 62, 17, 66, 126, 70, 218, 87, 91, 148, 64, 194, 241, 248, 62, 90, 140, 122, 234, 76, 144, 6, 250, 185, 37, 217, 77, 201, 180, 42, 81, 37, 165, 27, 22, 32, 25, 8, 156, 228, 78, 207, 208, 18, 91, 77, 189, 51, 112, 31, 237, 6} - Ed25519VkSubmitter2 = [32]byte{144, 6, 250, 185, 37, 217, 77, 201, 180, 42, 81, 37, 165, 27, 22, 32, 25, 8, 156, 228, 78, 207, 208, 18, 91, 77, 189, 51, 112, 31, 237, 6} -) - -// TODO: reorder and docdoc where need be -// -// Helpers that must create default values for different STFE types -// - -func DefaultCosth(t *testing.T, logVk [32]byte, witVk [][32]byte) *types.StItem { - t.Helper() - cosigs := make([]types.SignatureV1, 0) - for _, vk := range witVk { - cosigs = append(cosigs, types.SignatureV1{*NewNamespace(t, vk), Signature}) - } - return types.NewCosignedTreeHeadV1(DefaultSth(t, logVk).SignedTreeHeadV1, cosigs) -} - -func DefaultSth(t *testing.T, vk [32]byte) *types.StItem { - t.Helper() - return types.NewSignedTreeHeadV1(DefaultTh(t), DefaultSig(t, vk)) -} - -func DefaultSignedChecksum(t *testing.T, vk [32]byte) *types.StItem { - t.Helper() - return types.NewSignedChecksumV1(DefaultChecksum(t), DefaultSig(t, vk)) -} - -func DefaultTh(t *testing.T) *types.TreeHeadV1 { - t.Helper() - return types.NewTreeHeadV1(Timestamp, TreeSize, NodeHash, Extension) -} - -func DefaultSig(t *testing.T, vk [32]byte) *types.SignatureV1 { - t.Helper() - return &types.SignatureV1{*NewNamespace(t, vk), Signature} -} - -func DefaultChecksum(t *testing.T) *types.ChecksumV1 { - t.Helper() - return &types.ChecksumV1{Identifier, Checksum} -} - -func AddCosignatureBuffer(t *testing.T, sth *types.StItem, sk *[64]byte, vk *[32]byte) *bytes.Buffer { - t.Helper() - var cosigs []types.SignatureV1 - if vk != nil { - cosigs = []types.SignatureV1{ - types.SignatureV1{ - Namespace: *NewNamespace(t, *vk), - Signature: ed25519.Sign(ed25519.PrivateKey((*sk)[:]), marshal(t, *sth.SignedTreeHeadV1)), - }, - } - } - return bytes.NewBuffer(marshal(t, *types.NewCosignedTreeHeadV1(sth.SignedTreeHeadV1, cosigs))) -} - -func AddSignedChecksumBuffer(t *testing.T, sk [64]byte, vk [32]byte) *bytes.Buffer { - t.Helper() - data := DefaultChecksum(t) - return bytes.NewBuffer(marshal(t, *types.NewSignedChecksumV1( - data, - &types.SignatureV1{ - Namespace: *NewNamespace(t, vk), - Signature: ed25519.Sign(ed25519.PrivateKey(sk[:]), marshal(t, *data)), - }, - ))) -} - -func NewNamespacePool(t *testing.T, namespaces []*types.Namespace) *types.NamespacePool { - pool, err := types.NewNamespacePool(namespaces) - if err != nil { - t.Fatalf("must make namespace pool: %v", err) - } - return pool -} - -func NewNamespace(t *testing.T, vk [32]byte) *types.Namespace { - namespace, err := types.NewNamespaceEd25519V1(vk[:]) - if err != nil { - t.Fatalf("must make Ed25519V1 namespace: %v", err) - } - return namespace -} - -// -// Helpers that must create default values for different Trillian types -// - -// DefaultTLr creates a default Trillian log root -func DefaultTLr(t *testing.T) *ttypes.LogRootV1 { - t.Helper() - return Tlr(t, TreeSize, Timestamp, NodeHash) -} - -// Tlr creates a Trillian log root -func Tlr(t *testing.T, size, timestamp uint64, hash []byte) *ttypes.LogRootV1 { - t.Helper() - return &ttypes.LogRootV1{ - TreeSize: size, - RootHash: hash, - TimestampNanos: timestamp, - Revision: 0, // not used by stfe - Metadata: nil, // not used by stfe - } -} - -// DefaultTSlr creates a default Trillian signed log root -func DefaultTSlr(t *testing.T) *trillian.GetLatestSignedLogRootResponse { - t.Helper() - return Tslr(t, DefaultTLr(t)) -} - -// Tslr creates a Trillian signed log root -func Tslr(t *testing.T, lr *ttypes.LogRootV1) *trillian.GetLatestSignedLogRootResponse { - t.Helper() - b, err := lr.MarshalBinary() - if err != nil { - t.Fatalf("must marshal Trillian log root: %v", err) - } - return &trillian.GetLatestSignedLogRootResponse{ - SignedLogRoot: &trillian.SignedLogRoot{ - KeyHint: nil, // not used by stfe - LogRoot: b, - LogRootSignature: nil, // not used by stfe - }, - Proof: nil, // not used by stfe - } -} - -// DefaultTQlr creates a default Trillian queue leaf response -func DefaultTQlr(t *testing.T, withDupCode bool) *trillian.QueueLeafResponse { - t.Helper() - s := status.New(codes.OK, "ok").Proto() - if withDupCode { - s = status.New(codes.AlreadyExists, "duplicate").Proto() - } - return &trillian.QueueLeafResponse{ - QueuedLeaf: &trillian.QueuedLogLeaf{ - Leaf: &trillian.LogLeaf{ - MerkleLeafHash: nil, // not used by stfe - LeafValue: marshal(t, *DefaultSignedChecksum(t, Ed25519VkSubmitter)), - ExtraData: nil, // not used by stfe - LeafIndex: 0, // not applicable (log is not pre-ordered) - LeafIdentityHash: nil, // not used by stfe - }, - Status: s, - }, - } -} - -// DefaultTglbrr creates a default Trillian get leaves by range response -func DefaultTGlbrr(t *testing.T, start, end int64) *trillian.GetLeavesByRangeResponse { - t.Helper() - leaves := make([]*trillian.LogLeaf, 0, end-start+1) - for i, n := start, end+1; i < n; i++ { - leaves = append(leaves, &trillian.LogLeaf{ - MerkleLeafHash: nil, // not usedb y stfe - LeafValue: marshal(t, *DefaultSignedChecksum(t, Ed25519VkSubmitter)), - ExtraData: nil, // not used by stfe - LeafIndex: i, - LeafIdentityHash: nil, // not used by stfe - }) - } - return &trillian.GetLeavesByRangeResponse{ - Leaves: leaves, - SignedLogRoot: Tslr(t, Tlr(t, uint64(end)+1, Timestamp, NodeHash)).SignedLogRoot, - } -} - -func DefaultStItemList(t *testing.T, start, end uint64) *types.StItemList { - items := make([]types.StItem, 0, end-start+1) - for i, n := start, end+1; i < n; i++ { - items = append(items, *DefaultSignedChecksum(t, Ed25519VkSubmitter)) - } - return &types.StItemList{items} -} - -// DefaultTGipbhr creates a default Trillian get inclusion proof by hash response -func DefaultTGipbhr(t *testing.T) *trillian.GetInclusionProofByHashResponse { - t.Helper() - return &trillian.GetInclusionProofByHashResponse{ - Proof: []*trillian.Proof{ - &trillian.Proof{ - LeafIndex: Index, - Hashes: HashPath, - }, - }, - SignedLogRoot: nil, // not used by stfe - } -} - -func DefaultInclusionProof(t *testing.T, size uint64) *types.StItem { - return types.NewInclusionProofV1(NewNamespace(t, Ed25519VkLog), size, uint64(Index), NodePath) -} - -// DefaultTGcpr creates a default Trillian get consistency proof response -func DefaultTGcpr(t *testing.T) *trillian.GetConsistencyProofResponse { - t.Helper() - return &trillian.GetConsistencyProofResponse{ - Proof: &trillian.Proof{ - LeafIndex: 0, // not applicable for consistency proofs - Hashes: HashPath, - }, - SignedLogRoot: nil, // not used by stfe - } -} - -func DefaultConsistencyProof(t *testing.T, first, second uint64) *types.StItem { - return types.NewConsistencyProofV1(NewNamespace(t, Ed25519VkLog), first, second, NodePath) -} - -// -// Other helpers -// - -func Fingerprint(t *testing.T, namespace *types.Namespace) [types.NamespaceFingerprintSize]byte { - fpr, err := namespace.Fingerprint() - if err != nil { - t.Fatalf("must have namespace fingerprint: %v", err) - } - return *fpr -} - -func marshal(t *testing.T, i interface{}) []byte { - b, err := types.Marshal(i) - if err != nil { - t.Fatalf("must marshal interface: %v", err) - } - return b -} diff --git a/trillian.go b/trillian.go deleted file mode 100644 index f358d4d..0000000 --- a/trillian.go +++ /dev/null @@ -1,125 +0,0 @@ -package stfe - -import ( - "fmt" - - "github.com/golang/glog" - "github.com/google/trillian" - "github.com/google/trillian/types" - stfetypes "github.com/system-transparency/stfe/types" - "google.golang.org/grpc/codes" -) - -func checkQueueLeaf(rsp *trillian.QueueLeafResponse, err error) error { - if err != nil { - return fmt.Errorf("Trillian error: %v", err) - } - if rsp == nil { - return fmt.Errorf("Trillian error: empty response") - } - if rsp.QueuedLeaf == nil { - return fmt.Errorf("Trillian error: empty QueuedLeaf") - } - if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists { - glog.V(3).Infof("queued leaf is a duplicate => %X", rsp.QueuedLeaf.Leaf.LeafValue) - } - return nil -} - -func checkGetLeavesByRange(req *stfetypes.LeavesRequest, rsp *trillian.GetLeavesByRangeResponse, err error) error { - if err != nil { - return fmt.Errorf("Trillian Error: %v", err) - } - if rsp == nil { - return fmt.Errorf("Trillian error: empty response") - } - if rsp.SignedLogRoot == nil { - return fmt.Errorf("Trillian error: no signed log root") - } - if rsp.SignedLogRoot.LogRoot == nil { - return fmt.Errorf("Trillian error: no log root") - } - if len(rsp.Leaves) == 0 { - return fmt.Errorf("Trillian error: no leaves") - } - if len(rsp.Leaves) > int(req.EndSize-req.StartSize+1) { - return fmt.Errorf("too many leaves: %d for [%d,%d]", len(rsp.Leaves), req.StartSize, req.EndSize) - } - - // Ensure that a bad start parameter results in an error - var lr types.LogRootV1 - if err := lr.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { - return fmt.Errorf("cannot unmarshal log root: %v", err) - } - if uint64(req.StartSize) >= lr.TreeSize { - return fmt.Errorf("invalid start(%d): tree size is %d", req.StartSize, lr.TreeSize) - } - - // Ensure that we got and return expected leaf indices - for i, leaf := range rsp.Leaves { - if got, want := leaf.LeafIndex, int64(req.StartSize+uint64(i)); got != want { - return fmt.Errorf("invalid leaf index(%d): wanted %d", got, want) - } - } - return nil -} - -func checkGetInclusionProofByHash(lp *LogParameters, rsp *trillian.GetInclusionProofByHashResponse, err error) error { - if err != nil { - return fmt.Errorf("Trillian Error: %v", err) - } - if rsp == nil { - return fmt.Errorf("Trillian error: empty response") - } - if len(rsp.Proof) == 0 { - return fmt.Errorf("Trillian error: no proofs") - } - if rsp.Proof[0] == nil { - return fmt.Errorf("Trillian error: no proof") - } - return checkHashPath(lp.HashType.Size(), rsp.Proof[0].Hashes) -} - -func checkGetConsistencyProof(lp *LogParameters, rsp *trillian.GetConsistencyProofResponse, err error) error { - if err != nil { - return fmt.Errorf("Trillian Error: %v", err) - } - if rsp == nil { - return fmt.Errorf("Trillian error: empty response") - } - if rsp.Proof == nil { - return fmt.Errorf("Trillian error: no proof") - } - return checkHashPath(lp.HashType.Size(), rsp.Proof.Hashes) -} - -func checkGetLatestSignedLogRoot(lp *LogParameters, rsp *trillian.GetLatestSignedLogRootResponse, err error, out *types.LogRootV1) error { - if err != nil { - return fmt.Errorf("Trillian Error: %v", err) - } - if rsp == nil { - return fmt.Errorf("Trillian error: empty response") - } - if rsp.SignedLogRoot == nil { - return fmt.Errorf("Trillian error: no signed log root") - } - if rsp.SignedLogRoot.LogRoot == nil { - return fmt.Errorf("Trillian error: no log root") - } - if err := out.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { - return fmt.Errorf("cannot unmarshal log root: %v", err) - } - if len(out.RootHash) != lp.HashType.Size() { - return fmt.Errorf("invalid root hash: %v", out.RootHash) - } - return nil -} - -func checkHashPath(hashSize int, path [][]byte) error { - for _, hash := range path { - if len(hash) != hashSize { - return fmt.Errorf("invalid proof: %v", path) - } - } - return nil -} diff --git a/trillian/client.go b/trillian/client.go deleted file mode 100644 index 9ea6a4a..0000000 --- a/trillian/client.go +++ /dev/null @@ -1,178 +0,0 @@ -package trillian - -import ( - "context" - "fmt" - - "github.com/golang/glog" - "github.com/google/trillian" - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/types" - "google.golang.org/grpc/codes" -) - -type Client interface { - AddLeaf(context.Context, *types.LeafRequest) error - GetConsistencyProof(context.Context, *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) - GetTreeHead(context.Context) (*types.TreeHead, error) - GetInclusionProof(context.Context, *types.InclusionProofRequest) (*types.InclusionProof, error) - GetLeaves(context.Context, *types.LeavesRequest) (*types.LeafList, error) -} - -// TrillianClient is a wrapper around the Trillian gRPC client. -type TrillianClient struct { - // TreeID is a Merkle tree identifier that Trillian uses - TreeID int64 - - // GRPC is a Trillian gRPC client - GRPC trillian.TrillianLogClient -} - -func (c *TrillianClient) AddLeaf(ctx context.Context, req *types.LeafRequest) error { - leaf := types.Leaf{ - Message: req.Message, - SigIdent: types.SigIdent{ - Signature: req.Signature, - KeyHash: types.Hash(req.VerificationKey[:]), - }, - } - serialized := leaf.Marshal() - - glog.V(3).Infof("queueing leaf request: %x", types.HashLeaf(serialized)) - rsp, err := c.GRPC.QueueLeaf(ctx, &trillian.QueueLeafRequest{ - LogId: c.TreeID, - Leaf: &trillian.LogLeaf{ - LeafValue: serialized, - }, - }) - if err != nil { - return fmt.Errorf("backend failure: %v", err) - } - if rsp == nil { - return fmt.Errorf("no response") - } - if rsp.QueuedLeaf == nil { - return fmt.Errorf("no queued leaf") - } - if codes.Code(rsp.QueuedLeaf.GetStatus().GetCode()) == codes.AlreadyExists { - return fmt.Errorf("leaf is already queued or included") - } - return nil -} - -func (c *TrillianClient) GetTreeHead(ctx context.Context) (*types.TreeHead, error) { - rsp, err := c.GRPC.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{ - LogId: c.TreeID, - }) - if err != nil { - return nil, fmt.Errorf("backend failure: %v", err) - } - if rsp == nil { - return nil, fmt.Errorf("no response") - } - if rsp.SignedLogRoot == nil { - return nil, fmt.Errorf("no signed log root") - } - if rsp.SignedLogRoot.LogRoot == nil { - return nil, fmt.Errorf("no log root") - } - var r ttypes.LogRootV1 - if err := r.UnmarshalBinary(rsp.SignedLogRoot.LogRoot); err != nil { - return nil, fmt.Errorf("no log root: unmarshal failed: %v", err) - } - if len(r.RootHash) != types.HashSize { - return nil, fmt.Errorf("unexpected hash length: %d", len(r.RootHash)) - } - return treeHeadFromLogRoot(&r), nil -} - -func (c *TrillianClient) GetConsistencyProof(ctx context.Context, req *types.ConsistencyProofRequest) (*types.ConsistencyProof, error) { - rsp, err := c.GRPC.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ - LogId: c.TreeID, - FirstTreeSize: int64(req.OldSize), - SecondTreeSize: int64(req.NewSize), - }) - if err != nil { - return nil, fmt.Errorf("backend failure: %v", err) - } - if rsp == nil { - return nil, fmt.Errorf("no response") - } - if rsp.Proof == nil { - return nil, fmt.Errorf("no consistency proof") - } - if len(rsp.Proof.Hashes) == 0 { - return nil, fmt.Errorf("not a consistency proof: empty") - } - path, err := nodePathFromHashes(rsp.Proof.Hashes) - if err != nil { - return nil, fmt.Errorf("not a consistency proof: %v", err) - } - return &types.ConsistencyProof{ - OldSize: req.OldSize, - NewSize: req.NewSize, - Path: path, - }, nil -} - -func (c *TrillianClient) GetInclusionProof(ctx context.Context, req *types.InclusionProofRequest) (*types.InclusionProof, error) { - rsp, err := c.GRPC.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ - LogId: c.TreeID, - LeafHash: req.LeafHash[:], - TreeSize: int64(req.TreeSize), - OrderBySequence: true, - }) - if err != nil { - return nil, fmt.Errorf("backend failure: %v", err) - } - if rsp == nil { - return nil, fmt.Errorf("no response") - } - if len(rsp.Proof) != 1 { - return nil, fmt.Errorf("bad proof count: %d", len(rsp.Proof)) - } - proof := rsp.Proof[0] - if len(proof.Hashes) == 0 { - return nil, fmt.Errorf("not an inclusion proof: empty") - } - path, err := nodePathFromHashes(proof.Hashes) - if err != nil { - return nil, fmt.Errorf("not an inclusion proof: %v", err) - } - return &types.InclusionProof{ - TreeSize: req.TreeSize, - LeafIndex: uint64(proof.LeafIndex), - Path: path, - }, nil -} - -func (c *TrillianClient) GetLeaves(ctx context.Context, req *types.LeavesRequest) (*types.LeafList, error) { - rsp, err := c.GRPC.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ - LogId: c.TreeID, - StartIndex: int64(req.StartSize), - Count: int64(req.EndSize-req.StartSize) + 1, - }) - if err != nil { - return nil, fmt.Errorf("backend failure: %v", err) - } - if rsp == nil { - return nil, fmt.Errorf("no response") - } - if got, want := len(rsp.Leaves), int(req.EndSize-req.StartSize+1); got != want { - return nil, fmt.Errorf("unexpected number of leaves: %d", got) - } - var list types.LeafList - for i, leaf := range rsp.Leaves { - leafIndex := int64(req.StartSize + uint64(i)) - if leafIndex != leaf.LeafIndex { - return nil, fmt.Errorf("unexpected leaf(%d): got index %d", leafIndex, leaf.LeafIndex) - } - - var l types.Leaf - if err := l.Unmarshal(leaf.LeafValue); err != nil { - return nil, fmt.Errorf("unexpected leaf(%d): %v", leafIndex, err) - } - list = append(list[:], &l) - } - return &list, nil -} diff --git a/trillian/client_test.go b/trillian/client_test.go deleted file mode 100644 index e9f1ff5..0000000 --- a/trillian/client_test.go +++ /dev/null @@ -1,533 +0,0 @@ -package trillian - -import ( - "context" - "fmt" - "reflect" - "testing" - - "github.com/golang/mock/gomock" - "github.com/google/trillian" - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/mocks" - "github.com/system-transparency/stfe/types" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestAddLeaf(t *testing.T) { - req := &types.LeafRequest{ - Message: types.Message{ - ShardHint: 0, - Checksum: &[types.HashSize]byte{}, - }, - Signature: &[types.SignatureSize]byte{}, - VerificationKey: &[types.VerificationKeySize]byte{}, - DomainHint: "example.com", - } - for _, table := range []struct { - description string - req *types.LeafRequest - rsp *trillian.QueueLeafResponse - err error - wantErr bool - }{ - { - description: "invalid: backend failure", - req: req, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: no response", - req: req, - wantErr: true, - }, - { - description: "invalid: no queued leaf", - req: req, - rsp: &trillian.QueueLeafResponse{}, - wantErr: true, - }, - { - description: "invalid: leaf is already queued or included", - req: req, - rsp: &trillian.QueueLeafResponse{ - QueuedLeaf: &trillian.QueuedLogLeaf{ - Leaf: &trillian.LogLeaf{ - LeafValue: req.Message.Marshal(), - }, - Status: status.New(codes.AlreadyExists, "duplicate").Proto(), - }, - }, - wantErr: true, - }, - { - description: "valid", - req: req, - rsp: &trillian.QueueLeafResponse{ - QueuedLeaf: &trillian.QueuedLogLeaf{ - Leaf: &trillian.LogLeaf{ - LeafValue: req.Message.Marshal(), - }, - Status: status.New(codes.OK, "ok").Proto(), - }, - }, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - grpc := mocks.NewMockTrillianLogClient(ctrl) - grpc.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) - client := TrillianClient{GRPC: grpc} - - err := client.AddLeaf(context.Background(), table.req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - }() - } -} - -func TestGetTreeHead(t *testing.T) { - // valid root - root := &ttypes.LogRootV1{ - TreeSize: 0, - RootHash: make([]byte, types.HashSize), - TimestampNanos: 1622585623133599429, - } - buf, err := root.MarshalBinary() - if err != nil { - t.Fatalf("must marshal log root: %v", err) - } - // invalid root - root.RootHash = make([]byte, types.HashSize+1) - bufBadHash, err := root.MarshalBinary() - if err != nil { - t.Fatalf("must marshal log root: %v", err) - } - - for _, table := range []struct { - description string - rsp *trillian.GetLatestSignedLogRootResponse - err error - wantErr bool - wantTh *types.TreeHead - }{ - { - description: "invalid: backend failure", - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: no response", - wantErr: true, - }, - { - description: "invalid: no signed log root", - rsp: &trillian.GetLatestSignedLogRootResponse{}, - wantErr: true, - }, - { - description: "invalid: no log root", - rsp: &trillian.GetLatestSignedLogRootResponse{ - SignedLogRoot: &trillian.SignedLogRoot{}, - }, - wantErr: true, - }, - { - description: "invalid: no log root: unmarshal failed", - rsp: &trillian.GetLatestSignedLogRootResponse{ - SignedLogRoot: &trillian.SignedLogRoot{ - LogRoot: buf[1:], - }, - }, - wantErr: true, - }, - { - description: "invalid: unexpected hash length", - rsp: &trillian.GetLatestSignedLogRootResponse{ - SignedLogRoot: &trillian.SignedLogRoot{ - LogRoot: bufBadHash, - }, - }, - wantErr: true, - }, - { - description: "valid", - rsp: &trillian.GetLatestSignedLogRootResponse{ - SignedLogRoot: &trillian.SignedLogRoot{ - LogRoot: buf, - }, - }, - wantTh: &types.TreeHead{ - Timestamp: 1622585623, - TreeSize: 0, - RootHash: &[types.HashSize]byte{}, - }, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - grpc := mocks.NewMockTrillianLogClient(ctrl) - grpc.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) - client := TrillianClient{GRPC: grpc} - - th, err := client.GetTreeHead(context.Background()) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := th, table.wantTh; !reflect.DeepEqual(got, want) { - t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - }() - } -} - -func TestGetConsistencyProof(t *testing.T) { - req := &types.ConsistencyProofRequest{ - OldSize: 1, - NewSize: 3, - } - for _, table := range []struct { - description string - req *types.ConsistencyProofRequest - rsp *trillian.GetConsistencyProofResponse - err error - wantErr bool - wantProof *types.ConsistencyProof - }{ - { - description: "invalid: backend failure", - req: req, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: no response", - req: req, - wantErr: true, - }, - { - description: "invalid: no consistency proof", - req: req, - rsp: &trillian.GetConsistencyProofResponse{}, - wantErr: true, - }, - { - description: "invalid: not a consistency proof (1/2)", - req: req, - rsp: &trillian.GetConsistencyProofResponse{ - Proof: &trillian.Proof{ - Hashes: [][]byte{}, - }, - }, - wantErr: true, - }, - { - description: "invalid: not a consistency proof (2/2)", - req: req, - rsp: &trillian.GetConsistencyProofResponse{ - Proof: &trillian.Proof{ - Hashes: [][]byte{ - make([]byte, types.HashSize), - make([]byte, types.HashSize+1), - }, - }, - }, - wantErr: true, - }, - { - description: "valid", - req: req, - rsp: &trillian.GetConsistencyProofResponse{ - Proof: &trillian.Proof{ - Hashes: [][]byte{ - make([]byte, types.HashSize), - make([]byte, types.HashSize), - }, - }, - }, - wantProof: &types.ConsistencyProof{ - OldSize: 1, - NewSize: 3, - Path: []*[types.HashSize]byte{ - &[types.HashSize]byte{}, - &[types.HashSize]byte{}, - }, - }, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - grpc := mocks.NewMockTrillianLogClient(ctrl) - grpc.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) - client := TrillianClient{GRPC: grpc} - - proof, err := client.GetConsistencyProof(context.Background(), table.req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { - t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - }() - } -} - -func TestGetInclusionProof(t *testing.T) { - req := &types.InclusionProofRequest{ - TreeSize: 4, - LeafHash: &[types.HashSize]byte{}, - } - for _, table := range []struct { - description string - req *types.InclusionProofRequest - rsp *trillian.GetInclusionProofByHashResponse - err error - wantErr bool - wantProof *types.InclusionProof - }{ - { - description: "invalid: backend failure", - req: req, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: no response", - req: req, - wantErr: true, - }, - { - description: "invalid: bad proof count", - req: req, - rsp: &trillian.GetInclusionProofByHashResponse{ - Proof: []*trillian.Proof{ - &trillian.Proof{}, - &trillian.Proof{}, - }, - }, - wantErr: true, - }, - { - description: "invalid: not an inclusion proof (1/2)", - req: req, - rsp: &trillian.GetInclusionProofByHashResponse{ - Proof: []*trillian.Proof{ - &trillian.Proof{ - LeafIndex: 1, - Hashes: [][]byte{}, - }, - }, - }, - wantErr: true, - }, - { - description: "invalid: not an inclusion proof (2/2)", - req: req, - rsp: &trillian.GetInclusionProofByHashResponse{ - Proof: []*trillian.Proof{ - &trillian.Proof{ - LeafIndex: 1, - Hashes: [][]byte{ - make([]byte, types.HashSize), - make([]byte, types.HashSize+1), - }, - }, - }, - }, - wantErr: true, - }, - { - description: "valid", - req: req, - rsp: &trillian.GetInclusionProofByHashResponse{ - Proof: []*trillian.Proof{ - &trillian.Proof{ - LeafIndex: 1, - Hashes: [][]byte{ - make([]byte, types.HashSize), - make([]byte, types.HashSize), - }, - }, - }, - }, - wantProof: &types.InclusionProof{ - TreeSize: 4, - LeafIndex: 1, - Path: []*[types.HashSize]byte{ - &[types.HashSize]byte{}, - &[types.HashSize]byte{}, - }, - }, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - grpc := mocks.NewMockTrillianLogClient(ctrl) - grpc.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) - client := TrillianClient{GRPC: grpc} - - proof, err := client.GetInclusionProof(context.Background(), table.req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := proof, table.wantProof; !reflect.DeepEqual(got, want) { - t.Errorf("got proof\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - }() - } -} - -func TestGetLeaves(t *testing.T) { - req := &types.LeavesRequest{ - StartSize: 1, - EndSize: 2, - } - firstLeaf := &types.Leaf{ - Message: types.Message{ - ShardHint: 0, - Checksum: &[types.HashSize]byte{}, - }, - SigIdent: types.SigIdent{ - Signature: &[types.SignatureSize]byte{}, - KeyHash: &[types.HashSize]byte{}, - }, - } - secondLeaf := &types.Leaf{ - Message: types.Message{ - ShardHint: 0, - Checksum: &[types.HashSize]byte{}, - }, - SigIdent: types.SigIdent{ - Signature: &[types.SignatureSize]byte{}, - KeyHash: &[types.HashSize]byte{}, - }, - } - - for _, table := range []struct { - description string - req *types.LeavesRequest - rsp *trillian.GetLeavesByRangeResponse - err error - wantErr bool - wantLeaves *types.LeafList - }{ - { - description: "invalid: backend failure", - req: req, - err: fmt.Errorf("something went wrong"), - wantErr: true, - }, - { - description: "invalid: no response", - req: req, - wantErr: true, - }, - { - description: "invalid: unexpected number of leaves", - req: req, - rsp: &trillian.GetLeavesByRangeResponse{ - Leaves: []*trillian.LogLeaf{ - &trillian.LogLeaf{ - LeafValue: firstLeaf.Marshal(), - LeafIndex: 1, - }, - }, - }, - wantErr: true, - }, - { - description: "invalid: unexpected leaf (1/2)", - req: req, - rsp: &trillian.GetLeavesByRangeResponse{ - Leaves: []*trillian.LogLeaf{ - &trillian.LogLeaf{ - LeafValue: firstLeaf.Marshal(), - LeafIndex: 1, - }, - &trillian.LogLeaf{ - LeafValue: secondLeaf.Marshal(), - LeafIndex: 3, - }, - }, - }, - wantErr: true, - }, - { - description: "invalid: unexpected leaf (2/2)", - req: req, - rsp: &trillian.GetLeavesByRangeResponse{ - Leaves: []*trillian.LogLeaf{ - &trillian.LogLeaf{ - LeafValue: firstLeaf.Marshal(), - LeafIndex: 1, - }, - &trillian.LogLeaf{ - LeafValue: secondLeaf.Marshal()[1:], - LeafIndex: 2, - }, - }, - }, - wantErr: true, - }, - { - description: "valid", - req: req, - rsp: &trillian.GetLeavesByRangeResponse{ - Leaves: []*trillian.LogLeaf{ - &trillian.LogLeaf{ - LeafValue: firstLeaf.Marshal(), - LeafIndex: 1, - }, - &trillian.LogLeaf{ - LeafValue: secondLeaf.Marshal(), - LeafIndex: 2, - }, - }, - }, - wantLeaves: &types.LeafList{ - firstLeaf, - secondLeaf, - }, - }, - } { - // Run deferred functions at the end of each iteration - func() { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - grpc := mocks.NewMockTrillianLogClient(ctrl) - grpc.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).Return(table.rsp, table.err) - client := TrillianClient{GRPC: grpc} - - leaves, err := client.GetLeaves(context.Background(), table.req) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - return - } - if got, want := leaves, table.wantLeaves; !reflect.DeepEqual(got, want) { - t.Errorf("got leaves\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - }() - } -} diff --git a/trillian/util.go b/trillian/util.go deleted file mode 100644 index 87e64b6..0000000 --- a/trillian/util.go +++ /dev/null @@ -1,33 +0,0 @@ -package trillian - -import ( - "fmt" - - trillian "github.com/google/trillian/types" - siglog "github.com/system-transparency/stfe/types" -) - -func treeHeadFromLogRoot(lr *trillian.LogRootV1) *siglog.TreeHead { - var hash [siglog.HashSize]byte - th := siglog.TreeHead{ - Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000), - TreeSize: uint64(lr.TreeSize), - RootHash: &hash, - } - copy(th.RootHash[:], lr.RootHash) - return &th -} - -func nodePathFromHashes(hashes [][]byte) ([]*[siglog.HashSize]byte, error) { - var path []*[siglog.HashSize]byte - for _, hash := range hashes { - if len(hash) != siglog.HashSize { - return nil, fmt.Errorf("unexpected hash length: %v", len(hash)) - } - - var h [siglog.HashSize]byte - copy(h[:], hash) - path = append(path, &h) - } - return path, nil -} diff --git a/trillian_test.go b/trillian_test.go deleted file mode 100644 index 1b0c923..0000000 --- a/trillian_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package stfe - -import ( - "fmt" - "testing" - - "github.com/google/trillian" - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/testdata" - "github.com/system-transparency/stfe/types" -) - -func TestCheckQueueLeaf(t *testing.T) { - for _, table := range []struct { - description string - rsp *trillian.QueueLeafResponse - err error - wantErr bool - }{ - { - description: "invalid: no Trillian response: error", - err: fmt.Errorf("backend error"), - wantErr: true, - }, - { - description: "invalid: no Trillian response: nil", - wantErr: true, - }, - { - description: "invalid: no Trillian response: empty", - rsp: &trillian.QueueLeafResponse{}, - wantErr: true, - }, - { - description: "valid: gRPC status: duplicate", - rsp: testdata.DefaultTQlr(t, true), - }, - { - description: "valid: gRPC status: ok", - rsp: testdata.DefaultTQlr(t, false), - }, - } { - err := checkQueueLeaf(table.rsp, table.err) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} - -func TestCheckGetLeavesByRange(t *testing.T) { - for _, table := range []struct { - description string - req *types.GetEntriesV1 - rsp *trillian.GetLeavesByRangeResponse - err error - wantErr bool - }{ - { - description: "invalid: no Trillian response: error", - req: &types.GetEntriesV1{Start: 0, End: 1}, - err: fmt.Errorf("backend error"), - wantErr: true, - }, - { - description: "invalid: no Trillian response: nil", - req: &types.GetEntriesV1{Start: 0, End: 1}, - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no leaves", - req: &types.GetEntriesV1{Start: 0, End: 1}, - rsp: func(rsp *trillian.GetLeavesByRangeResponse) *trillian.GetLeavesByRangeResponse { - rsp.Leaves = nil - return rsp - }(testdata.DefaultTGlbrr(t, 0, 1)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no signed log root", - req: &types.GetEntriesV1{Start: 0, End: 1}, - rsp: func(rsp *trillian.GetLeavesByRangeResponse) *trillian.GetLeavesByRangeResponse { - rsp.SignedLogRoot = nil - return rsp - }(testdata.DefaultTGlbrr(t, 0, 1)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no log root", - req: &types.GetEntriesV1{Start: 0, End: 1}, - rsp: func(rsp *trillian.GetLeavesByRangeResponse) *trillian.GetLeavesByRangeResponse { - rsp.SignedLogRoot.LogRoot = nil - return rsp - }(testdata.DefaultTGlbrr(t, 0, 1)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: truncated log root", - req: &types.GetEntriesV1{Start: 0, End: 1}, - rsp: func(rsp *trillian.GetLeavesByRangeResponse) *trillian.GetLeavesByRangeResponse { - rsp.SignedLogRoot.LogRoot = rsp.SignedLogRoot.LogRoot[1:] - return rsp - }(testdata.DefaultTGlbrr(t, 0, 1)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: too many leaves", - req: &types.GetEntriesV1{Start: 0, End: 1}, - rsp: testdata.DefaultTGlbrr(t, 0, 2), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: start is not a valid index", - req: &types.GetEntriesV1{Start: 10, End: 10}, - rsp: testdata.DefaultTGlbrr(t, 9, 9), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: invalid leaf indices", - req: &types.GetEntriesV1{Start: 10, End: 11}, - rsp: testdata.DefaultTGlbrr(t, 11, 12), - wantErr: true, - }, - { - description: "valid", - req: &types.GetEntriesV1{Start: 10, End: 20}, - rsp: testdata.DefaultTGlbrr(t, 10, 20), - }, - } { - err := checkGetLeavesByRange(table.req, table.rsp, table.err) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} - -func TestCheckGetInclusionProofByHash(t *testing.T) { - for _, table := range []struct { - description string - rsp *trillian.GetInclusionProofByHashResponse - err error - wantErr bool - }{ - { - description: "invalid: no Trillian response: error", - err: fmt.Errorf("backend failure"), - wantErr: true, - }, - { - description: "invalid: no Trillian response: nil", - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no proofs", - rsp: &trillian.GetInclusionProofByHashResponse{}, - wantErr: true, - }, - { - description: "bad response: no proof", - rsp: func(rsp *trillian.GetInclusionProofByHashResponse) *trillian.GetInclusionProofByHashResponse { - rsp.Proof[0] = nil - return rsp - }(testdata.DefaultTGipbhr(t)), - wantErr: true, - }, - { - description: "bad response: proof with invalid node hash", - rsp: func(rsp *trillian.GetInclusionProofByHashResponse) *trillian.GetInclusionProofByHashResponse { - rsp.Proof[0].Hashes = append(rsp.Proof[0].Hashes, make([]byte, 0)) - return rsp - }(testdata.DefaultTGipbhr(t)), - wantErr: true, - }, - { - description: "valid", - rsp: testdata.DefaultTGipbhr(t), - }, - } { - err := checkGetInclusionProofByHash(newLogParameters(t, nil), table.rsp, table.err) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} - -func TestCheckGetConsistencyProof(t *testing.T) { - for _, table := range []struct { - description string - rsp *trillian.GetConsistencyProofResponse - err error - wantErr bool - }{ - { - description: "invalid: no Trillian response: error", - err: fmt.Errorf("backend failure"), - wantErr: true, - }, - { - description: "invalid: no Trillian response: nil", - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no proof", - rsp: &trillian.GetConsistencyProofResponse{}, - wantErr: true, - }, - { - description: "invalid: bad Trillian response: proof with invalid node hash", - rsp: func(rsp *trillian.GetConsistencyProofResponse) *trillian.GetConsistencyProofResponse { - rsp.Proof.Hashes = append(rsp.Proof.Hashes, make([]byte, 0)) - return rsp - }(testdata.DefaultTGcpr(t)), - wantErr: true, - }, - { - description: "valid", - rsp: testdata.DefaultTGcpr(t), - }, - } { - err := checkGetConsistencyProof(newLogParameters(t, nil), table.rsp, table.err) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} - -func TestCheckGetLatestSignedLogRoot(t *testing.T) { - for _, table := range []struct { - description string - rsp *trillian.GetLatestSignedLogRootResponse - err error - wantErr bool - }{ - { - description: "invalid: no Trillian response: error", - err: fmt.Errorf("backend failure"), - wantErr: true, - }, - { - description: "invalid: no Trillian response: nil", - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no signed log root", - rsp: func(rsp *trillian.GetLatestSignedLogRootResponse) *trillian.GetLatestSignedLogRootResponse { - rsp.SignedLogRoot = nil - return rsp - }(testdata.DefaultTSlr(t)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: no log root", - rsp: func(rsp *trillian.GetLatestSignedLogRootResponse) *trillian.GetLatestSignedLogRootResponse { - rsp.SignedLogRoot.LogRoot = nil - return rsp - }(testdata.DefaultTSlr(t)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: truncated log root", - rsp: func(rsp *trillian.GetLatestSignedLogRootResponse) *trillian.GetLatestSignedLogRootResponse { - rsp.SignedLogRoot.LogRoot = rsp.SignedLogRoot.LogRoot[1:] - return rsp - }(testdata.DefaultTSlr(t)), - wantErr: true, - }, - { - description: "invalid: bad Trillian response: truncated root hash", - rsp: testdata.Tslr(t, testdata.Tlr(t, testdata.TreeSize, testdata.Timestamp, make([]byte, 31))), - wantErr: true, - }, - { - description: "valid", - rsp: testdata.DefaultTSlr(t), - }, - } { - var lr ttypes.LogRootV1 - err := checkGetLatestSignedLogRoot(newLogParameters(t, nil), table.rsp, table.err, &lr) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q", got, want, table.description) - } - } -} diff --git a/types/ascii.go b/types/ascii.go deleted file mode 100644 index d27d79b..0000000 --- a/types/ascii.go +++ /dev/null @@ -1,421 +0,0 @@ -package types - -import ( - "bytes" - "encoding/hex" - "fmt" - "io" - "io/ioutil" - "strconv" -) - -const ( - // Delim is a key-value separator - Delim = "=" - - // EOL is a line sepator - EOL = "\n" - - // NumField* is the number of unique keys in an incoming ASCII message - NumFieldLeaf = 4 - NumFieldSignedTreeHead = 5 - NumFieldConsistencyProof = 3 - NumFieldInclusionProof = 3 - NumFieldLeavesRequest = 2 - NumFieldInclusionProofRequest = 2 - NumFieldConsistencyProofRequest = 2 - NumFieldLeafRequest = 5 - NumFieldCosignatureRequest = 2 - - // New leaf keys - ShardHint = "shard_hint" - Checksum = "checksum" - SignatureOverMessage = "signature_over_message" - VerificationKey = "verification_key" - DomainHint = "domain_hint" - - // Inclusion proof keys - LeafHash = "leaf_hash" - LeafIndex = "leaf_index" - InclusionPath = "inclusion_path" - - // Consistency proof keys - NewSize = "new_size" - OldSize = "old_size" - ConsistencyPath = "consistency_path" - - // Range of leaves keys - StartSize = "start_size" - EndSize = "end_size" - - // Tree head keys - Timestamp = "timestamp" - TreeSize = "tree_size" - RootHash = "root_hash" - - // Signature and signer-identity keys - Signature = "signature" - KeyHash = "key_hash" -) - -// MessageASCI is a wrapper that manages ASCII key-value pairs -type MessageASCII struct { - m map[string][]string -} - -// NewMessageASCII unpacks an incoming ASCII message -func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) { - buf, err := ioutil.ReadAll(r) - if err != nil { - return nil, fmt.Errorf("ReadAll: %v", err) - } - lines := bytes.Split(buf, []byte(EOL)) - if len(lines) <= 1 { - return nil, fmt.Errorf("Not enough lines: empty") - } - lines = lines[:len(lines)-1] // valid message => split gives empty last line - - msg := MessageASCII{make(map[string][]string)} - for _, line := range lines { - split := bytes.Index(line, []byte(Delim)) - if split == -1 { - return nil, fmt.Errorf("invalid line: %v", string(line)) - } - - key := string(line[:split]) - value := string(line[split+len(Delim):]) - values, ok := msg.m[key] - if !ok { - values = nil - msg.m[key] = values - } - msg.m[key] = append(values, value) - } - - if msg.NumField() != numFieldExpected { - return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField()) - } - return &msg, nil -} - -// NumField returns the number of unique keys -func (msg *MessageASCII) NumField() int { - return len(msg.m) -} - -// GetStrings returns a list of strings -func (msg *MessageASCII) GetStrings(key string) []string { - strs, ok := msg.m[key] - if !ok { - return nil - } - return strs -} - -// GetString unpacks a string -func (msg *MessageASCII) GetString(key string) (string, error) { - strs := msg.GetStrings(key) - if len(strs) != 1 { - return "", fmt.Errorf("expected one string: %v", strs) - } - return strs[0], nil -} - -// GetUint64 unpacks an uint64 -func (msg *MessageASCII) GetUint64(key string) (uint64, error) { - str, err := msg.GetString(key) - if err != nil { - return 0, fmt.Errorf("GetString: %v", err) - } - num, err := strconv.ParseUint(str, 10, 64) - if err != nil { - return 0, fmt.Errorf("ParseUint: %v", err) - } - return num, nil -} - -// GetHash unpacks a hash -func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) { - str, err := msg.GetString(key) - if err != nil { - return nil, fmt.Errorf("GetString: %v", err) - } - - var hash [HashSize]byte - if err := decodeHex(str, hash[:]); err != nil { - return nil, fmt.Errorf("decodeHex: %v", err) - } - return &hash, nil -} - -// GetSignature unpacks a signature -func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) { - str, err := msg.GetString(key) - if err != nil { - return nil, fmt.Errorf("GetString: %v", err) - } - - var signature [SignatureSize]byte - if err := decodeHex(str, signature[:]); err != nil { - return nil, fmt.Errorf("decodeHex: %v", err) - } - return &signature, nil -} - -// GetVerificationKey unpacks a verification key -func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) { - str, err := msg.GetString(key) - if err != nil { - return nil, fmt.Errorf("GetString: %v", err) - } - - var vk [VerificationKeySize]byte - if err := decodeHex(str, vk[:]); err != nil { - return nil, fmt.Errorf("decodeHex: %v", err) - } - return &vk, nil -} - -// decodeHex decodes a hex-encoded string into an already-sized byte slice -func decodeHex(str string, out []byte) error { - buf, err := hex.DecodeString(str) - if err != nil { - return fmt.Errorf("DecodeString: %v", err) - } - if len(buf) != len(out) { - return fmt.Errorf("invalid length: %v", len(buf)) - } - copy(out, buf) - return nil -} - -/* - * - * MarshalASCII wrappers for types that the log server outputs - * - */ -func (l *Leaf) MarshalASCII(w io.Writer) error { - if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, SignatureOverMessage, hex.EncodeToString(l.Signature[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - return nil -} - -func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error { - if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - for _, sigident := range sth.SigIdent { - if err := sigident.MarshalASCII(w); err != nil { - return fmt.Errorf("MarshalASCII: %v", err) - } - } - return nil -} - -func (si *SigIdent) MarshalASCII(w io.Writer) error { - if err := writeASCII(w, Signature, hex.EncodeToString(si.Signature[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, KeyHash, hex.EncodeToString(si.KeyHash[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - return nil -} - -func (p *ConsistencyProof) MarshalASCII(w io.Writer) error { - if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - for _, hash := range p.Path { - if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - } - return nil -} - -func (p *InclusionProof) MarshalASCII(w io.Writer) error { - if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - for _, hash := range p.Path { - if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil { - return fmt.Errorf("writeASCII: %v", err) - } - } - return nil -} - -func writeASCII(w io.Writer, key, value string) error { - if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil { - return fmt.Errorf("Fprintf: %v", err) - } - return nil -} - -/* - * - * Unmarshal ASCII wrappers that the log server and/or log clients receive. - * - */ -func (ll *LeafList) UnmarshalASCII(r io.Reader) error { - return nil -} - -func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldSignedTreeHead) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - // TreeHead - if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil { - return fmt.Errorf("GetUint64(Timestamp): %v", err) - } - if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil { - return fmt.Errorf("GetUint64(TreeSize): %v", err) - } - if sth.RootHash, err = msg.GetHash(RootHash); err != nil { - return fmt.Errorf("GetHash(RootHash): %v", err) - } - - // SigIdent - signatures := msg.GetStrings(Signature) - if len(signatures) == 0 { - return fmt.Errorf("no signer") - } - keyHashes := msg.GetStrings(KeyHash) - if len(signatures) != len(keyHashes) { - return fmt.Errorf("mismatched signature-signer count") - } - sth.SigIdent = make([]*SigIdent, 0, len(signatures)) - for i, n := 0, len(signatures); i < n; i++ { - var signature [SignatureSize]byte - if err := decodeHex(signatures[i], signature[:]); err != nil { - return fmt.Errorf("decodeHex: %v", err) - } - var hash [HashSize]byte - if err := decodeHex(keyHashes[i], hash[:]); err != nil { - return fmt.Errorf("decodeHex: %v", err) - } - sth.SigIdent = append(sth.SigIdent, &SigIdent{ - Signature: &signature, - KeyHash: &hash, - }) - } - return nil -} - -func (p *InclusionProof) UnmarshalASCII(r io.Reader) error { - return nil -} - -func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error { - return nil -} - -func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - if req.LeafHash, err = msg.GetHash(LeafHash); err != nil { - return fmt.Errorf("GetHash(LeafHash): %v", err) - } - if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil { - return fmt.Errorf("GetUint64(TreeSize): %v", err) - } - return nil -} - -func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - if req.NewSize, err = msg.GetUint64(NewSize); err != nil { - return fmt.Errorf("GetUint64(NewSize): %v", err) - } - if req.OldSize, err = msg.GetUint64(OldSize); err != nil { - return fmt.Errorf("GetUint64(OldSize): %v", err) - } - return nil -} - -func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldLeavesRequest) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - if req.StartSize, err = msg.GetUint64(StartSize); err != nil { - return fmt.Errorf("GetUint64(StartSize): %v", err) - } - if req.EndSize, err = msg.GetUint64(EndSize); err != nil { - return fmt.Errorf("GetUint64(EndSize): %v", err) - } - return nil -} - -func (req *LeafRequest) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldLeafRequest) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil { - return fmt.Errorf("GetUint64(ShardHint): %v", err) - } - if req.Checksum, err = msg.GetHash(Checksum); err != nil { - return fmt.Errorf("GetHash(Checksum): %v", err) - } - if req.Signature, err = msg.GetSignature(SignatureOverMessage); err != nil { - return fmt.Errorf("GetSignature: %v", err) - } - if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil { - return fmt.Errorf("GetVerificationKey: %v", err) - } - if req.DomainHint, err = msg.GetString(DomainHint); err != nil { - return fmt.Errorf("GetString(DomainHint): %v", err) - } - return nil -} - -func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error { - msg, err := NewMessageASCII(r, NumFieldCosignatureRequest) - if err != nil { - return fmt.Errorf("NewMessageASCII: %v", err) - } - - if req.Signature, err = msg.GetSignature(Signature); err != nil { - return fmt.Errorf("GetSignature: %v", err) - } - if req.KeyHash, err = msg.GetHash(KeyHash); err != nil { - return fmt.Errorf("GetHash(KeyHash): %v", err) - } - return nil -} diff --git a/types/ascii_test.go b/types/ascii_test.go deleted file mode 100644 index 92732f9..0000000 --- a/types/ascii_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package types - -import ( - "bytes" - "fmt" - "io" - "reflect" - "testing" -) - -/* - * - * MessageASCII methods and helpers - * - */ -func TestNewMessageASCII(t *testing.T) { - for _, table := range []struct { - description string - input io.Reader - wantErr bool - wantMap map[string][]string - }{ - { - description: "invalid: not enough lines", - input: bytes.NewBufferString(""), - wantErr: true, - }, - { - description: "invalid: lines must end with new line", - input: bytes.NewBufferString("k1=v1\nk2=v2"), - wantErr: true, - }, - { - description: "invalid: lines must not be empty", - input: bytes.NewBufferString("k1=v1\n\nk2=v2\n"), - wantErr: true, - }, - { - description: "invalid: wrong number of fields", - input: bytes.NewBufferString("k1=v1\n"), - wantErr: true, - }, - { - description: "valid", - input: bytes.NewBufferString("k1=v1\nk2=v2\nk2=v3=4\n"), - wantMap: map[string][]string{ - "k1": []string{"v1"}, - "k2": []string{"v2", "v3=4"}, - }, - }, - } { - msg, err := NewMessageASCII(table.input, len(table.wantMap)) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := msg.m, table.wantMap; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestNumField(t *testing.T) {} -func TestGetStrings(t *testing.T) {} -func TestGetString(t *testing.T) {} -func TestGetUint64(t *testing.T) {} -func TestGetHash(t *testing.T) {} -func TestGetSignature(t *testing.T) {} -func TestGetVerificationKey(t *testing.T) {} -func TestDecodeHex(t *testing.T) {} - -/* - * - * MarshalASCII methods and helpers - * - */ -func TestLeafMarshalASCII(t *testing.T) { - description := "valid: two leaves" - leafList := []*Leaf{ - &Leaf{ - Message: Message{ - ShardHint: 123, - Checksum: testBuffer32, - }, - SigIdent: SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - &Leaf{ - Message: Message{ - ShardHint: 456, - Checksum: testBuffer32, - }, - SigIdent: SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - } - wantBuf := bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+ - "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", - // Leaf 1 - ShardHint, Delim, 123, EOL, - Checksum, Delim, testBuffer32[:], EOL, - SignatureOverMessage, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - // Leaf 2 - ShardHint, Delim, 456, EOL, - Checksum, Delim, testBuffer32[:], EOL, - SignatureOverMessage, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - )) - buf := bytes.NewBuffer(nil) - for _, leaf := range leafList { - if err := leaf.MarshalASCII(buf); err != nil { - t.Errorf("expected error %v but got %v in test %q: %v", false, true, description, err) - return - } - } - if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) - } -} - -func TestSignedTreeHeadMarshalASCII(t *testing.T) { - description := "valid" - sth := &SignedTreeHead{ - TreeHead: TreeHead{ - Timestamp: 123, - TreeSize: 456, - RootHash: testBuffer32, - }, - SigIdent: []*SigIdent{ - &SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - &SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - } - wantBuf := bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", - Timestamp, Delim, 123, EOL, - TreeSize, Delim, 456, EOL, - RootHash, Delim, testBuffer32[:], EOL, - Signature, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - Signature, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - )) - buf := bytes.NewBuffer(nil) - if err := sth.MarshalASCII(buf); err != nil { - t.Errorf("expected error %v but got %v in test %q", false, true, description) - return - } - if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) - } -} - -func TestInclusionProofMarshalASCII(t *testing.T) { - description := "valid" - proof := InclusionProof{ - TreeSize: 321, - LeafIndex: 123, - Path: []*[HashSize]byte{ - testBuffer32, - testBuffer32, - }, - } - wantBuf := bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", - TreeSize, Delim, 321, EOL, - LeafIndex, Delim, 123, EOL, - InclusionPath, Delim, testBuffer32[:], EOL, - InclusionPath, Delim, testBuffer32[:], EOL, - )) - buf := bytes.NewBuffer(nil) - if err := proof.MarshalASCII(buf); err != nil { - t.Errorf("expected error %v but got %v in test %q", false, true, description) - return - } - if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) - } -} - -func TestConsistencyProofMarshalASCII(t *testing.T) { - description := "valid" - proof := ConsistencyProof{ - NewSize: 321, - OldSize: 123, - Path: []*[HashSize]byte{ - testBuffer32, - testBuffer32, - }, - } - wantBuf := bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s", - NewSize, Delim, 321, EOL, - OldSize, Delim, 123, EOL, - ConsistencyPath, Delim, testBuffer32[:], EOL, - ConsistencyPath, Delim, testBuffer32[:], EOL, - )) - buf := bytes.NewBuffer(nil) - if err := proof.MarshalASCII(buf); err != nil { - t.Errorf("expected error %v but got %v in test %q", false, true, description) - return - } - if got, want := buf.Bytes(), wantBuf.Bytes(); !bytes.Equal(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", string(got), string(want), description) - } -} - -func TestWriteASCII(t *testing.T) { -} - -/* - * - * UnmarshalASCII methods and helpers - * - */ -func TestLeafListUnmarshalASCII(t *testing.T) {} - -func TestSignedTreeHeadUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantSth *SignedTreeHead - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s", - Timestamp, Delim, 123, EOL, - TreeSize, Delim, 456, EOL, - RootHash, Delim, testBuffer32[:], EOL, - Signature, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - Signature, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - )), - wantSth: &SignedTreeHead{ - TreeHead: TreeHead{ - Timestamp: 123, - TreeSize: 456, - RootHash: testBuffer32, - }, - SigIdent: []*SigIdent{ - &SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - &SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - }, - }, - } { - var sth SignedTreeHead - err := sth.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &sth, table.wantSth; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestInclusionProofUnmarshalASCII(t *testing.T) {} -func TestConsistencyProofUnmarshalASCII(t *testing.T) {} - -func TestInclusionProofRequestUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantReq *InclusionProofRequest - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%x%s"+"%s%s%d%s", - LeafHash, Delim, testBuffer32[:], EOL, - TreeSize, Delim, 123, EOL, - )), - wantReq: &InclusionProofRequest{ - LeafHash: testBuffer32, - TreeSize: 123, - }, - }, - } { - var req InclusionProofRequest - err := req.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestConsistencyProofRequestUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantReq *ConsistencyProofRequest - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s", - NewSize, Delim, 321, EOL, - OldSize, Delim, 123, EOL, - )), - wantReq: &ConsistencyProofRequest{ - NewSize: 321, - OldSize: 123, - }, - }, - } { - var req ConsistencyProofRequest - err := req.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestLeavesRequestUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantReq *LeavesRequest - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%d%s", - StartSize, Delim, 123, EOL, - EndSize, Delim, 456, EOL, - )), - wantReq: &LeavesRequest{ - StartSize: 123, - EndSize: 456, - }, - }, - } { - var req LeavesRequest - err := req.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestLeafRequestUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantReq *LeafRequest - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%d%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%x%s"+"%s%s%s%s", - ShardHint, Delim, 123, EOL, - Checksum, Delim, testBuffer32[:], EOL, - SignatureOverMessage, Delim, testBuffer64[:], EOL, - VerificationKey, Delim, testBuffer32[:], EOL, - DomainHint, Delim, "example.com", EOL, - )), - wantReq: &LeafRequest{ - Message: Message{ - ShardHint: 123, - Checksum: testBuffer32, - }, - Signature: testBuffer64, - VerificationKey: testBuffer32, - DomainHint: "example.com", - }, - }, - } { - var req LeafRequest - err := req.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} - -func TestCosignatureRequestUnmarshalASCII(t *testing.T) { - for _, table := range []struct { - description string - buf io.Reader - wantErr bool - wantReq *CosignatureRequest - }{ - { - description: "valid", - buf: bytes.NewBufferString(fmt.Sprintf( - "%s%s%x%s"+"%s%s%x%s", - Signature, Delim, testBuffer64[:], EOL, - KeyHash, Delim, testBuffer32[:], EOL, - )), - wantReq: &CosignatureRequest{ - SigIdent: SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - }, - } { - var req CosignatureRequest - err := req.UnmarshalASCII(table.buf) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &req, table.wantReq; !reflect.DeepEqual(got, want) { - t.Errorf("got\n\t%v\nbut wanted\n\t%v\nin test %q", got, want, table.description) - } - } -} diff --git a/types/trunnel.go b/types/trunnel.go deleted file mode 100644 index 268f6f7..0000000 --- a/types/trunnel.go +++ /dev/null @@ -1,60 +0,0 @@ -package types - -import ( - "encoding/binary" - "fmt" -) - -const ( - // MessageSize is the number of bytes in a Trunnel-encoded leaf message - MessageSize = 8 + HashSize - // LeafSize is the number of bytes in a Trunnel-encoded leaf - LeafSize = MessageSize + SignatureSize + HashSize -) - -// Marshal returns a Trunnel-encoded message -func (m *Message) Marshal() []byte { - buf := make([]byte, MessageSize) - binary.BigEndian.PutUint64(buf, m.ShardHint) - copy(buf[8:], m.Checksum[:]) - return buf -} - -// Marshal returns a Trunnel-encoded leaf -func (l *Leaf) Marshal() []byte { - buf := l.Message.Marshal() - buf = append(buf, l.SigIdent.Signature[:]...) - buf = append(buf, l.SigIdent.KeyHash[:]...) - return buf -} - -// Marshal returns a Trunnel-encoded tree head -func (th *TreeHead) Marshal() []byte { - buf := make([]byte, 8+8+HashSize) - binary.BigEndian.PutUint64(buf[0:8], th.Timestamp) - binary.BigEndian.PutUint64(buf[8:16], th.TreeSize) - copy(buf[16:], th.RootHash[:]) - return buf -} - -// Unmarshal parses the Trunnel-encoded buffer as a leaf -func (l *Leaf) Unmarshal(buf []byte) error { - if len(buf) != LeafSize { - return fmt.Errorf("invalid leaf size: %v", len(buf)) - } - // Shard hint - l.ShardHint = binary.BigEndian.Uint64(buf) - offset := 8 - // Checksum - l.Checksum = &[HashSize]byte{} - copy(l.Checksum[:], buf[offset:offset+HashSize]) - offset += HashSize - // Signature - l.Signature = &[SignatureSize]byte{} - copy(l.Signature[:], buf[offset:offset+SignatureSize]) - offset += SignatureSize - // KeyHash - l.KeyHash = &[HashSize]byte{} - copy(l.KeyHash[:], buf[offset:]) - return nil -} diff --git a/types/trunnel_test.go b/types/trunnel_test.go deleted file mode 100644 index 297578c..0000000 --- a/types/trunnel_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package types - -import ( - "bytes" - "reflect" - "testing" -) - -var ( - testBuffer32 = &[32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31} - testBuffer64 = &[64]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63} -) - -func TestMarshalMessage(t *testing.T) { - description := "valid: shard hint 72623859790382856, checksum 0x00,0x01,..." - message := &Message{ - ShardHint: 72623859790382856, - Checksum: testBuffer32, - } - want := bytes.Join([][]byte{ - []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - testBuffer32[:], - }, nil) - if got := message.Marshal(); !bytes.Equal(got, want) { - t.Errorf("got message\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) - } -} - -func TestMarshalLeaf(t *testing.T) { - description := "valid: shard hint 72623859790382856, buffers 0x00,0x01,..." - leaf := &Leaf{ - Message: Message{ - ShardHint: 72623859790382856, - Checksum: testBuffer32, - }, - SigIdent: SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - } - want := bytes.Join([][]byte{ - []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - testBuffer32[:], testBuffer64[:], testBuffer32[:], - }, nil) - if got := leaf.Marshal(); !bytes.Equal(got, want) { - t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) - } -} - -func TestMarshalTreeHead(t *testing.T) { - description := "valid: timestamp 16909060, tree size 72623859790382856, root hash 0x00,0x01,..." - th := &TreeHead{ - Timestamp: 16909060, - TreeSize: 72623859790382856, - RootHash: testBuffer32, - } - want := bytes.Join([][]byte{ - []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04}, - []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - testBuffer32[:], - }, nil) - if got := th.Marshal(); !bytes.Equal(got, want) { - t.Errorf("got tree head\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, description) - } -} - -func TestUnmarshalLeaf(t *testing.T) { - for _, table := range []struct { - description string - serialized []byte - wantErr bool - want *Leaf - }{ - { - description: "invalid: not enough bytes", - serialized: make([]byte, LeafSize-1), - wantErr: true, - }, - { - description: "invalid: too many bytes", - serialized: make([]byte, LeafSize+1), - wantErr: true, - }, - { - description: "valid: shard hint 72623859790382856, buffers 0x00,0x01,...", - serialized: bytes.Join([][]byte{ - []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - testBuffer32[:], testBuffer64[:], testBuffer32[:], - }, nil), - want: &Leaf{ - Message: Message{ - ShardHint: 72623859790382856, - Checksum: testBuffer32, - }, - SigIdent: SigIdent{ - Signature: testBuffer64, - KeyHash: testBuffer32, - }, - }, - }, - } { - var leaf Leaf - err := leaf.Unmarshal(table.serialized) - if got, want := err != nil, table.wantErr; got != want { - t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) - } - if err != nil { - continue - } - if got, want := &leaf, table.want; !reflect.DeepEqual(got, want) { - t.Errorf("got leaf\n\t%v\nbut wanted\n\t%v\nin test %q\n", got, want, table.description) - } - } -} diff --git a/types/types.go b/types/types.go deleted file mode 100644 index 9ca7db8..0000000 --- a/types/types.go +++ /dev/null @@ -1,155 +0,0 @@ -package types - -import ( - "crypto" - "crypto/ed25519" - "crypto/sha256" - "fmt" - "strings" -) - -const ( - HashSize = sha256.Size - SignatureSize = ed25519.SignatureSize - VerificationKeySize = ed25519.PublicKeySize - - EndpointAddLeaf = Endpoint("add-leaf") - EndpointAddCosignature = Endpoint("add-cosignature") - EndpointGetTreeHeadLatest = Endpoint("get-tree-head-latest") - EndpointGetTreeHeadToSign = Endpoint("get-tree-head-to-sign") - EndpointGetTreeHeadCosigned = Endpoint("get-tree-head-cosigned") - EndpointGetProofByHash = Endpoint("get-proof-by-hash") - EndpointGetConsistencyProof = Endpoint("get-consistency-proof") - EndpointGetLeaves = Endpoint("get-leaves") -) - -// Endpoint is a named HTTP API endpoint -type Endpoint string - -// Path joins a number of components to form a full endpoint path. For example, -// EndpointAddLeaf.Path("example.com", "st/v0") -> example.com/st/v0/add-leaf. -func (e Endpoint) Path(components ...string) string { - return strings.Join(append(components, string(e)), "/") -} - -// Leaf is the log's Merkle tree leaf. -type Leaf struct { - Message - SigIdent -} - -// Message is composed of a shard hint and a checksum. The submitter selects -// these values to fit the log's shard interval and the opaque data in question. -type Message struct { - ShardHint uint64 - Checksum *[HashSize]byte -} - -// SigIdent is composed of a signature-signer pair. The signature is computed -// over the Trunnel-serialized leaf message. KeyHash identifies the signer. -type SigIdent struct { - Signature *[SignatureSize]byte - KeyHash *[HashSize]byte -} - -// SignedTreeHead is composed of a tree head and a list of signature-signer -// pairs. Each signature is computed over the Trunnel-serialized tree head. -type SignedTreeHead struct { - TreeHead - SigIdent []*SigIdent -} - -// TreeHead is the log's tree head. -type TreeHead struct { - Timestamp uint64 - TreeSize uint64 - RootHash *[HashSize]byte -} - -// ConsistencyProof is a consistency proof that proves the log's append-only -// property. -type ConsistencyProof struct { - NewSize uint64 - OldSize uint64 - Path []*[HashSize]byte -} - -// InclusionProof is an inclusion proof that proves a leaf is included in the -// log. -type InclusionProof struct { - TreeSize uint64 - LeafIndex uint64 - Path []*[HashSize]byte -} - -// LeafList is a list of leaves -type LeafList []*Leaf - -// ConsistencyProofRequest is a get-consistency-proof request -type ConsistencyProofRequest struct { - NewSize uint64 - OldSize uint64 -} - -// InclusionProofRequest is a get-proof-by-hash request -type InclusionProofRequest struct { - LeafHash *[HashSize]byte - TreeSize uint64 -} - -// LeavesRequest is a get-leaves request -type LeavesRequest struct { - StartSize uint64 - EndSize uint64 -} - -// LeafRequest is an add-leaf request -type LeafRequest struct { - Message - Signature *[SignatureSize]byte - VerificationKey *[VerificationKeySize]byte - DomainHint string -} - -// CosignatureRequest is an add-cosignature request -type CosignatureRequest struct { - SigIdent -} - -// Sign signs the tree head using the log's signature scheme -func (th *TreeHead) Sign(signer crypto.Signer) (*SignedTreeHead, error) { - sig, err := signer.Sign(nil, th.Marshal(), crypto.Hash(0)) - if err != nil { - return nil, fmt.Errorf("Sign: %v", err) - } - - sigident := SigIdent{ - KeyHash: Hash(signer.Public().(ed25519.PublicKey)[:]), - Signature: &[SignatureSize]byte{}, - } - copy(sigident.Signature[:], sig) - return &SignedTreeHead{ - TreeHead: *th, - SigIdent: []*SigIdent{ - &sigident, - }, - }, nil -} - -// Verify verifies the tree head signature using the log's signature scheme -func (th *TreeHead) Verify(vk *[VerificationKeySize]byte, sig *[SignatureSize]byte) error { - if !ed25519.Verify(ed25519.PublicKey(vk[:]), th.Marshal(), sig[:]) { - return fmt.Errorf("invalid tree head signature") - } - return nil -} - -// Verify checks if a leaf is included in the log -func (p *InclusionProof) Verify(leaf *Leaf, th *TreeHead) error { // TODO - return nil -} - -// Verify checks if two tree heads are consistent -func (p *ConsistencyProof) Verify(oldTH, newTH *TreeHead) error { // TODO - return nil -} diff --git a/types/types_test.go b/types/types_test.go deleted file mode 100644 index da89c59..0000000 --- a/types/types_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package types - -import ( - "testing" -) - -func TestEndpointPath(t *testing.T) { - base, prefix, proto := "example.com", "log", "st/v0" - for _, table := range []struct { - endpoint Endpoint - want string - }{ - { - endpoint: EndpointAddLeaf, - want: "example.com/log/st/v0/add-leaf", - }, - { - endpoint: EndpointAddCosignature, - want: "example.com/log/st/v0/add-cosignature", - }, - { - endpoint: EndpointGetTreeHeadLatest, - want: "example.com/log/st/v0/get-tree-head-latest", - }, - { - endpoint: EndpointGetTreeHeadToSign, - want: "example.com/log/st/v0/get-tree-head-to-sign", - }, - { - endpoint: EndpointGetTreeHeadCosigned, - want: "example.com/log/st/v0/get-tree-head-cosigned", - }, - { - endpoint: EndpointGetConsistencyProof, - want: "example.com/log/st/v0/get-consistency-proof", - }, - { - endpoint: EndpointGetProofByHash, - want: "example.com/log/st/v0/get-proof-by-hash", - }, - { - endpoint: EndpointGetLeaves, - want: "example.com/log/st/v0/get-leaves", - }, - } { - if got, want := table.endpoint.Path(base+"/"+prefix+"/"+proto), table.want; got != want { - t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\twith one component", got, want) - } - if got, want := table.endpoint.Path(base, prefix, proto), table.want; got != want { - t.Errorf("got endpoint\n%s\n\tbut wanted\n%s\n\tmultiple components", got, want) - } - } -} - -func TestTreeHeadSign(t *testing.T) {} -func TestTreeHeadVerify(t *testing.T) {} -func TestInclusionProofVerify(t *testing.T) {} -func TestConsistencyProofVerify(t *testing.T) {} diff --git a/types/util.go b/types/util.go deleted file mode 100644 index 3cd7dfa..0000000 --- a/types/util.go +++ /dev/null @@ -1,21 +0,0 @@ -package types - -import ( - "crypto/sha256" -) - -const ( - LeafHashPrefix = 0x00 -) - -func Hash(buf []byte) *[HashSize]byte { - var ret [HashSize]byte - hash := sha256.New() - hash.Write(buf) - copy(ret[:], hash.Sum(nil)) - return &ret -} - -func HashLeaf(buf []byte) *[HashSize]byte { - return Hash(append([]byte{LeafHashPrefix}, buf...)) -} diff --git a/util.go b/util.go deleted file mode 100644 index a8c918e..0000000 --- a/util.go +++ /dev/null @@ -1,27 +0,0 @@ -package stfe - -import ( - ttypes "github.com/google/trillian/types" - "github.com/system-transparency/stfe/types" -) - -func NewTreeHeadFromLogRoot(lr *ttypes.LogRootV1) *types.TreeHead { - var hash [types.HashSize]byte - th := types.TreeHead{ - Timestamp: uint64(lr.TimestampNanos / 1000 / 1000 / 1000), - TreeSize: uint64(lr.TreeSize), - RootHash: &hash, - } - copy(th.RootHash[:], lr.RootHash) - return &th -} - -func NodePathFromHashes(hashes [][]byte) []*[types.HashSize]byte { - var path []*[types.HashSize]byte - for _, hash := range hashes { - var h [types.HashSize]byte - copy(h[:], hash) - path = append(path, &h) - } - return path -} diff --git a/util_test.go b/util_test.go deleted file mode 100644 index b40a672..0000000 --- a/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package stfe - -import ( - "testing" -) - -// TODO: TestNewTreeHeadV1FromLogRoot -func TestNewTreeHeadV1FromLogRoot(t *testing.T) { -} - -// TODO: TestNewNodePathFromHashPath -func TestNewNodePathFromHashPath(t *testing.T) { -} - -// TODO: TestStItemListFromLeaves -func TestStItemListFromLeaves(t *testing.T) { -} -- cgit v1.2.3