diff options
-rwxr-xr-x | sigsum-witness.py | 53 | ||||
-rw-r--r-- | sigsum/__init__.py | 0 | ||||
-rw-r--r-- | sigsum/ascii.py | 93 | ||||
-rw-r--r-- | sigsum/ascii_test.py | 119 |
4 files changed, 236 insertions, 29 deletions
diff --git a/sigsum-witness.py b/sigsum-witness.py index 5de433a..3f5ac38 100755 --- a/sigsum-witness.py +++ b/sigsum-witness.py @@ -23,7 +23,7 @@ import struct import sys import threading import time -from binascii import hexlify, unhexlify +from binascii import hexlify from hashlib import sha256 from math import floor from pathlib import PurePath @@ -34,6 +34,7 @@ import nacl.signing import prometheus_client as prometheus import requests +from sigsum import ascii from tools.libsigntools import ssh_to_sign BASE_URL_DEFAULT = 'http://poc.sigsum.org:4780/' @@ -169,44 +170,39 @@ def parse_keyval(text): class TreeHead: def __init__(self, sth_data): - self._text = parse_keyval(sth_data) - assert(len(self._text) == 4) - assert('timestamp' in self._text) - assert('tree_size' in self._text) - assert('root_hash' in self._text) - assert('signature' in self._text) + self._data = ascii.loads(sth_data) + assert(len(self._data) == 4) + assert('timestamp' in self._data) + assert('tree_size' in self._data) + assert('root_hash' in self._data) + assert('signature' in self._data) @property def timestamp(self): - return int(self._text['timestamp']) + return self._data.getint('timestamp') @property def tree_size(self): - return int(self._text['tree_size']) + return self._data.getint('tree_size') @property def root_hash(self): - return unhexlify(self._text['root_hash']) + return self._data.getbytes('root_hash') def text(self): - text = 'timestamp={}\n'.format(self._text['timestamp']) - text += 'tree_size={}\n'.format(self._text['tree_size']) - text += 'root_hash={}\n'.format(self._text['root_hash']) - text += 'signature={}\n'.format(self._text['signature']) - return text.encode('ascii') + return ascii.dumps(self._data).encode('ascii') def to_signed_data(self, pubkey): namespace = 'tree_head:v0:{}@sigsum.org'.format(hexlify(sha256(pubkey.encode()).digest()).decode()) msg = struct.pack('!QQ', self.timestamp, self.tree_size) - msg += unhexlify(self._text['root_hash']) + msg += self.root_hash assert(len(msg) == 8 + 8 + 32) return ssh_to_sign(namespace, 'sha256', sha256(msg).digest()) def signature_valid(self, pubkey): # Guard against tree head with >1 signature -- don't try to # validate a cosigned tree head. - assert(type(self._text['signature']) is str) - sig = unhexlify(self._text['signature']) + sig = self._data.getbytes('signature') assert(len(sig) == 64) data = self.to_signed_data(pubkey) try: @@ -286,19 +282,18 @@ class ConsistencyProof(): def __init__(self, old_size, new_size, consistency_proof_data): self._old_size = old_size self._new_size = new_size - self._text = parse_keyval(consistency_proof_data) - assert(len(self._text) == 1) - assert('consistency_path' in self._text) + self._data = ascii.loads(consistency_proof_data) + assert(len(self._data) == 1) + assert('consistency_path' in self._data) def old_size(self): return self._old_size def new_size(self): return self._new_size + def path(self): - if type(self._text['consistency_path']) is list: - return [unhexlify(e) for e in self._text['consistency_path']] - else: - return [unhexlify(self._text['consistency_path'])] + return self._data.getbytes('consistency_path', many=True) + def make_base_dir_maybe(): dirname = os.path.expanduser(g_args.base_dir) @@ -410,10 +405,10 @@ def consistency_proof_valid(first, second, proof): def sign_send_store_tree_head(signing_key, log_key, tree_head): signature = signing_key.sign(tree_head.to_signed_data(log_key)).signature hash = sha256(signing_key.verify_key.encode()) - - post_data = 'cosignature={}\n'.format(hexlify(signature).decode('ascii')) - post_data += 'key_hash={}\n'.format(hash.hexdigest()) - + post_data = ascii.dumps({ + 'cosignature': signature.hex(), + 'key_hash': hash.hexdigest(), + }) try: req = requests.post(g_args.base_url + 'sigsum/v0/add-cosignature', post_data) except requests.ConnectionError as err: diff --git a/sigsum/__init__.py b/sigsum/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/sigsum/__init__.py diff --git a/sigsum/ascii.py b/sigsum/ascii.py new file mode 100644 index 0000000..f7f378c --- /dev/null +++ b/sigsum/ascii.py @@ -0,0 +1,93 @@ +import io + + +def dumps(data): + """ + dumps takes a key/values mapping and serializes it to ASCII. + If one of the values is not of type str, int or bytes (or a list of those) + a TypeError is raised. + """ + res = io.StringIO() + for key in data: + values = data[key] + if not isinstance(values, list): + values = [values] + for val in values: + if isinstance(val, (int, str)): + res.write(f"{key}={val}\n") + elif isinstance(val, bytes): + res.write(f"{key}={val.hex()}\n") + else: + raise TypeError( + f"Object of type {type(val).__name__} is not ASCII serializable" + ) + res.seek(0) + return res.read() + + +def loads(txt): + """ + loads deserialized the given string into an ASCIIValue. + """ + kv = [] + for lno, line in enumerate(txt.splitlines(), 1): + if "=" not in line: + raise ASCIIDecodeError("Expecting '=' delimiter line 1") + (key, val) = line.rstrip().split("=", 1) + if val == "": + raise ASCIIDecodeError("Expecting value after '=' line 1") + kv.append((key, val)) + return ASCIIValue(kv) + + +class ASCIIDecodeError(Exception): + """ + ASCIIDecodeError indicates that loads couldn't deserialize the given input. + """ + + +class ASCIIValue: + """ + ASCIIValue implements Mapping[str, List[str]] with convenience getters to + parse sigsum types. + """ + + def __init__(self, data): + self._d = {} + for k, v in data: + self._d.setdefault(k, []).append(v) + + def __getitem__(self, k): + return self._d.__getitem__(k) + + def __len__(self): + return self._d.__len__() + + def __iter__(self): + return self._d.__iter__() + + def getone(self, k): + v = self._d[k] + if len(v) > 1: + raise ValueError(f"{k}: expected a single value, got {len(v)}") + return self._d[k][0] + + def getint(self, k, many=False): + if many: + return [int(x) for x in self._d[k]] + return int(self.getone(k)) + + def getbytes(self, k, many=False): + if many: + return [bytes.fromhex(x) for x in self._d[k]] + return bytes.fromhex(self.getone(k)) + + def __repr__(self): + return f'ASCIIValue([{", ".join(f"({k!r}, {v!r})" for k,vs in self._d.items() for v in vs)}])' + + def __eq__(self, other): + if isinstance(other, ASCIIValue): + return self._d.__eq__(other._d) + if isinstance(other, dict): + return self._d.__eq__(other) + return NotImplemented diff --git a/sigsum/ascii_test.py b/sigsum/ascii_test.py new file mode 100644 index 0000000..6dfe025 --- /dev/null +++ b/sigsum/ascii_test.py @@ -0,0 +1,119 @@ +import io +import operator +from operator import methodcaller as M + +from . import ascii + + +def test(): + pass + + +import pytest + + +@pytest.mark.parametrize( + "txt, expected", + [ + ("", {}), + ("foo=bar", {"foo": ["bar"]}), + ("foo=bar\nqux=42", {"foo": ["bar"], "qux": ["42"]}), + ("foo=bar\nfoo=biz", {"foo": ["bar", "biz"]}), + ("error=something went wrong", {"error": ["something went wrong"]}), + ("error=a message with an = sign", {"error": ["a message with an = sign"]}), + ], +) +def test_loads(txt, expected): + assert ascii.loads(txt) == expected + + +@pytest.mark.parametrize( + "txt, message", + [ + ("foo", "Expecting '=' delimiter line 1"), + ("foo=", "Expecting value after '=' line 1"), + ], +) +def test_loads_error(txt, message): + with pytest.raises(ascii.ASCIIDecodeError, match=message): + ascii.loads(txt) + + +@pytest.mark.parametrize( + "data, expected", + [ + ({}, ""), + ({"foo": ["bar"], "baz": ["biz"]}, "foo=bar\nbaz=biz\n"), + ({"foo": ["bar", "baz"]}, "foo=bar\nfoo=baz\n"), + ({"foo": [42]}, "foo=42\n"), + ({"foo": [b"\xDE\xAD\xBE\xEF"]}, "foo=deadbeef\n"), + ({"foo": "bar"}, "foo=bar\n"), + ], + ids=["empty", "simple", "list", "int", "bytes", "single-value-shortcut"], +) +def test_dumps(data, expected): + assert ascii.dumps(data) == expected + + +def test_dumps_type_error(): + with pytest.raises( + TypeError, match="Object of type object is not ASCII serializable" + ): + ascii.dumps({"foo": [object()]}) + + +@pytest.mark.parametrize( + "data, func, expected", + [ + # Check that it behave like a Mapping[str, List[str]] + ([("foo", "bar"), ("foo", "baz")], operator.itemgetter("foo"), ["bar", "baz"]), + ([("foo", "bar"), ("foo", "baz")], len, 1), + ([("foo", "bar"), ("foo", "baz")], lambda x: list(iter(x)), ["foo"]), + # Check accessors + ([("foo", "bar")], M("getone", "foo"), "bar"), + ([("foo", "42")], M("getint", "foo"), 42), + ([("foo", "deadbeef")], M("getbytes", "foo"), b"\xDE\xAD\xBE\xEF"), + ([("foo", "42"), ("foo", "0")], M("getint", "foo", True), [42, 0]), + ( + [("foo", "dead"), ("foo", "beef")], + M("getbytes", "foo", True), + [b"\xDE\xAD", b"\xBE\xEF"], + ), + ], +) +def test_asciivalue_getters(data, func, expected): + kv = ascii.ASCIIValue(data) + assert func(kv) == expected + + +@pytest.mark.parametrize( + "data, func, error", + [ + # missing key + ([], M("getone", "foo"), KeyError), + ([], M("getint", "foo"), KeyError), + ([], M("getbytes", "foo"), KeyError), + # too many values + ([("foo", "bar"), ("foo", "baz")], M("getone", "foo"), ValueError), + ([("foo", "42"), ("foo", "0")], M("getint", "foo"), ValueError), + ([("foo", "dead"), ("foo", "beef")], M("getbytes", "foo"), ValueError), + # strconv errors + ([("foo", "xx")], M("getint", "foo"), ValueError), + ([("foo", "xx")], M("getbytes", "foo"), ValueError), + ], +) +def test_asciivalue_getters_errorrs(data, func, error): + kv = ascii.ASCIIValue(data) + with pytest.raises(error): + func(kv) + + +def test_asciivalue_repr(): + v = ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")]) + assert repr(v) == "ASCIIValue([('foo', 'bar'), ('foo', 'baz'), ('qux', 'quux')])" + + +def test_asciivalue_eq(): + v = ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")]) + assert v == ascii.ASCIIValue([("foo", "bar"), ("foo", "baz"), ("qux", "quux")]) + assert v == {"foo": ["bar", "baz"], "qux": ["quux"]} |