diff options
Diffstat (limited to 'types/ascii.go')
| -rw-r--r-- | types/ascii.go | 413 | 
1 files changed, 413 insertions, 0 deletions
| diff --git a/types/ascii.go b/types/ascii.go new file mode 100644 index 0000000..88b372e --- /dev/null +++ b/types/ascii.go @@ -0,0 +1,413 @@ +package types + +import ( +	"bytes" +	"encoding/hex" +	"fmt" +	"io" +	"strconv" +) + +const ( +	// Delim is a key-value separator +	Delim = "=" + +	// EOL is a line sepator +	EOL = "\n" + +	// NumField* is the number of unique keys in an incoming ASCII message +	NumFieldLeaf                    = 4 +	NumFieldSignedTreeHead          = 5 +	NumFieldConsistencyProof        = 3 +	NumFieldInclusionProof          = 3 +	NumFieldLeavesRequest           = 2 +	NumFieldInclusionProofRequest   = 2 +	NumFieldConsistencyProofRequest = 2 +	NumFieldLeafRequest             = 5 +	NumFieldCosignatureRequest      = 2 + +	// New leaf keys +	ShardHint            = "shard_hint" +	Checksum             = "checksum" +	SignatureOverMessage = "signature_over_message" +	VerificationKey      = "verification_key" +	DomainHint           = "domain_hint" + +	// Inclusion proof keys +	LeafHash      = "leaf_hash" +	LeafIndex     = "leaf_index" +	InclusionPath = "inclusion_path" + +	// Consistency proof keys +	NewSize         = "new_size" +	OldSize         = "old_size" +	ConsistencyPath = "consistency_path" + +	// Range of leaves keys +	StartSize = "start_size" +	EndSize   = "end_size" + +	// Tree head keys +	Timestamp = "timestamp" +	TreeSize  = "tree_size" +	RootHash  = "root_hash" + +	// Signature and signer-identity keys +	Signature = "signature" +	KeyHash   = "key_hash" +) + +// MessageASCI is a wrapper that manages ASCII key-value pairs +type MessageASCII struct { +	m map[string][]string +} + +// NewMessageASCII unpacks an incoming ASCII message +func NewMessageASCII(r io.Reader, numFieldExpected int) (*MessageASCII, error) { +	buf, err := io.ReadAll(r) +	if err != nil { +		return nil, fmt.Errorf("ReadAll: %v", err) +	} +	lines := bytes.Split(buf, []byte(EOL)) +	if len(lines) <= 1 { +		return nil, fmt.Errorf("Not enough lines: empty") +	} +	lines = lines[:len(lines)-1] // valid message => split gives empty last line + +	msg := MessageASCII{make(map[string][]string)} +	for _, line := range lines { +		split := bytes.Index(line, []byte(Delim)) +		if split == -1 { +			return nil, fmt.Errorf("invalid line: %v", string(line)) +		} + +		key := string(line[:split]) +		value := string(line[split+len(Delim):]) +		values, ok := msg.m[key] +		if !ok { +			values = nil +			msg.m[key] = values +		} +		msg.m[key] = append(values, value) +	} + +	if msg.NumField() != numFieldExpected { +		return nil, fmt.Errorf("Unexpected number of keys: %v", msg.NumField()) +	} +	return &msg, nil +} + +// NumField returns the number of unique keys +func (msg *MessageASCII) NumField() int { +	return len(msg.m) +} + +// GetStrings returns a list of strings +func (msg *MessageASCII) GetStrings(key string) []string { +	strs, ok := msg.m[key] +	if !ok { +		return nil +	} +	return strs +} + +// GetString unpacks a string +func (msg *MessageASCII) GetString(key string) (string, error) { +	strs := msg.GetStrings(key) +	if len(strs) != 1 { +		return "", fmt.Errorf("expected one string: %v", strs) +	} +	return strs[0], nil +} + +// GetUint64 unpacks an uint64 +func (msg *MessageASCII) GetUint64(key string) (uint64, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return 0, fmt.Errorf("GetString: %v", err) +	} +	num, err := strconv.ParseUint(str, 10, 64) +	if err != nil { +		return 0, fmt.Errorf("ParseUint: %v", err) +	} +	return num, nil +} + +// GetHash unpacks a hash +func (msg *MessageASCII) GetHash(key string) (*[HashSize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var hash [HashSize]byte +	if err := decodeHex(str, hash[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &hash, nil +} + +// GetSignature unpacks a signature +func (msg *MessageASCII) GetSignature(key string) (*[SignatureSize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var signature [SignatureSize]byte +	if err := decodeHex(str, signature[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &signature, nil +} + +// GetVerificationKey unpacks a verification key +func (msg *MessageASCII) GetVerificationKey(key string) (*[VerificationKeySize]byte, error) { +	str, err := msg.GetString(key) +	if err != nil { +		return nil, fmt.Errorf("GetString: %v", err) +	} + +	var vk [VerificationKeySize]byte +	if err := decodeHex(str, vk[:]); err != nil { +		return nil, fmt.Errorf("decodeHex: %v", err) +	} +	return &vk, nil +} + +// decodeHex decodes a hex-encoded string into an already-sized byte slice +func decodeHex(str string, out []byte) error { +	buf, err := hex.DecodeString(str) +	if err != nil { +		return fmt.Errorf("DecodeString: %v", err) +	} +	if len(buf) != len(out) { +		return fmt.Errorf("invalid length: %v", len(buf)) +	} +	copy(out, buf) +	return nil +} + +/* + * + * MarshalASCII wrappers for types that the log server outputs + * + */ +func (l *Leaf) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, ShardHint, strconv.FormatUint(l.ShardHint, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, Checksum, hex.EncodeToString(l.Checksum[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, Signature, hex.EncodeToString(l.Signature[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, KeyHash, hex.EncodeToString(l.KeyHash[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	return nil +} + +func (sth *SignedTreeHead) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, Timestamp, strconv.FormatUint(sth.Timestamp, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, TreeSize, strconv.FormatUint(sth.TreeSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, RootHash, hex.EncodeToString(sth.RootHash[:])); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, sigident := range sth.SigIdent { +		if err := writeASCII(w, Signature, hex.EncodeToString(sigident.Signature[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +		if err := writeASCII(w, KeyHash, hex.EncodeToString(sigident.KeyHash[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +	} +	return nil +} + +func (p *ConsistencyProof) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, NewSize, strconv.FormatUint(p.NewSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, OldSize, strconv.FormatUint(p.OldSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, hash := range p.Path { +		if err := writeASCII(w, ConsistencyPath, hex.EncodeToString(hash[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +	} +	return nil +} + +func (p *InclusionProof) MarshalASCII(w io.Writer) error { +	if err := writeASCII(w, TreeSize, strconv.FormatUint(p.TreeSize, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	if err := writeASCII(w, LeafIndex, strconv.FormatUint(p.LeafIndex, 10)); err != nil { +		return fmt.Errorf("writeASCII: %v", err) +	} +	for _, hash := range p.Path { +		if err := writeASCII(w, InclusionPath, hex.EncodeToString(hash[:])); err != nil { +			return fmt.Errorf("writeASCII: %v", err) +		} +	} +	return nil +} + +func writeASCII(w io.Writer, key, value string) error { +	if _, err := fmt.Fprintf(w, "%s%s%s%s", key, Delim, value, EOL); err != nil { +		return fmt.Errorf("Fprintf: %v", err) +	} +	return nil +} + +/* + * + * Unmarshal ASCII wrappers that the log server and/or log clients receive. + * + */ +func (ll *LeafList) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (sth *SignedTreeHead) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldSignedTreeHead) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	// TreeHead +	if sth.Timestamp, err = msg.GetUint64(Timestamp); err != nil { +		return fmt.Errorf("GetUint64(Timestamp): %v", err) +	} +	if sth.TreeSize, err = msg.GetUint64(TreeSize); err != nil { +		return fmt.Errorf("GetUint64(TreeSize): %v", err) +	} +	if sth.RootHash, err = msg.GetHash(RootHash); err != nil { +		return fmt.Errorf("GetHash(RootHash): %v", err) +	} + +	// SigIdent +	signatures := msg.GetStrings(Signature) +	if len(signatures) == 0 { +		return fmt.Errorf("no signer") +	} +	keyHashes := msg.GetStrings(KeyHash) +	if len(signatures) != len(keyHashes) { +		return fmt.Errorf("mismatched signature-signer count") +	} +	sth.SigIdent = make([]*SigIdent, 0, len(signatures)) +	for i, n := 0, len(signatures); i < n; i++ { +		var signature [SignatureSize]byte +		if err := decodeHex(signatures[i], signature[:]); err != nil { +			return fmt.Errorf("decodeHex: %v", err) +		} +		var hash [HashSize]byte +		if err := decodeHex(keyHashes[i], hash[:]); err != nil { +			return fmt.Errorf("decodeHex: %v", err) +		} +		sth.SigIdent = append(sth.SigIdent, &SigIdent{ +			Signature: &signature, +			KeyHash:   &hash, +		}) +	} +	return nil +} + +func (p *InclusionProof) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (p *ConsistencyProof) UnmarshalASCII(r io.Reader) error { +	return nil +} + +func (req *InclusionProofRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldInclusionProofRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.LeafHash, err = msg.GetHash(LeafHash); err != nil { +		return fmt.Errorf("GetHash(LeafHash): %v", err) +	} +	if req.TreeSize, err = msg.GetUint64(TreeSize); err != nil { +		return fmt.Errorf("GetUint64(TreeSize): %v", err) +	} +	return nil +} + +func (req *ConsistencyProofRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldConsistencyProofRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.NewSize, err = msg.GetUint64(NewSize); err != nil { +		return fmt.Errorf("GetUint64(NewSize): %v", err) +	} +	if req.OldSize, err = msg.GetUint64(OldSize); err != nil { +		return fmt.Errorf("GetUint64(OldSize): %v", err) +	} +	return nil +} + +func (req *LeavesRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldLeavesRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.StartSize, err = msg.GetUint64(StartSize); err != nil { +		return fmt.Errorf("GetUint64(StartSize): %v", err) +	} +	if req.EndSize, err = msg.GetUint64(EndSize); err != nil { +		return fmt.Errorf("GetUint64(EndSize): %v", err) +	} +	return nil +} + +func (req *LeafRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldLeafRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.ShardHint, err = msg.GetUint64(ShardHint); err != nil { +		return fmt.Errorf("GetUint64(ShardHint): %v", err) +	} +	if req.Checksum, err = msg.GetHash(Checksum); err != nil { +		return fmt.Errorf("GetHash(Checksum): %v", err) +	} +	if req.Signature, err = msg.GetSignature(Signature); err != nil { +		return fmt.Errorf("GetSignature: %v", err) +	} +	if req.VerificationKey, err = msg.GetVerificationKey(VerificationKey); err != nil { +		return fmt.Errorf("GetVerificationKey: %v", err) +	} +	if req.DomainHint, err = msg.GetString(DomainHint); err != nil { +		return fmt.Errorf("GetString(DomainHint): %v", err) +	} +	return nil +} + +func (req *CosignatureRequest) UnmarshalASCII(r io.Reader) error { +	msg, err := NewMessageASCII(r, NumFieldCosignatureRequest) +	if err != nil { +		return fmt.Errorf("NewMessageASCII: %v", err) +	} + +	if req.Signature, err = msg.GetSignature(Signature); err != nil { +		return fmt.Errorf("GetSignature: %v", err) +	} +	if req.KeyHash, err = msg.GetHash(KeyHash); err != nil { +		return fmt.Errorf("GetHash(KeyHash): %v", err) +	} +	return nil +} | 
