From b540f681b4cdf740f9b8d1e584fd2b107fc1b090 Mon Sep 17 00:00:00 2001
From: Rasmus Dahlberg <rasmus.dahlberg@kau.se>
Date: Tue, 1 Dec 2020 20:42:21 +0100
Subject: started to clean-up instance

Things like opening files is better place in the server package.  Any
code that is difficult to test should also not be in the STFE package.
---
 handler.go      |  14 ++++--
 handler_test.go |  24 +++++-----
 instance.go     | 141 ++++++++++++++++++++------------------------------------
 server/main.go  |  45 ++++++++++++++++--
 4 files changed, 112 insertions(+), 112 deletions(-)

diff --git a/handler.go b/handler.go
index bd1fbdb..58771c8 100644
--- a/handler.go
+++ b/handler.go
@@ -12,16 +12,22 @@ import (
 	"github.com/google/trillian/types"
 )
 
-// handler implements the http.Handler interface, and contains a reference
+// Handler implements the http.Handler interface, and contains a reference
 // to an STFE server instance as well as a function that uses it.
-type handler struct {
+type Handler struct {
 	instance *Instance // STFE server instance
 	endpoint Endpoint  // e.g., add-entry
 	method   string    // e.g., GET
 	handler  func(context.Context, *Instance, http.ResponseWriter, *http.Request) (int, error)
 }
 
-func (a handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+// Path returns a path that should be configured for this handler
+func (h Handler) Path() string {
+	return h.endpoint.Path("", h.instance.LogParameters.Prefix)
+}
+
+// ServeHTTP is part of the http.Handler interface
+func (a Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	// export prometheus metrics
 	var now time.Time = time.Now()
 	var statusCode int
@@ -47,7 +53,7 @@ func (a handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func (a handler) sendHTTPError(w http.ResponseWriter, statusCode int, err error) {
+func (a Handler) sendHTTPError(w http.ResponseWriter, statusCode int, err error) {
 	http.Error(w, http.StatusText(statusCode), statusCode)
 }
 
diff --git a/handler_test.go b/handler_test.go
index 689541b..3bcc702 100644
--- a/handler_test.go
+++ b/handler_test.go
@@ -43,17 +43,17 @@ func newTestHandler(t *testing.T, signer crypto.Signer) *testHandler {
 	}
 }
 
-func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]handler {
-	return map[Endpoint]handler{
-		EndpointGetSth:              handler{instance: th.instance, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet},
-		EndpointGetConsistencyProof: handler{instance: th.instance, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet},
-		EndpointGetProofByHash:      handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},
-		EndpointGetAnchors:          handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},
-		EndpointGetEntries:          handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet},
+func (th *testHandler) getHandlers(t *testing.T) map[Endpoint]Handler {
+	return map[Endpoint]Handler{
+		EndpointGetSth:              Handler{instance: th.instance, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet},
+		EndpointGetConsistencyProof: Handler{instance: th.instance, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet},
+		EndpointGetProofByHash:      Handler{instance: th.instance, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},
+		EndpointGetAnchors:          Handler{instance: th.instance, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},
+		EndpointGetEntries:          Handler{instance: th.instance, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet},
 	}
 }
 
-func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) handler {
+func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) Handler {
 	handler, ok := th.getHandlers(t)[endpoint]
 	if !ok {
 		t.Fatalf("no such get endpoint: %s", endpoint)
@@ -61,13 +61,13 @@ func (th *testHandler) getHandler(t *testing.T, endpoint Endpoint) handler {
 	return handler
 }
 
-func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]handler {
-	return map[Endpoint]handler{
-		EndpointAddEntry: handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost},
+func (th *testHandler) postHandlers(t *testing.T) map[Endpoint]Handler {
+	return map[Endpoint]Handler{
+		EndpointAddEntry: Handler{instance: th.instance, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost},
 	}
 }
 
-func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) handler {
+func (th *testHandler) postHandler(t *testing.T, endpoint Endpoint) Handler {
 	handler, ok := th.postHandlers(t)[endpoint]
 	if !ok {
 		t.Fatalf("no such post endpoint: %s", endpoint)
diff --git a/instance.go b/instance.go
index 2c108aa..b0a2d9e 100644
--- a/instance.go
+++ b/instance.go
@@ -9,31 +9,12 @@ import (
 	"crypto/sha256"
 	"crypto/x509"
 	"encoding/base64"
-	"io/ioutil"
 	"net/http"
 
-	"github.com/golang/glog"
 	"github.com/google/trillian"
 	"github.com/system-transparency/stfe/x509util"
 )
 
-type Endpoint string
-
-const (
-	EndpointAddEntry            = Endpoint("add-entry")
-	EndpointGetEntries          = Endpoint("get-entries")
-	EndpointGetAnchors          = Endpoint("get-anchors")
-	EndpointGetProofByHash      = Endpoint("get-proof-by-hash")
-	EndpointGetConsistencyProof = Endpoint("get-consistency-proof")
-	EndpointGetSth              = Endpoint("get-sth")
-)
-
-// Path joins a number of components to form a full endpoint path, e.g., base
-// ("example.com"), prefix ("st/v1"), and the endpoint itself ("get-sth").
-func (e Endpoint) Path(components ...string) string {
-	return strings.Join(append(components, string(e)), "/")
-}
-
 // Instance is an instance of a particular log front-end
 type Instance struct {
 	LogParameters *LogParameters
@@ -55,6 +36,18 @@ type LogParameters struct {
 	HashType   crypto.Hash // hash function used by Trillian
 }
 
+// Endpoint is a named HTTP API endpoint
+type Endpoint string
+
+const (
+	EndpointAddEntry            = Endpoint("add-entry")
+	EndpointGetEntries          = Endpoint("get-entries")
+	EndpointGetAnchors          = Endpoint("get-anchors")
+	EndpointGetProofByHash      = Endpoint("get-proof-by-hash")
+	EndpointGetConsistencyProof = Endpoint("get-consistency-proof")
+	EndpointGetSth              = Endpoint("get-sth")
+)
+
 func (i Instance) String() string {
 	return fmt.Sprintf("%s Deadline(%v)\n", i.LogParameters, i.Deadline)
 }
@@ -63,103 +56,67 @@ func (p LogParameters) String() string {
 	return fmt.Sprintf("LogId(%s) TreeId(%d) Prefix(%s) NumAnchors(%d)", base64.StdEncoding.EncodeToString(p.LogId), p.TreeId, p.Prefix, len(p.AnchorList))
 }
 
-func (i *LogParameters) id() string {
-	return base64.StdEncoding.EncodeToString(i.LogId)
+func (e Endpoint) String() string {
+	return string(e)
 }
 
-// NewInstance returns a new STFE Instance
-func NewInstance(lp *LogParameters, client trillian.TrillianLogClient, deadline time.Duration, mux *http.ServeMux) (*Instance, error) {
-	i := &Instance{
+// NewInstance creates a new STFE instance
+func NewInstance(lp *LogParameters, client trillian.TrillianLogClient, deadline time.Duration, mux *http.ServeMux) *Instance {
+	return &Instance{
 		LogParameters: lp,
 		Client:        client,
 		Deadline:      deadline,
 	}
-	i.registerHandlers(mux)
-	return i, nil
 }
 
-// NewLogParameters initializes log parameters, assuming ed25519 signatures.
-func NewLogParameters(treeId int64, prefix string, anchorPath, keyPath string, maxRange, maxChain int64) (*LogParameters, error) {
-	anchorList, anchorPool, err := loadTrustAnchors(anchorPath)
+// NewLogParameters creates new log parameters.  Note that the signer is
+// assumed to be an ed25519 signing key.  Could be fixed at some point.
+func NewLogParameters(treeId int64, prefix string, anchors []*x509.Certificate, signer crypto.Signer, maxRange, maxChain int64) (*LogParameters, error) {
+	pub, err := x509.MarshalPKIXPublicKey(signer.Public())
 	if err != nil {
-		return nil, err
-	}
-
-	pem, err := ioutil.ReadFile(keyPath)
-	if err != nil {
-		return nil, fmt.Errorf("failed reading %s: %v", keyPath, err)
+		return nil, fmt.Errorf("failed DER encoding SubjectPublicKeyInfo: %v", err)
 	}
-	key, err := x509util.NewEd25519PrivateKey(pem)
-	if err != nil {
-		return nil, err
+	if maxRange < 1 {
+		return nil, fmt.Errorf("invalid max range: must be at least 1")
 	}
-
-	pub, err := x509.MarshalPKIXPublicKey(key.Public())
-	if err != nil {
-		return nil, fmt.Errorf("failed DER encoding SubjectPublicKeyInfo: %v", err)
+	if maxChain < 1 {
+		return nil, fmt.Errorf("invalid max chain: must be at least 1")
 	}
 	hasher := sha256.New()
 	hasher.Write(pub)
-	logId := hasher.Sum(nil)
-
 	return &LogParameters{
-		LogId:      logId,
+		LogId:      hasher.Sum(nil),
 		TreeId:     treeId,
 		Prefix:     prefix,
 		MaxRange:   maxRange,
 		MaxChain:   maxChain,
-		AnchorPool: anchorPool,
-		AnchorList: anchorList,
+		AnchorPool: x509util.NewCertPool(anchors),
+		AnchorList: anchors,
 		KeyUsage:   []x509.ExtKeyUsage{}, // placeholder, must be tested if used
-		Signer:     key,
-		HashType:   crypto.SHA256,
+		Signer:     signer,
+		HashType:   crypto.SHA256, // STFE assumes RFC 6962 hashing
 	}, nil
 }
 
-func (i *Instance) registerHandlers(mux *http.ServeMux) {
-	for _, endpoint := range []struct {
-		path    string
-		handler handler
-	}{
-		{
-			EndpointAddEntry.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost},
-		},
-		{
-			EndpointGetEntries.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet},
-		},
-		{
-			EndpointGetAnchors.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},
-		},
-		{
-			EndpointGetProofByHash.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},
-		},
-		{
-			EndpointGetConsistencyProof.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet},
-		},
-		{
-			EndpointGetSth.Path("", i.LogParameters.Prefix),
-			handler{instance: i, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet},
-		},
-	} {
-		glog.Infof("adding handler for %v", endpoint.path)
-		mux.Handle(endpoint.path, endpoint.handler)
-	}
+// Path joins a number of components to form a full endpoint path, e.g., base
+// ("example.com"), prefix ("st/v1"), and the endpoint itself ("get-sth").
+func (e Endpoint) Path(components ...string) string {
+	return strings.Join(append(components, string(e)), "/")
 }
 
-// loadTrustAnchors loads a list of PEM-encoded certificates from file
-func loadTrustAnchors(path string) ([]*x509.Certificate, *x509.CertPool, error) {
-	pem, err := ioutil.ReadFile(path)
-	if err != nil {
-		return nil, nil, fmt.Errorf("failed reading trust anchors: %v", err)
-	}
-	anchorList, err := x509util.NewCertificateList(pem)
-	if err != nil || len(anchorList) == 0 {
-		return nil, nil, fmt.Errorf("failed parsing trust anchors: %v", err)
+// TODO: id() docdoc
+func (i *LogParameters) id() string {
+	return base64.StdEncoding.EncodeToString(i.LogId)
+}
+
+// Handlers returns a list of STFE handlers
+func (i *Instance) Handlers() []Handler {
+	return []Handler{
+		Handler{instance: i, handler: addEntry, endpoint: EndpointAddEntry, method: http.MethodPost},
+		Handler{instance: i, handler: getEntries, endpoint: EndpointGetEntries, method: http.MethodGet},
+		Handler{instance: i, handler: getAnchors, endpoint: EndpointGetAnchors, method: http.MethodGet},
+		Handler{instance: i, handler: getProofByHash, endpoint: EndpointGetProofByHash, method: http.MethodGet},
+		Handler{instance: i, handler: getConsistencyProof, endpoint: EndpointGetConsistencyProof, method: http.MethodGet},
+		Handler{instance: i, handler: getSth, endpoint: EndpointGetSth, method: http.MethodGet},
 	}
-	return anchorList, x509util.NewCertPool(anchorList), nil
 }
diff --git a/server/main.go b/server/main.go
index d6a7aa5..c60f95d 100644
--- a/server/main.go
+++ b/server/main.go
@@ -3,14 +3,18 @@ package main
 
 import (
 	"flag"
+	"fmt"
 	"time"
 
+	"crypto/x509"
+	"io/ioutil"
 	"net/http"
 
 	"github.com/golang/glog"
 	"github.com/google/trillian"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/system-transparency/stfe"
+	"github.com/system-transparency/stfe/x509util"
 	"google.golang.org/grpc"
 )
 
@@ -44,14 +48,31 @@ func main() {
 	glog.Info("Adding prometheus handler on path: /metrics")
 	http.Handle("/metrics", promhttp.Handler())
 
-	lp, err := stfe.NewLogParameters(*trillianID, *prefix, *anchorPath, *keyPath, *maxRange, *maxChain)
+	glog.Infof("Loading trust anchors from file: %s", *anchorPath)
+	anchors, err := loadCertificates(*anchorPath)
 	if err != nil {
-		glog.Fatalf("failed setting up log parameters: %v", err)
+		glog.Fatalf("no trust anchors: %v", err)
+	}
+
+	glog.Infof("Loading Ed25519 signing key from file: %s", *keyPath)
+	pem, err := ioutil.ReadFile(*keyPath)
+	if err != nil {
+		glog.Fatalf("no signing key: %v", err)
+	}
+	signer, err := x509util.NewEd25519PrivateKey(pem)
+	if err != nil {
+		glog.Fatalf("no signing key: %v", err)
 	}
 
-	i, err := stfe.NewInstance(lp, client, *rpcDeadline, mux)
+	lp, err := stfe.NewLogParameters(*trillianID, *prefix, anchors, signer, *maxRange, *maxChain)
 	if err != nil {
-		glog.Fatalf("failed setting up log instance: %v", err)
+		glog.Fatalf("failed setting up log parameters: %v", err)
+	}
+
+	i := stfe.NewInstance(lp, client, *rpcDeadline, mux)
+	for _, handler := range i.Handlers() {
+		glog.Infof("adding handler: %s", handler.Path())
+		mux.Handle(handler.Path(), handler)
 	}
 	glog.Infof("Configured: %s", i)
 
@@ -64,3 +85,19 @@ func main() {
 
 	glog.Flush()
 }
+
+// loadCertificates loads a non-empty list of PEM-encoded certificates from file
+func loadCertificates(path string) ([]*x509.Certificate, error) {
+	pem, err := ioutil.ReadFile(path)
+	if err != nil {
+		return nil, fmt.Errorf("failed reading %s: %v", path, err)
+	}
+	anchors, err := x509util.NewCertificateList(pem)
+	if err != nil {
+		return nil, fmt.Errorf("failed parsing: %v", err)
+	}
+	if len(anchors) == 0 {
+		return nil, fmt.Errorf("no trust anchors")
+	}
+	return anchors, nil
+}
-- 
cgit v1.2.3