From 868366cf35e2e649fe7265201d527dbb76bbaf68 Mon Sep 17 00:00:00 2001 From: Rasmus Dahlberg Date: Mon, 30 Nov 2020 20:42:05 +0100 Subject: added an endpoint type --- client/client.go | 19 +++++++++++++------ handler.go | 10 +++++----- handler_test.go | 56 ++++++++++++++++++++++++++++---------------------------- instance.go | 54 +++++++++++++++++++++++++++++++++++++++++++++--------- instance_test.go | 2 +- reqres_test.go | 10 +++++++--- server/main.go | 4 ++-- 7 files changed, 101 insertions(+), 54 deletions(-) diff --git a/client/client.go b/client/client.go index 6eb99c2..7c78034 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "strings" "crypto/ed25519" "crypto/tls" @@ -98,7 +99,8 @@ func (c *Client) AddEntry(ctx context.Context, name, checksum []byte) (*stfe.StI } glog.V(3).Infof("created post data: %s", string(data)) - req, err := http.NewRequest("POST", c.protocol()+c.Log.BaseUrl+"/add-entry", bytes.NewBuffer(data)) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointAddEntry.String()}, "/") + req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } @@ -124,7 +126,8 @@ func (c *Client) AddEntry(ctx context.Context, name, checksum []byte) (*stfe.StI // GetSth fetches and verifies the most recent STH. Safe to use without a // client chain and corresponding private key. func (c *Client) GetSth(ctx context.Context) (*stfe.StItem, error) { - req, err := http.NewRequest("GET", c.protocol()+c.Log.BaseUrl+"/get-sth", nil) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointGetSth.String()}, "/") + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } @@ -149,7 +152,8 @@ func (c *Client) GetSth(ctx context.Context) (*stfe.StItem, error) { // GetConsistencyProof fetches and verifies a consistency proof between two // STHs. Safe to use without a client chain and corresponding private key. func (c *Client) GetConsistencyProof(ctx context.Context, first, second *stfe.StItem) (*stfe.StItem, error) { - req, err := http.NewRequest("GET", c.protocol()+c.Log.BaseUrl+"/get-consistency-proof", nil) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointGetConsistencyProof.String()}, "/") + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } @@ -177,7 +181,8 @@ func (c *Client) GetConsistencyProof(ctx context.Context, first, second *stfe.St // STH. Safe to use without a client chain and corresponding private key. func (c *Client) GetProofByHash(ctx context.Context, treeSize uint64, rootHash, leaf []byte) (*stfe.StItem, error) { leafHash := rfc6962.DefaultHasher.HashLeaf(leaf) - req, err := http.NewRequest("GET", c.protocol()+c.Log.BaseUrl+"/get-proof-by-hash", nil) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointGetProofByHash.String()}, "/") + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } @@ -209,7 +214,8 @@ func (c *Client) GetProofByHash(ctx context.Context, treeSize uint64, rootHash, // Note that a certificate chain is considered valid if it is chained correctly. // In other words, the caller may want to check whether the anchor is trusted. func (c *Client) GetEntries(ctx context.Context, start, end uint64) ([]*stfe.GetEntryResponse, error) { - req, err := http.NewRequest("GET", c.protocol()+c.Log.BaseUrl+"/get-entries", nil) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointGetEntries.String()}, "/") + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } @@ -246,7 +252,8 @@ func (c *Client) GetEntries(ctx context.Context, start, end uint64) ([]*stfe.Get // GetAnchors fetches the log's trust anchors. Safe to use without a client // chain and corresponding private key. func (c *Client) GetAnchors(ctx context.Context) ([]*x509.Certificate, error) { - req, err := http.NewRequest("GET", c.protocol()+c.Log.BaseUrl+"/get-anchors", nil) + url := c.protocol() + strings.Join([]string{c.Log.BaseUrl, stfe.EndpointGetAnchors.String()}, "/") + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed creating http request: %v", err) } diff --git a/handler.go b/handler.go index 00fd686..d6d4224 100644 --- a/handler.go +++ b/handler.go @@ -16,7 +16,7 @@ import ( // to an STFE server instance as well as a function that uses it. type handler struct { instance *Instance // STFE server instance - endpoint string // e.g., add-entry + endpoint Endpoint // e.g., add-entry method string // e.g., GET handler func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error) } @@ -26,16 +26,16 @@ func (a handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var now time.Time = time.Now() var statusCode int defer func() { - rspcnt.Inc(a.instance.LogParameters.id(), a.endpoint, fmt.Sprintf("%d", statusCode)) - latency.Observe(time.Now().Sub(now).Seconds(), a.instance.LogParameters.id(), a.endpoint, fmt.Sprintf("%d", statusCode)) + rspcnt.Inc(a.instance.LogParameters.id(), a.endpoint.String(), fmt.Sprintf("%d", statusCode)) + latency.Observe(time.Now().Sub(now).Seconds(), a.instance.LogParameters.id(), a.endpoint.String(), fmt.Sprintf("%d", statusCode)) }() - reqcnt.Inc(a.instance.LogParameters.id(), a.endpoint) + reqcnt.Inc(a.instance.LogParameters.id(), a.endpoint.String()) ctx, cancel := context.WithDeadline(r.Context(), now.Add(a.instance.Deadline)) defer cancel() if r.Method != a.method { - glog.Warningf("%s: got HTTP %s, wanted HTTP %s", a.instance.LogParameters.Prefix+a.endpoint, r.Method, a.method) + glog.Warningf("%s: got HTTP %s, wanted HTTP %s", a.instance.LogParameters.Prefix+a.endpoint.String(), r.Method, a.method) a.sendHTTPError(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed: %s", r.Method)) return } diff --git a/handler_test.go b/handler_test.go index 5390d1c..bdf0752 100644 --- a/handler_test.go +++ b/handler_test.go @@ -44,17 +44,17 @@ func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler { } } -func (th *testHandler) getHandlers(t *testing.T) map[string]handler { - return map[string]handler{ - "get-sth": handler{instance: th.instance, handler: getSth, endpoint: "get-sth", method: http.MethodGet}, - "get-consistency-proof": handler{instance: th.instance, handler: getConsistencyProof, endpoint: "get-consistency-proof", method: http.MethodGet}, - "get-proof-by-hash": handler{instance: th.instance, handler: getProofByHash, endpoint: "get-proof-by-hash", method: http.MethodGet}, - "get-anchors": handler{instance: th.instance, handler: getAnchors, endpoint: "get-anchors", method: http.MethodGet}, - "get-entries": handler{instance: th.instance, handler: getEntries, endpoint: "get-entries", method: http.MethodGet}, +func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]handler { + return map[Endpoint]handler{ + EndpointGetSth: handler{instance: th.instance, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet}, + EndpointGetConsistencyProof: handler{instance: th.instance, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet}, + EndpointGetProofByHash: handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, + EndpointGetAnchors: handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, + EndpointGetEntries: handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, } } -func (th *testHandler) getHandler(t *testing.T, endpoint string) handler { +func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) handler { handler, ok := th.getHandlers(t)[endpoint] if !ok { t.Fatalf("no such get endpoint: %s", endpoint) @@ -62,13 +62,13 @@ func (th *testHandler) getHandler(t *testing.T, endpoint string) handler { return handler } -func (th *testHandler) postHandlers(t *testing.T) map[string]handler { - return map[string]handler{ - "add-entry": handler{instance: th.instance, handler: addEntry, endpoint: "add-entry", method: http.MethodPost}, +func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]handler { + return map[Endpoint]handler{ + EndpointAddEntry: handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, } } -func (th *testHandler) postHandler(t *testing.T, endpoint string) handler { +func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) handler { handler, ok := th.postHandlers(t)[endpoint] if !ok { t.Fatalf("no such post endpoint: %s", endpoint) @@ -82,11 +82,11 @@ func TestGetHandlersRejectPost(t *testing.T) { defer th.mockCtrl.Finish() for endpoint, handler := range th.getHandlers(t) { - t.Run(endpoint, func(t *testing.T) { + t.Run(endpoint.String(), func(t *testing.T) { s := httptest.NewServer(handler) defer s.Close() - url := s.URL + strings.Join([]string{th.instance.LogParameters.Prefix, endpoint}, "/") + url := strings.Join([]string{s.URL, th.instance.LogParameters.Prefix, endpoint.String()}, "/") if rsp, err := http.Post(url, "application/json", nil); err != nil { t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err) } else if rsp.StatusCode != http.StatusMethodNotAllowed { @@ -102,11 +102,11 @@ func TestPostHandlersRejectGet(t *testing.T) { defer th.mockCtrl.Finish() for endpoint, handler := range th.postHandlers(t) { - t.Run(endpoint, func(t *testing.T) { + t.Run(endpoint.String(), func(t *testing.T) { s := httptest.NewServer(handler) defer s.Close() - url := s.URL + strings.Join([]string{th.instance.LogParameters.Prefix, endpoint}, "/") + url := strings.Join([]string{s.URL, th.instance.LogParameters.Prefix, endpoint.String()}, "/") if rsp, err := http.Get(url); err != nil { t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err) } else if rsp.StatusCode != http.StatusMethodNotAllowed { @@ -121,14 +121,14 @@ func TestGetAnchors(t *testing.T) { th := newTestHandler(t, nil) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-anchors" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointGetAnchors.String()}, "/") req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) } w := httptest.NewRecorder() - th.getHandler(t, "get-anchors").ServeHTTP(w, req) + th.getHandler(t, EndpointGetAnchors).ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, http.StatusOK) return @@ -200,7 +200,7 @@ func TestGetEntries(t *testing.T) { th := newTestHandler(t, nil) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-entries" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointGetEntries.String()}, "/") req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) @@ -214,7 +214,7 @@ func TestGetEntries(t *testing.T) { th.client.EXPECT().GetLeavesByRange(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() - th.getHandler(t, "get-entries").ServeHTTP(w, req) + th.getHandler(t, EndpointGetEntries).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } @@ -313,7 +313,7 @@ func TestAddEntry(t *testing.T) { th := newTestHandler(t, table.signer) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/add-entry" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointAddEntry.String()}, "/") req, err := http.NewRequest("POST", url, table.breq) if err != nil { t.Fatalf("failed creating http request: %v", err) @@ -324,7 +324,7 @@ func TestAddEntry(t *testing.T) { th.client.EXPECT().QueueLeaf(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } w := httptest.NewRecorder() - th.postHandler(t, "add-entry").ServeHTTP(w, req) + th.postHandler(t, EndpointAddEntry).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } @@ -407,7 +407,7 @@ func TestGetSth(t *testing.T) { th := newTestHandler(t, table.signer) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-sth" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointGetSth.String()}, "/") req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) @@ -415,7 +415,7 @@ func TestGetSth(t *testing.T) { w := httptest.NewRecorder() th.client.EXPECT().GetLatestSignedLogRoot(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) - th.getHandler(t, "get-sth").ServeHTTP(w, req) + th.getHandler(t, EndpointGetSth).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } @@ -511,7 +511,7 @@ func TestGetConsistencyProof(t *testing.T) { th := newTestHandler(t, nil) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-consistency-proof" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointGetConsistencyProof.String()}, "/") req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) @@ -525,7 +525,7 @@ func TestGetConsistencyProof(t *testing.T) { if table.trsp != nil || table.terr != nil { th.client.EXPECT().GetConsistencyProof(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } - th.getHandler(t, "get-consistency-proof").ServeHTTP(w, req) + th.getHandler(t, EndpointGetConsistencyProof).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } @@ -620,7 +620,7 @@ func TestGetProofByHash(t *testing.T) { th := newTestHandler(t, nil) defer th.mockCtrl.Finish() - url := "http://example.com" + th.instance.LogParameters.Prefix + "/get-proof-by-hash" + url := strings.Join([]string{"http://example.com", th.instance.LogParameters.Prefix, EndpointGetProofByHash.String()}, "/") req, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("failed creating http request: %v", err) @@ -634,7 +634,7 @@ func TestGetProofByHash(t *testing.T) { if table.trsp != nil || table.terr != nil { th.client.EXPECT().GetInclusionProofByHash(newDeadlineMatcher(), gomock.Any()).Return(table.trsp, table.terr) } - th.getHandler(t, "get-proof-by-hash").ServeHTTP(w, req) + th.getHandler(t, EndpointGetProofByHash).ServeHTTP(w, req) if w.Code != table.wantCode { t.Errorf("GET(%s)=%d, want http status code %d", url, w.Code, table.wantCode) } diff --git a/instance.go b/instance.go index 510c7ae..8d9070a 100644 --- a/instance.go +++ b/instance.go @@ -3,6 +3,7 @@ package stfe import ( "crypto" "fmt" + "strings" "time" "crypto/sha256" @@ -16,6 +17,23 @@ import ( "github.com/system-transparency/stfe/x509util" ) +type Endpoint string + +const ( + EndpointAddEntry = Endpoint("add-entry") + EndpointGetEntries = Endpoint("get-entries") + EndpointGetAnchors = Endpoint("get-anchors") + EndpointGetProofByHash = Endpoint("get-proof-by-hash") + EndpointGetConsistencyProof = Endpoint("get-consistency-proof") + EndpointGetSth = Endpoint("get-sth") +) + +func (e Endpoint) String() string { + return string(e) +} + +// TODO: type EndpointParam string? + // Instance is an instance of a particular log front-end type Instance struct { LogParameters *LogParameters @@ -25,9 +43,9 @@ type Instance struct { // LogParameters is a collection of log parameters type LogParameters struct { - LogId []byte // used externally by everyone - TreeId int64 // used internally by Trillian - Prefix string + LogId []byte // used externally by everyone + TreeId int64 // used internally by Trillian + Prefix string // e.g., "test" for /test MaxRange int64 // max entries per get-entries request MaxChain int64 // max submitter certificate chain length AnchorPool *x509.CertPool // for chain verification @@ -103,12 +121,30 @@ func (i *Instance) registerHandlers(mux *http.ServeMux) { path string handler handler }{ - {i.LogParameters.Prefix + "/add-entry", handler{instance: i, handler: addEntry, endpoint: "add-entry", method: http.MethodPost}}, - {i.LogParameters.Prefix + "/get-entries", handler{instance: i, handler: getEntries, endpoint: "get-entries", method: http.MethodGet}}, - {i.LogParameters.Prefix + "/get-anchors", handler{instance: i, handler: getAnchors, endpoint: "get-anchors", method: http.MethodGet}}, - {i.LogParameters.Prefix + "/get-proof-by-hash", handler{instance: i, handler: getProofByHash, endpoint: "get-proof-by-hash", method: http.MethodGet}}, - {i.LogParameters.Prefix + "/get-consistency-proof", handler{instance: i, handler: getConsistencyProof, endpoint: "get-consistency-proof", method: http.MethodGet}}, - {i.LogParameters.Prefix + "/get-sth", handler{instance: i, handler: getSth, endpoint: "get-sth", method: http.MethodGet}}, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointAddEntry.String()}, "/"), + handler{instance: i, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost}, + }, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointGetEntries.String()}, "/"), + handler{instance: i, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet}, + }, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointGetAnchors.String()}, "/"), + handler{instance: i, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet}, + }, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointGetProofByHash.String()}, "/"), + handler{instance: i, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet}, + }, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointGetConsistencyProof.String()}, "/"), + handler{instance: i, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet}, + }, + { + strings.Join([]string{"", i.LogParameters.Prefix, EndpointGetSth.String()}, "/"), + handler{instance: i, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet}, + }, } { glog.Infof("adding handler for %v", endpoint.path) mux.Handle(endpoint.path, endpoint.handler) diff --git a/instance_test.go b/instance_test.go index 582b232..7facdd6 100644 --- a/instance_test.go +++ b/instance_test.go @@ -15,7 +15,7 @@ var ( testMaxRange = int64(3) testMaxChain = int64(3) testTreeId = int64(0) - testPrefix = "/test" + testPrefix = "test" testHashType = crypto.SHA256 testExtKeyUsage = []x509.ExtKeyUsage{} ) diff --git a/reqres_test.go b/reqres_test.go index 1a6304b..8334bd7 100644 --- a/reqres_test.go +++ b/reqres_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "strconv" + "strings" "testing" "crypto/x509" @@ -65,7 +66,8 @@ func TestNewGetEntriesRequest(t *testing.T) { end: fmt.Sprintf("%d", testMaxRange-1), }, } { - r, err := http.NewRequest("GET", "http://example.com/"+lp.Prefix+"/get-entries", nil) + url := strings.Join([]string{"http://example.com/", lp.Prefix, EndpointGetEntries.String()}, "/") + r, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must make http request in test %q: %v", table.description, err) } @@ -135,7 +137,8 @@ func TestNewGetProofByHashRequest(t *testing.T) { hash: b64(testNodeHash), }, } { - r, err := http.NewRequest("GET", "http://example.com/"+lp.Prefix+"/get-proof-by-hash", nil) + url := strings.Join([]string{"http://example.com/", lp.Prefix, EndpointGetProofByHash.String()}, "/") + r, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must make http request in test %q: %v", table.description, err) } @@ -199,7 +202,8 @@ func TestNewGetConsistencyProofRequest(t *testing.T) { second: "2", }, } { - r, err := http.NewRequest("GET", "http://example.com/"+lp.Prefix+"/get-consistency-proof", nil) + url := strings.Join([]string{"http://example.com/", lp.Prefix, EndpointGetConsistencyProof.String()}, "/") + r, err := http.NewRequest("GET", url, nil) if err != nil { t.Fatalf("must make http request in test %q: %v", table.description, err) } diff --git a/server/main.go b/server/main.go index f98b114..d6a7aa5 100644 --- a/server/main.go +++ b/server/main.go @@ -17,7 +17,7 @@ import ( var ( httpEndpoint = flag.String("http_endpoint", "localhost:6965", "host:port specification of where stfe serves clients") rpcBackend = flag.String("log_rpc_server", "localhost:6962", "host:port specification of where Trillian serves clients") - prefix = flag.String("prefix", "/st/v1", "a prefix that proceeds each endpoint path") + prefix = flag.String("prefix", "st/v1", "a prefix that proceeds each endpoint path") trillianID = flag.Int64("trillian_id", 5991359069696313945, "log identifier in the Trillian database") rpcDeadline = flag.Duration("rpc_deadline", time.Second*10, "deadline for backend RPC requests") anchorPath = flag.String("anchor_path", "../x509util/testdata/anchors.pem", "path to a file containing PEM-encoded X.509 root certificates") @@ -55,7 +55,7 @@ func main() { } glog.Infof("Configured: %s", i) - glog.Infof("Serving on %v%v", *httpEndpoint, *prefix) + glog.Infof("Serving on %v/%v", *httpEndpoint, *prefix) srv := http.Server{Addr: *httpEndpoint} err = srv.ListenAndServe() if err != http.ErrServerClosed { -- cgit v1.2.3