Skip to content

Commit a4ceec1

Browse files
authored
Merge pull request #3 from jamesob/jamesob-25-03-0.0.2
0.0.2: base58, memzero
2 parents 8b76ce0 + fb2f6a6 commit a4ceec1

File tree

7 files changed

+171
-38
lines changed

7 files changed

+171
-38
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
## 0.0.2
2+
3+
- Changed `bip32_serialze` str_len to a pointer which returns the final length of the
4+
base58-encoded out string.
5+
- Added some precautionary `sodium_memzero()` calls.
6+
- Made `bip32_b58_encode()` and `bip32_b58_decode()` a public part of the API.

bip32.c

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ int bip32_from_seed(bip32_key *key, const unsigned char *seed, size_t seed_len)
6666
memcpy(key->chain_code, output + BIP32_PRIVKEY_SIZE, BIP32_CHAINCODE_SIZE);
6767

6868
exit:
69+
sodium_memzero(output, crypto_auth_hmacsha512_BYTES);
6970
secp256k1_context_destroy(ctx);
7071
return retcode;
7172
}
@@ -164,6 +165,9 @@ int bip32_index_derive(bip32_key *target, const bip32_key *source, uint32_t inde
164165

165166
bip32_hmac_sha512(output, source->chain_code, BIP32_CHAINCODE_SIZE, hmac_msg, hmac_msg_len);
166167

168+
// hmac_msg potentially contains privkey bytes.
169+
sodium_memzero(hmac_msg, hmac_msg_len);
170+
167171
memcpy(target->chain_code, output + BIP32_PRIVKEY_SIZE, BIP32_CHAINCODE_SIZE);
168172

169173
if (source->is_private) {
@@ -222,19 +226,6 @@ int bip32_index_derive(bip32_key *target, const bip32_key *source, uint32_t inde
222226
return retcode;
223227
}
224228

225-
226-
// Returns true if invalid path characters are detected in a path string.
227-
static bool has_invalid_path_characters(const char* str) {
228-
const char* valid = "m/0123456789hH'pP";
229-
while (*str) {
230-
if (!strchr(valid, *str)) {
231-
return true;
232-
}
233-
str++;
234-
}
235-
return false;
236-
}
237-
238229
int bip32_derive_from_str(bip32_key* target, const char* source, const char* path) {
239230
if (!target || !source || !path || strncmp(path, "m", 1) != 0) {
240231
return 0;
@@ -250,6 +241,7 @@ int bip32_derive_from_str(bip32_key* target, const char* source, const char* pat
250241
strncmp(source, "xpub", 4) == 0 ||
251242
strncmp(source, "tpub", 4) == 0) {
252243
if (!bip32_deserialize(&basekey, source, strlen(source))) {
244+
sodium_memzero(&basekey, sizeof(bip32_key));
253245
return 0;
254246
}
255247
}
@@ -261,6 +253,7 @@ int bip32_derive_from_str(bip32_key* target, const char* source, const char* pat
261253
return 0;
262254
}
263255
if (!bip32_from_seed(&basekey, seedbytes, bin_len)) {
256+
sodium_memzero(&basekey, sizeof(bip32_key));
264257
return 0;
265258
}
266259
} else {
@@ -269,6 +262,7 @@ int bip32_derive_from_str(bip32_key* target, const char* source, const char* pat
269262

270263
if (bip32_derive(&basekey, path)) {
271264
memcpy(target, &basekey, sizeof(bip32_key));
265+
sodium_memzero(&basekey, sizeof(bip32_key));
272266
return 1;
273267
}
274268
return 0;
@@ -284,6 +278,18 @@ int bip32_derive_from_seed(bip32_key* target, const unsigned char* seed, size_t
284278
return 0;
285279
}
286280

281+
// Returns true if invalid path characters are detected in a path string.
282+
static bool has_invalid_path_characters(const char* str) {
283+
const char* valid = "m/0123456789hH'pP";
284+
while (*str) {
285+
if (!strchr(valid, *str)) {
286+
return true;
287+
}
288+
str++;
289+
}
290+
return false;
291+
}
292+
287293
// Do an in-place derivation on `key`.
288294
int bip32_derive(bip32_key* key, const char* path) {
289295
if (!path || strncmp(path, "m", 1) != 0 || has_invalid_path_characters(path)) {
@@ -314,6 +320,7 @@ int bip32_derive(bip32_key* key, const char* path) {
314320
if (bip32_index_derive(key, &tmp, path_index) != 1) {
315321
return 0;
316322
}
323+
sodium_memzero(&tmp, sizeof(bip32_key));
317324
p = strchr(end, '/');
318325
}
319326

@@ -323,7 +330,7 @@ int bip32_derive(bip32_key* key, const char* path) {
323330
#define SER_SIZE 78
324331
#define SER_PLUS_CHECKSUM_SIZE (SER_SIZE + 4)
325332

326-
int bip32_serialize(const bip32_key *key, char *str, size_t str_len) {
333+
int bip32_serialize(const bip32_key *key, char *str, size_t* str_len) {
327334
unsigned char data[SER_PLUS_CHECKSUM_SIZE];
328335
uint32_t version;
329336

@@ -362,7 +369,10 @@ int bip32_serialize(const bip32_key *key, char *str, size_t str_len) {
362369
bip32_sha256_double(hash, data, 78);
363370
memcpy(data + SER_SIZE, hash, 4);
364371

365-
return b58enc(str, &str_len, data, SER_PLUS_CHECKSUM_SIZE);
372+
bool b58_ok = bip32_b58_encode(str, str_len, data, SER_PLUS_CHECKSUM_SIZE);
373+
sodium_memzero(data, SER_PLUS_CHECKSUM_SIZE);
374+
375+
return b58_ok ? 1 : 0;
366376
}
367377

368378
#define BIP32_BASE58_BYTES_LEN 82
@@ -371,7 +381,8 @@ int bip32_deserialize(bip32_key *key, const char *str, const size_t str_len) {
371381
unsigned char data[BIP32_BASE58_BYTES_LEN];
372382
size_t data_len = BIP32_BASE58_BYTES_LEN;
373383

374-
if (!b58tobin(data, &data_len, str, str_len) || data_len != BIP32_BASE58_BYTES_LEN) {
384+
if (!bip32_b58_decode(data, &data_len, str, str_len) || data_len != BIP32_BASE58_BYTES_LEN) {
385+
sodium_memzero(data, BIP32_BASE58_BYTES_LEN);
375386
return 0;
376387
}
377388

@@ -426,16 +437,22 @@ int bip32_deserialize(bip32_key *key, const char *str, const size_t str_len) {
426437

427438
if (key->is_private) {
428439
if (data[45] != 0) {
440+
sodium_memzero(data, BIP32_BASE58_BYTES_LEN);
429441
secp256k1_context_destroy(ctx);
430442
return 0;
431443
}
444+
432445
memcpy(key->key.privkey, data + 46, BIP32_PRIVKEY_SIZE);
446+
sodium_memzero(data, BIP32_BASE58_BYTES_LEN);
447+
433448
if (!secp256k1_ec_seckey_verify(ctx, key->key.privkey)) {
434449
secp256k1_context_destroy(ctx);
435450
return 0;
436451
}
437452
} else {
438453
memcpy(key->key.pubkey, data + 45, BIP32_PUBKEY_SIZE);
454+
sodium_memzero(data, BIP32_BASE58_BYTES_LEN);
455+
439456
secp256k1_pubkey pubkey;
440457
if (!secp256k1_ec_pubkey_parse(ctx, &pubkey, key->key.pubkey, BIP32_PUBKEY_SIZE)) {
441458
secp256k1_context_destroy(ctx);
@@ -484,3 +501,11 @@ void bip32_hmac_sha512(
484501
crypto_auth_hmacsha512_update(&state, msg, msg_len);
485502
crypto_auth_hmacsha512_final(&state, hmac_out);
486503
}
504+
505+
bool bip32_b58_encode(char* str_out, size_t* out_size, const unsigned char* data, size_t data_size) {
506+
return b58enc(str_out, out_size, data, data_size);
507+
}
508+
509+
bool bip32_b58_decode(unsigned char* bin_out, size_t* out_size, const char* str_in, size_t str_size) {
510+
return b58tobin(bin_out, out_size, str_in, str_size);
511+
}

bip32.h

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
/**
2+
* This API is generally designed to mimic the libsecp256k1 library, in terms of
3+
* argument order and return conventions.
4+
*/
15
#include <stdbool.h>
26
#include <stdint.h>
37
#include <stdio.h>
@@ -60,11 +64,14 @@ int bip32_derive_from_str(bip32_key *target, const char* source, const char* pat
6064
*/
6165
int bip32_derive(bip32_key *target, const char* path);
6266

63-
/** Serialize a BIP32 key to its base58 string representation.
67+
/** Serialize a BIP32 key to its base58 string representation. Writes the resulting
68+
* string to `str`, and the length of the resulting string to `str_len`.
69+
*
70+
* `str_len` must initially be set to the maximum length of the `str` buffer.
6471
*
6572
* Returns 1 if successful.
6673
*/
67-
int bip32_serialize(const bip32_key *key, char *str, size_t str_len);
74+
int bip32_serialize(const bip32_key *key, char *str, size_t* str_len);
6875

6976
/** Deserialize a BIP32 key from its base58 string representation.
7077
*
@@ -106,6 +113,18 @@ void bip32_hmac_sha512(
106113
size_t msg_len
107114
);
108115

116+
/** Encode some bytes to a base58 string.
117+
*
118+
* Returns true if successful.
119+
*/
120+
bool bip32_b58_encode(char* str_out, size_t* out_size, const unsigned char* data, size_t data_size);
121+
122+
/** Decode some bytes from a base58 string.
123+
*
124+
* Returns true if successful.
125+
*/
126+
bool bip32_b58_decode(unsigned char* bin_out, size_t* out_size, const char* str_in, size_t str_size);
127+
109128
#ifdef __cplusplus
110129
}
111130
#endif

examples/cli.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ int main(int argc, char *argv[]) {
1515
return 1;
1616
}
1717

18-
if (!bip32_serialize(&key, serialized, sizeof(serialized))) {
18+
size_t out_size;
19+
if (!bip32_serialize(&key, serialized, &out_size)) {
1920
fprintf(stderr, "Serialization failed\n");
2021
return 1;
2122
}

examples/py/bindings.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
from functools import lru_cache
44
from pathlib import Path
55
from ctypes import (
6-
c_uint8, c_uint32, c_size_t, c_char_p, c_ubyte, c_void_p,
7-
Structure, Union, POINTER, create_string_buffer
8-
)
6+
c_uint8,
7+
c_uint32,
8+
c_size_t,
9+
c_char_p,
10+
c_ubyte,
11+
c_void_p,
12+
Structure,
13+
Union,
14+
POINTER,
15+
create_string_buffer,
16+
byref)
917

1018

1119
@lru_cache
@@ -20,7 +28,9 @@ def get_bip32_module():
2028
bip32_lib.bip32_init.argtypes = [POINTER(BIP32Key)]
2129
bip32_lib.bip32_init.restype = None
2230

23-
bip32_lib.bip32_derive_from_seed.argtypes = [POINTER(BIP32Key), POINTER(c_ubyte), c_size_t, c_char_p]
31+
bip32_lib.bip32_derive_from_seed.argtypes = [
32+
POINTER(BIP32Key), POINTER(c_ubyte), c_size_t, c_char_p
33+
]
2434
bip32_lib.bip32_derive_from_seed.restype = ctypes.c_int
2535

2636
bip32_lib.bip32_derive_from_str.argtypes = [POINTER(BIP32Key), c_char_p, c_char_p]
@@ -29,7 +39,9 @@ def get_bip32_module():
2939
bip32_lib.bip32_derive.argtypes = [POINTER(BIP32Key), c_char_p]
3040
bip32_lib.bip32_derive.restype = ctypes.c_int
3141

32-
bip32_lib.bip32_serialize.argtypes = [POINTER(BIP32Key), c_char_p, c_size_t]
42+
bip32_lib.bip32_serialize.argtypes = [
43+
POINTER(BIP32Key), c_char_p, POINTER(c_size_t)
44+
]
3345
bip32_lib.bip32_serialize.restype = ctypes.c_bool
3446

3547
bip32_lib.bip32_deserialize.argtypes = [POINTER(BIP32Key), c_char_p, c_size_t]
@@ -38,14 +50,22 @@ def get_bip32_module():
3850
bip32_lib.bip32_get_public.argtypes = [POINTER(BIP32Key), POINTER(BIP32Key)]
3951
bip32_lib.bip32_get_public.restype = ctypes.c_int
4052

53+
bip32_lib.bip32_b58_encode.argtypes = [
54+
c_char_p, POINTER(c_size_t), POINTER(c_ubyte), c_size_t
55+
]
56+
bip32_lib.bip32_b58_encode.restype = ctypes.c_bool
57+
58+
bip32_lib.bip32_b58_decode.argtypes = [
59+
POINTER(c_ubyte), POINTER(c_size_t), c_char_p, c_size_t
60+
]
61+
bip32_lib.bip32_b58_decode.restype = ctypes.c_bool
62+
4163
return bip32_lib
4264

4365

4466
class KeyUnion(Union):
45-
_fields_ = [
46-
('privkey', c_uint8 * 32),
47-
('pubkey', c_uint8 * 33)
48-
]
67+
_fields_ = [('privkey', c_uint8 * 32), ('pubkey', c_uint8 * 33)]
68+
4969

5070
class BIP32Key(Structure):
5171
_fields_ = [
@@ -65,6 +85,7 @@ def print(self):
6585

6686

6787
class BIP32:
88+
6889
def __init__(self):
6990
self.key = BIP32Key()
7091
self.bip32_lib = get_bip32_module()
@@ -80,7 +101,9 @@ def derive(self, path: str) -> 'BIP32':
80101

81102
def serialize(self):
82103
buf = create_string_buffer(200) # Standard BIP32 serialization length
83-
if not self.bip32_lib.bip32_serialize(self.key, buf, len(buf)):
104+
out_len = c_size_t(len(buf))
105+
106+
if not self.bip32_lib.bip32_serialize(self.key, buf, byref(out_len)):
84107
raise ValueError("Serialization failed")
85108
return buf.value.decode()
86109

@@ -106,7 +129,8 @@ def derive(source: str, path: str = 'm') -> BIP32:
106129
107130
"""
108131
b = BIP32()
109-
if not get_bip32_module().bip32_derive_from_str(b.key, source.encode(), path.encode()):
132+
if not get_bip32_module().bip32_derive_from_str(
133+
b.key, source.encode(), path.encode()):
110134
raise ValueError("failed")
111135
return b
112136

@@ -115,6 +139,35 @@ def derive_from_seed(seed: bytes, path: str = 'm') -> BIP32:
115139
b = BIP32()
116140
c_seed = ctypes.c_char_p(seed)
117141
seed_ptr = ctypes.cast(c_seed, POINTER(c_ubyte))
118-
if not get_bip32_module().bip32_derive_from_seed(b.key, seed_ptr, len(seed), path.encode()):
142+
if not get_bip32_module().bip32_derive_from_seed(
143+
b.key, seed_ptr, len(seed), path.encode()):
119144
raise ValueError("failed")
120145
return b
146+
147+
148+
def b58_encode(inp: bytes) -> str:
149+
data_len = len(inp)
150+
data_arr = (c_ubyte * data_len)(*inp)
151+
152+
out_size = c_size_t(data_len * 2)
153+
str_out = ctypes.create_string_buffer(out_size.value)
154+
155+
if not get_bip32_module().bip32_b58_encode(
156+
str_out, byref(out_size), data_arr, data_len):
157+
raise ValueError("base58 encoding failed")
158+
159+
return str_out.value[:out_size.value].decode('utf-8')
160+
161+
162+
def b58_decode(in_str: str) -> bytes:
163+
str_bytes = c_char_p(in_str.encode('utf-8'))
164+
str_len = len(in_str)
165+
166+
out_size = c_size_t(str_len * 2)
167+
bin_out = (c_ubyte * out_size.value)()
168+
169+
if not get_bip32_module().bip32_b58_decode(
170+
bin_out, byref(out_size), str_bytes, str_len):
171+
raise ValueError("base58 decoding failed")
172+
173+
return bytes(bin_out[-out_size.value:])

examples/py/test_fuzz_cross_impl.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
from hypothesis import given, strategies as st, target, settings
2020

21-
from bindings import derive, derive_from_seed
21+
from bindings import derive, derive_from_seed, b58_decode, b58_encode
2222

2323
log = logging.getLogger(__name__)
2424
logging.basicConfig()
@@ -181,5 +181,32 @@ def test_xpub_impls(bip32_path):
181181
assert ours == pys
182182

183183

184+
185+
@given(b58_data=st.binary(min_size=0, max_size=1000))
186+
@settings(max_examples=1000)
187+
def test_base58(b58_data: bytes):
188+
if b58_data and len(b58_data) >= 2:
189+
# TODO: figure out why the base58 impl is failing on example b':'
190+
assert b58_decode(b58_encode(b58_data)) == b58_data
191+
192+
193+
def test_base58_known_vectors():
194+
cases = [
195+
(bytes.fromhex(""), ""),
196+
(bytes.fromhex("00"), "1"),
197+
(bytes.fromhex("0000"), "11"),
198+
(bytes.fromhex("68656c6c6f20776f726c64"), "StV1DL6CwTryKyV"),
199+
(bytes.fromhex("0068656c6c6f20776f726c64"), "1StV1DL6CwTryKyV"),
200+
(bytes.fromhex("000068656c6c6f20776f726c64"), "11StV1DL6CwTryKyV"),
201+
]
202+
203+
for raw, encoded in cases:
204+
if raw: # Skip empty input for encoding test
205+
assert b58_encode(raw) == encoded
206+
207+
if encoded: # Skip empty input for decoding test
208+
assert b58_decode(encoded) == raw
209+
210+
184211
if __name__ == "__main__":
185212
pytest.main([__file__, "-v", "--capture=no", "--hypothesis-show-statistics", "-x"] + sys.argv[1:])

0 commit comments

Comments
 (0)