diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/state/single.go | 62 | ||||
-rw-r--r-- | internal/state/single_test.go | 8 |
2 files changed, 57 insertions, 13 deletions
diff --git a/internal/state/single.go b/internal/state/single.go index fd73b3f..2c6bb4b 100644 --- a/internal/state/single.go +++ b/internal/state/single.go @@ -1,10 +1,12 @@ package state import ( + "bytes" "context" "crypto" "crypto/ed25519" "fmt" + "os" "sync" "time" @@ -24,6 +26,7 @@ type StateManagerSingle struct { interval time.Duration deadline time.Duration secondary client.Client + sthFile *os.File // Lock-protected access to pointers. A write lock is only obtained once // per interval when doing pointer rotation. All endpoints are readers. @@ -39,7 +42,7 @@ type StateManagerSingle struct { // NewStateManagerSingle() sets up a new state manager, in particular its // signedTreeHead. An optional secondary node can be used to ensure that // a newer primary tree is not signed unless it has been replicated. -func NewStateManagerSingle(dbcli db.Client, signer crypto.Signer, interval, deadline time.Duration, secondary client.Client) (*StateManagerSingle, error) { +func NewStateManagerSingle(dbcli db.Client, signer crypto.Signer, interval, deadline time.Duration, secondary client.Client, sthFile *os.File) (*StateManagerSingle, error) { sm := &StateManagerSingle{ client: dbcli, signer: signer, @@ -47,16 +50,22 @@ func NewStateManagerSingle(dbcli db.Client, signer crypto.Signer, interval, dead interval: interval, deadline: deadline, secondary: secondary, + sthFile: sthFile, } - sth, err := sm.restoreTreeHead() - if err != nil { - return nil, fmt.Errorf("restore signed tree head: %v", err) + var err error + if sm.signedTreeHead, err = sm.restoreSTH(); err != nil { + return nil, err } - sm.signedTreeHead = sth - - ictx, cancel := context.WithTimeout(context.Background(), sm.deadline) - defer cancel() - return sm, sm.tryRotate(ictx) + ctx := context.Background() + for { + err := sm.tryRotate(ctx) + if err == nil { + break + } + log.Warning("restore signed tree head: %v", err) + time.Sleep(time.Second * 3) + } + return sm, nil } func (sm *StateManagerSingle) ToCosignTreeHead() *types.SignedTreeHead { @@ -123,6 +132,10 @@ func (sm *StateManagerSingle) tryRotate(ctx context.Context) error { } log.Debug("wanted to advance to size %d, chose size %d", th.TreeSize, nextSTH.TreeSize) + if err := sm.storeSTH(nextSTH); err != nil { + return err + } + sm.rotate(nextSTH) return nil } @@ -250,9 +263,34 @@ func (sm *StateManagerSingle) treeStatusString() string { return fmt.Sprintf("signed at %d, cosigned at %d", sm.signedTreeHead.TreeSize, cosigned) } -func (sm *StateManagerSingle) restoreTreeHead() (*types.SignedTreeHead, error) { - th := zeroTreeHead() // TODO: restore from disk, stored when advanced the tree; zero tree head if "bootstrap" - return refreshTreeHead(*th).Sign(sm.signer, &sm.namespace) +func (sm *StateManagerSingle) restoreSTH() (*types.SignedTreeHead, error) { + var th types.TreeHead + b := make([]byte, 1024) + n, err := sm.sthFile.Read(b) + if err != nil { + th = *zeroTreeHead() + } else if err := th.FromASCII(bytes.NewBuffer(b[:n])); err != nil { + th = *zeroTreeHead() + } + th = *refreshTreeHead(th) + return th.Sign(sm.signer, &sm.namespace) +} + +func (sm *StateManagerSingle) storeSTH(sth *types.SignedTreeHead) error { + buf := bytes.NewBuffer(nil) + if err := sth.ToASCII(buf); err != nil { + return err + } + if err := sm.sthFile.Truncate(int64(buf.Len())); err != nil { + return err + } + if _, err := sm.sthFile.WriteAt(buf.Bytes(), 0); err != nil { + return err + } + if err := sm.sthFile.Sync(); err != nil { + return err + } + return nil } func zeroTreeHead() *types.TreeHead { diff --git a/internal/state/single_test.go b/internal/state/single_test.go index 9442fdc..a60795c 100644 --- a/internal/state/single_test.go +++ b/internal/state/single_test.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "fmt" "io" + "os" "reflect" "testing" "time" @@ -65,7 +66,12 @@ func TestNewStateManagerSingle(t *testing.T) { secondary.EXPECT().Initiated().Return(false) } - sm, err := NewStateManagerSingle(trillianClient, table.signer, time.Duration(0), time.Duration(0), secondary) + tmpFile, err := os.CreateTemp("", "sigsum-log-test-sth") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + sm, err := NewStateManagerSingle(trillianClient, table.signer, time.Duration(0), time.Duration(0), secondary, tmpFile) if got, want := err != nil, table.description != "valid"; got != want { t.Errorf("got error %v but wanted %v in test %q: %v", got, want, table.description, err) } |