aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsigsum-witness.py53
-rw-r--r--sigsum/__init__.py0
-rw-r--r--sigsum/ascii.py93
-rw-r--r--sigsum/ascii_test.py119
4 files changed, 236 insertions, 29 deletions
diff --git a/sigsum-witness.py b/sigsum-witness.py
index 5de433a..3f5ac38 100755
--- a/sigsum-witness.py
+++ b/sigsum-witness.py
@@ -23,7 +23,7 @@ import struct
import sys
import threading
import time
-from binascii import hexlify, unhexlify
+from binascii import hexlify
from hashlib import sha256
from math import floor
from pathlib import PurePath
@@ -34,6 +34,7 @@ import nacl.signing
import prometheus_client as prometheus
import requests
+from sigsum import ascii
from tools.libsigntools import ssh_to_sign
BASE_URL_DEFAULT = 'http://poc.sigsum.org:4780/'
@@ -169,44 +170,39 @@ def parse_keyval(text):
class TreeHead:
def __init__(self, sth_data):
- self._text = parse_keyval(sth_data)
- assert(len(self._text) == 4)
- assert('timestamp' in self._text)
- assert('tree_size' in self._text)
- assert('root_hash' in self._text)
- assert('signature' in self._text)
+ self._data = ascii.loads(sth_data)
+ assert(len(self._data) == 4)
+ assert('timestamp' in self._data)
+ assert('tree_size' in self._data)
+ assert('root_hash' in self._data)
+ assert('signature' in self._data)
@property
def timestamp(self):
- return int(self._text['timestamp'])
+ return self._data.getint('timestamp')
@property
def tree_size(self):
- return int(self._text['tree_size'])
+ return self._data.getint('tree_size')
@property
def root_hash(self):
- return unhexlify(self._text['root_hash'])
+ return self._data.getbytes('root_hash')
def text(self):
- text = 'timestamp={}\n'.format(self._text['timestamp'])
- text += 'tree_size={}\n'.format(self._text['tree_size'])
- text += 'root_hash={}\n'.format(self._text['root_hash'])
- text += 'signature={}\n'.format(self._text['signature'])
- return text.encode('ascii')
+ return ascii.dumps(self._data).encode('ascii')
def to_signed_data(self, pubkey):
namespace = 'tree_head:v0:{}@sigsum.org'.format(hexlify(sha256(pubkey.encode()).digest()).decode())
msg = struct.pack('!QQ', self.timestamp, self.tree_size)
- msg += unhexlify(self._text['root_hash'])
+ msg += self.root_hash
assert(len(msg) == 8 + 8 + 32)
return ssh_to_sign(namespace, 'sha256', sha256(msg).digest())
def signature_valid(self, pubkey):
# Guard against tree head with >1 signature -- don't try to
# validate a cosigned tree head.
- assert(type(self._text['signature']) is str)
- sig = unhexlify(self._text['signature'])
+ sig = self._data.getbytes('signature')
assert(len(sig) == 64)
data = self.to_signed_data(pubkey)
try:
@@ -286,19 +282,18 @@ class ConsistencyProof():
def __init__(self, old_size, new_size, consistency_proof_data):
self._old_size = old_size
self._new_size = new_size
- self._text = parse_keyval(consistency_proof_data)
- assert(len(self._text) == 1)
- assert('consistency_path' in self._text)
+ self._data = ascii.loads(consistency_proof_data)
+ assert(len(self._data) == 1)
+ assert('consistency_path' in self._data)
def old_size(self):
return self._old_size
def new_size(self):
return self._new_size
+
def path(self):
- if type(self._text['consistency_path']) is list:
- return [unhexlify(e) for e in self._text['consistency_path']]
- else:
- return [unhexlify(self._text['consistency_path'])]
+ return self._data.getbytes('consistency_path', many=True)
+
def make_base_dir_maybe():
dirname = os.path.expanduser(g_args.base_dir)
@@ -410,10 +405,10 @@ def consistency_proof_valid(first, second, proof):
def sign_send_store_tree_head(signing_key, log_key, tree_head):
signature = signing_key.sign(tree_head.to_signed_data(log_key)).signature
hash = sha256(signing_key.verify_key.encode())
-
- post_data = 'cosignature={}\n'.format(hexlify(signature).decode('ascii'))
- post_data += 'key_hash={}\n'.format(hash.hexdigest())
-
+ post_data = ascii.dumps({
+ 'cosignature': signature.hex(),
+ 'key_hash': hash.hexdigest(),
+ })
try:
req = requests.post(g_args.base_url + 'sigsum/v0/add-cosignature', post_data)
except requests.ConnectionError as err:
diff --git a/sigsum/__init__.py b/sigsum/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/sigsum/__init__.py
diff --git a/sigsum/ascii.py b/sigsum/ascii.py
new file mode 100644
index 0000000..f7f378c
--- /dev/null
+++ b/sigsum/ascii.py
@@ -0,0 +1,93 @@
+import io
+
+
+def dumps(data):
+ """
+ dumps takes a key/values mapping and serializes it to ASCII.
+ If one of the values is not of type str, int or bytes (or a list of those)
+ a TypeError is raised.
+ """
+ res = io.StringIO()
+ for key in data:
+ values = data[key]
+ if not isinstance(values, list):
+ values = [values]
+ for val in values:
+ if isinstance(val, (int, str)):
+ res.write(f"{key}={val}\n")
+ elif isinstance(val, bytes):
+ res.write(f"{key}={val.hex()}\n")
+ else:
+ raise TypeError(
+ f"Object of type {type(val).__name__} is not ASCII serializable"
+ )
+ res.seek(0)
+ return res.read()
+
+
+def loads(txt):
+ """
+ loads deserialized the given string into an ASCIIValue.
+ """
+ kv = []
+ for lno, line in enumerate(txt.splitlines(), 1):
+ if "=" not in line:
+ raise ASCIIDecodeError("Expecting '=' delimiter line 1")
+ (key, val) = line.rstrip().split("=", 1)
+ if val == "":
+ raise ASCIIDecodeError("Expecting value after '=' line 1")
+ kv.append((key, val))
+ return ASCIIValue(kv)
+
+
+class ASCIIDecodeError(Exception):
+ """
+ ASCIIDecodeError indicates that loads couldn't deserialize the given input.
+ """
+
+
+class ASCIIValue:
+ """
+ ASCIIValue implements Mapping[str, List[str]] with convenience getters to
+ parse sigsum types.
+ """
+
+ def __init__(self, data):
+ self._d = {}
+ for k, v in data:
+ self._d.setdefault(k, []).append(v)
+
+ def __getitem__(self, k):
+ return self._d.__getitem__(k)
+
+ def __len__(self):
+ return self._d.__len__()
+
+ def __iter__(self):
+ return self._d.__iter__()
+
+ def getone(self, k):
+ v = self._d[k]
+ if len(v) > 1:
+ raise ValueError(f"{k}: expected a single value, got {len(v)}")
+ return self._d[k][0]
+
+ def getint(self, k, many=False):
+ if many:
+ return [int(x) for x in self._d[k]]
+ return int(self.getone(k))
+
+ def getbytes(self, k, many=False):
+ if many:
+ return [bytes.fromhex(x) for x in self._d[k]]
+ return bytes.fromhex(self.getone(k))
+
+ def __repr__(self):
+ return f'ASCIIValue([{", ".join(f"({k!r}, {v!r})" for k,vs in self._d.items() for v in vs)}])'
+
+ def __eq__(self, other):
+ if isinstance(other, ASCIIValue):
+ return self._d.__eq__(other._d)
+ if isinstance(other, dict):
+ return self._d.__eq__(other)
+ return NotImplemented
diff --git a/sigsum/ascii_test.py b/sigsum/ascii_test.py
new file mode 100644
index 0000000..6dfe025
--- /dev/null
+++ b/sigsum/ascii_test.py
@@ -0,0 +1,119 @@
+import io
+import operator
+from operator import methodcaller as M
+
+from . import ascii
+
+
+def test():
+ pass
+
+
+import pytest
+
+
+@pytest.mark.parametrize(
+ "txt, expected",
+ [
+ ("", {}),
+ ("foo=bar", {"foo": ["bar"]}),
+ ("foo=bar\nqux=42", {"foo": ["bar"], "qux": ["42"]}),
+ ("foo=bar\nfoo=biz", {"foo": ["bar", "biz"]}),
+ ("error=something went wrong", {"error": ["something went wrong"]}),
+ ("error=a message with an = sign", {"error": ["a message with an = sign"]}),
+ ],
+)
+def test_loads(txt, expected):
+ assert ascii.loads(txt) == expected
+
+
+@pytest.mark.parametrize(
+ "txt, message",
+ [
+ ("foo", "Expecting '=' delimiter line 1"),
+ ("foo=", "Expecting value after '=' line 1"),
+ ],
+)
+def test_loads_error(txt, message):
+ with pytest.raises(ascii.ASCIIDecodeError, match=message):
+ ascii.loads(txt)
+
+
+@pytest.mark.parametrize(
+ "data, expected",
+ [
+ ({}, ""),
+ ({"foo": ["bar"], "baz": ["biz"]}, "foo=bar\nbaz=biz\n"),
+ ({"foo": ["bar", "baz"]}, "foo=bar\nfoo=baz\n"),
+ ({"foo": [42]}, "foo=42\n"),
+ ({"foo": [b"\xDE\xAD\xBE\xEF"]}, "foo=deadbeef\n"),
+ ({"foo": "bar"}, "foo=bar\n"),
+ ],
+ ids=["empty", "simple", "list", "int", "bytes", "single-value-shortcut"],
+)
+def test_dumps(data, expected):
+ assert ascii.dumps(data) == expected
+
+
+def test_dumps_type_error():
+ with pytest.raises(
+ TypeError, match="Object of type object is not ASCII serializable"
+ ):
+ ascii.dumps({"foo": [object()]})
+
+
+@pytest.mark.parametrize(
+ "data, func, expected",
+ [
+ # Check that it behave like a Mapping[str, List[str]]
+ ([("foo", "bar"), ("foo", "baz")], operator.itemgetter("foo"), ["bar", "baz"]),
+ ([("foo", "bar"), ("foo", "baz")], len, 1),
+ ([("foo", "bar"), ("foo", "baz")], lambda x: list(iter(x)), ["foo"]),
+ # Check accessors
+ ([("foo", "bar")], M("getone", "foo"), "bar"),
+ ([("foo", "42")], M("getint", "foo"), 42),
+ ([("foo", "deadbeef")], M("getbytes", "foo"), b"\xDE\xAD\xBE\xEF"),
+ ([("foo", "42"), ("foo", "0")], M("getint", "foo", True), [42, 0]),
+ (
+ [("foo", "dead"), ("foo", "beef")],
+ M("getbytes", "foo", True),
+ [b"\xDE\xAD", b"\xBE\xEF"],
+ ),
+ ],
+)
+def test_asciivalue_getters(data, func, expected):
+ kv = ascii.ASCIIValue(data)
+ assert func(kv) == expected
+
+
+@pytest.mark.parametrize(
+ "data, func, error",
+ [
+ # missing key
+ ([], M("getone", "foo"), KeyError),
+ ([], M("getint", "foo"), KeyError),
+ ([], M("getbytes", "foo"), KeyError),
+ # too many values
+ ([("foo", "bar"), ("foo", "baz")], M("getone", "foo"), ValueError),
+ ([("foo", "42"), ("foo", "0")], M("getint", "foo"), ValueError),
+ ([("foo", "dead"), ("foo", "beef")], M("getbytes", "foo"), ValueError),
+ # strconv errors
+ ([("foo", "xx")], M("getint", "foo"), ValueError),
+ ([("foo", "xx")], M("getbytes", "foo"), ValueError),
+ ],
+)
+def test_asciivalue_getters_errorrs(data, func, error):
+ kv = ascii.ASCIIValue(data)
+ with pytest.raises(error):
+ func(kv)
+
+
+def test_asciivalue_repr():
+ v = ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")])
+ assert repr(v) == "ASCIIValue([('foo', 'bar'), ('foo', 'baz'), ('qux', 'quux')])"
+
+
+def test_asciivalue_eq():
+ v = ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")])
+ assert v == ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")])
+ assert v == {"foo": ["bar", "baz"], "qux": ["quux"]}