aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorLinus Nordberg <linus@nordberg.se>2022-04-28 15:46:01 +0200
committerLinus Nordberg <linus@nordberg.se>2022-04-28 15:46:01 +0200
commit47490448be1b7006641e6badc6a84b1441b27698 (patch)
treefb386e9e6ccb90b368da63d0a8085d114fd8431c /internal
parent2dcd7bca2f3e69fb6f1770ec0bf740d8956978ca (diff)
parentb270a4c0d10947fe480bad7330b31bb793225968 (diff)
Merge branch 'merge/sigsum-debug'
Diffstat (limited to 'internal')
-rw-r--r--internal/fmtio/fmtio.go79
-rw-r--r--internal/options/options.go65
2 files changed, 144 insertions, 0 deletions
diff --git a/internal/fmtio/fmtio.go b/internal/fmtio/fmtio.go
new file mode 100644
index 0000000..0e252d4
--- /dev/null
+++ b/internal/fmtio/fmtio.go
@@ -0,0 +1,79 @@
+// package fmtio provides basic utilities to format input and output
+package fmtio
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/ed25519"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "git.sigsum.org/sigsum-go/pkg/hex"
+ "git.sigsum.org/sigsum-go/pkg/types"
+)
+
+func BytesFromStdin() ([]byte, error) {
+ b, err := ioutil.ReadAll(os.Stdin)
+ if err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+// StringFromStdin reads bytes from stdin, parsing them as a string without
+// leading and trailing white space
+func StringFromStdin() (string, error) {
+ b, err := ioutil.ReadAll(os.Stdin)
+ if err != nil {
+ return "", err
+ }
+ return string(bytes.TrimSpace(b)), nil
+}
+
+func SignerFromHex(s string) (crypto.Signer, error) {
+ b, err := hex.Deserialize(s)
+ if err != nil {
+ return nil, err
+ }
+ if n := len(b); n != ed25519.PrivateKeySize {
+ return nil, fmt.Errorf("invalid size %d", n)
+ }
+ return ed25519.PrivateKey(b), nil
+}
+
+func PublicKeyFromHex(s string) (pub types.PublicKey, err error) {
+ b, err := hex.Deserialize(s)
+ if err != nil {
+ return pub, err
+ }
+ if n := len(b); n != types.PublicKeySize {
+ return pub, fmt.Errorf("invalid size %d", n)
+ }
+ copy(pub[:], b)
+ return
+}
+
+func KeyHashFromHex(s string) (h types.Hash, err error) {
+ b, err := hex.Deserialize(s)
+ if err != nil {
+ return h, err
+ }
+ if n := len(b); n != types.HashSize {
+ return h, fmt.Errorf("invalid size %d", n)
+ }
+ copy(h[:], b)
+ return
+}
+
+func SignatureFromHex(s string) (sig types.Signature, err error) {
+ b, err := hex.Deserialize(s)
+ if err != nil {
+ return sig, err
+ }
+ if n := len(b); n != types.SignatureSize {
+ return sig, fmt.Errorf("invalid size %d", n)
+ }
+ copy(sig[:], b)
+ return
+}
diff --git a/internal/options/options.go b/internal/options/options.go
new file mode 100644
index 0000000..8e4ab0c
--- /dev/null
+++ b/internal/options/options.go
@@ -0,0 +1,65 @@
+package options
+
+import (
+ "flag"
+ "fmt"
+)
+
+const (
+ DefaultString = "default string"
+ DefaultUint64 = 18446744073709551615
+)
+
+// New initializes a flag set using the provided arguments.
+//
+// - args should start with the (sub)command's name
+// - usage is a function that prints a usage message
+// - set is a function that sets the command's flag arguments
+//
+func New(args []string, usage func(), set func(*flag.FlagSet)) *flag.FlagSet {
+ if len(args) == 0 {
+ args = append(args, "")
+ }
+
+ fs := flag.NewFlagSet(args[0], flag.ExitOnError)
+ fs.Usage = func() {
+ usage()
+ }
+ set(fs)
+ fs.Parse(args[1:])
+ return fs
+}
+
+// AddString adds a string option to a flag set
+func AddString(fs *flag.FlagSet, opt *string, short, long, value string) {
+ fs.StringVar(opt, short, value, "")
+ fs.StringVar(opt, long, value, "")
+}
+
+// AddUint64 adds an uint64 option to a flag set
+func AddUint64(fs *flag.FlagSet, opt *uint64, short, long string, value uint64) {
+ fs.Uint64Var(opt, short, value, "")
+ fs.Uint64Var(opt, long, value, "")
+}
+
+// CheckString checks that a string option has a non-default value
+func CheckString(optionName, value string, err error) error {
+ if err != nil {
+ return err
+ }
+ if value == DefaultString {
+ return fmt.Errorf("%s is a required option", optionName)
+ }
+ return nil
+}
+
+// CheckUint64 checks that an uint64 option has a non-default value
+func CheckUint64(optionName string, value uint64, err error) error {
+ if err != nil {
+ return err
+ }
+ if value == DefaultUint64 {
+ return fmt.Errorf("%s is a required option", optionName)
+ }
+ return nil
+}