aboutsummaryrefslogtreecommitdiff
path: root/instance_test.go
blob: de539a16caa9fd65e3e278ca928a55cb265a3b24 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package stfe

import (
	"crypto"
	"testing"

	"net/http"
	"net/http/httptest"

	"github.com/golang/mock/gomock"
	"github.com/google/certificate-transparency-go/trillian/mockclient"
	"github.com/system-transparency/stfe/testdata"
	"github.com/system-transparency/stfe/types"
)

type testInstance struct {
	ctrl     *gomock.Controller
	client   *mockclient.MockTrillianLogClient
	instance *Instance
}

// newTestInstances sets up a test instance that uses default log parameters
// with an optional signer, see newLogParameters() for further details.  The
// SthSource is instantiated with an ActiveSthSource that has (i) the default
// STH as the currently cosigned STH based on testdata.Ed25519VkWitness, and
// (ii) the default STH without any cosignatures as the currently stable STH.
func newTestInstance(t *testing.T, signer crypto.Signer) *testInstance {
	t.Helper()
	ctrl := gomock.NewController(t)
	client := mockclient.NewMockTrillianLogClient(ctrl)
	return &testInstance{
		ctrl:   ctrl,
		client: client,
		instance: &Instance{
			Client:        client,
			LogParameters: newLogParameters(t, signer),
			SthSource: &ActiveSthSource{
				client:          client,
				logParameters:   newLogParameters(t, signer),
				currCosth:       testdata.DefaultCosth(t, testdata.Ed25519VkLog, [][32]byte{testdata.Ed25519VkWitness}),
				nextCosth:       testdata.DefaultCosth(t, testdata.Ed25519VkLog, nil),
				cosignatureFrom: make(map[[types.NamespaceFingerprintSize]byte]bool),
			},
		},
	}
}

// getHandlers returns all endpoints that use HTTP GET as a map to handlers
func (ti *testInstance) getHandlers(t *testing.T) map[Endpoint]Handler {
	t.Helper()
	return map[Endpoint]Handler{
		EndpointGetLatestSth:   Handler{Instance: ti.instance, Handler: getLatestSth, Endpoint: EndpointGetLatestSth, Method: http.MethodGet},
		EndpointGetStableSth:   Handler{Instance: ti.instance, Handler: getStableSth, Endpoint: EndpointGetStableSth, Method: http.MethodGet},
		EndpointGetCosignedSth: Handler{Instance: ti.instance, Handler: getCosignedSth, Endpoint: EndpointGetCosignedSth, Method: http.MethodGet},
	}
}

// postHandlers returns all endpoints that use HTTP POST as a map to handlers
func (ti *testInstance) postHandlers(t *testing.T) map[Endpoint]Handler {
	t.Helper()
	return map[Endpoint]Handler{
		EndpointAddEntry:            Handler{Instance: ti.instance, Handler: addEntry, Endpoint: EndpointAddEntry, Method: http.MethodPost},
		EndpointAddCosignature:      Handler{Instance: ti.instance, Handler: addCosignature, Endpoint: EndpointAddCosignature, Method: http.MethodPost},
		EndpointGetConsistencyProof: Handler{Instance: ti.instance, Handler: getConsistencyProof, Endpoint: EndpointGetConsistencyProof, Method: http.MethodPost},
		EndpointGetProofByHash:      Handler{Instance: ti.instance, Handler: getProofByHash, Endpoint: EndpointGetProofByHash, Method: http.MethodPost},
		EndpointGetEntries:          Handler{Instance: ti.instance, Handler: getEntries, Endpoint: EndpointGetEntries, Method: http.MethodPost},
	}
}

// getHandler must return a particular HTTP GET handler
func (ti *testInstance) getHandler(t *testing.T, endpoint Endpoint) Handler {
	t.Helper()
	handler, ok := ti.getHandlers(t)[endpoint]
	if !ok {
		t.Fatalf("must return HTTP GET handler for endpoint: %s", endpoint)
	}
	return handler
}

// postHandler must return a particular HTTP POST handler
func (ti *testInstance) postHandler(t *testing.T, endpoint Endpoint) Handler {
	t.Helper()
	handler, ok := ti.postHandlers(t)[endpoint]
	if !ok {
		t.Fatalf("must return HTTP POST handler for endpoint: %s", endpoint)
	}
	return handler
}

// TestHandlers checks that we configured all endpoints and that there are no
// unexpected ones.
func TestHandlers(t *testing.T) {
	endpoints := map[Endpoint]bool{
		EndpointAddEntry:            false,
		EndpointAddCosignature:      false,
		EndpointGetLatestSth:        false,
		EndpointGetStableSth:        false,
		EndpointGetCosignedSth:      false,
		EndpointGetConsistencyProof: false,
		EndpointGetProofByHash:      false,
		EndpointGetEntries:          false,
	}
	i := &Instance{nil, newLogParameters(t, nil), nil}
	for _, handler := range i.Handlers() {
		if _, ok := endpoints[handler.Endpoint]; !ok {
			t.Errorf("got unexpected endpoint: %s", handler.Endpoint)
		}
		endpoints[handler.Endpoint] = true
	}
	for endpoint, ok := range endpoints {
		if !ok {
			t.Errorf("endpoint %s is not configured", endpoint)
		}
	}
}

// TestGetHandlersRejectPost checks that all get handlers reject post requests
func TestGetHandlersRejectPost(t *testing.T) {
	ti := newTestInstance(t, nil)
	defer ti.ctrl.Finish()

	for endpoint, handler := range ti.getHandlers(t) {
		t.Run(string(endpoint), func(t *testing.T) {
			s := httptest.NewServer(handler)
			defer s.Close()

			url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
			if rsp, err := http.Post(url, "application/json", nil); err != nil {
				t.Fatalf("http.Post(%s)=(_,%q), want (_,nil)", url, err)
			} else if rsp.StatusCode != http.StatusMethodNotAllowed {
				t.Errorf("http.Post(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
			}
		})
	}
}

// TestPostHandlersRejectGet checks that all post handlers reject get requests
func TestPostHandlersRejectGet(t *testing.T) {
	ti := newTestInstance(t, nil)
	defer ti.ctrl.Finish()

	for endpoint, handler := range ti.postHandlers(t) {
		t.Run(string(endpoint), func(t *testing.T) {
			s := httptest.NewServer(handler)
			defer s.Close()

			url := endpoint.Path(s.URL, ti.instance.LogParameters.Prefix)
			if rsp, err := http.Get(url); err != nil {
				t.Fatalf("http.Get(%s)=(_,%q), want (_,nil)", url, err)
			} else if rsp.StatusCode != http.StatusMethodNotAllowed {
				t.Errorf("http.Get(%s)=(%d,nil), want (%d, nil)", url, rsp.StatusCode, http.StatusMethodNotAllowed)
			}
		})
	}
}

// TODO: TestHandlerPath
func TestHandlerPath(t *testing.T) {
}