2121 */
2222
2323#include <TrezorCrypto/schnorr.h>
24- #include <TrezorCrypto/memzero.h>
25- #include "stdio.h"
26-
27- static void hex_to_raw_bignum (const char * str , uint8_t bn_raw [32 ]) {
28- for (size_t i = 0 ; i < 32 ; i ++ ) {
29- uint8_t c = 0 ;
30- if (str [i * 2 ] >= '0' && str [i * 2 ] <= '9' ) c += (str [i * 2 ] - '0' ) << 4 ;
31- if ((str [i * 2 ] & ~0x20 ) >= 'A' && (str [i * 2 ] & ~0x20 ) <= 'F' )
32- c += (10 + (str [i * 2 ] & ~0x20 ) - 'A' ) << 4 ;
33- if (str [i * 2 + 1 ] >= '0' && str [i * 2 + 1 ] <= '9' )
34- c += (str [i * 2 + 1 ] - '0' );
35- if ((str [i * 2 + 1 ] & ~0x20 ) >= 'A' && (str [i * 2 + 1 ] & ~0x20 ) <= 'F' )
36- c += (10 + (str [i * 2 + 1 ] & ~0x20 ) - 'A' );
37- bn_raw [i ] = c ;
38- }
39- }
40-
41- static void bn_be_to_hex_string (const bignum256 * b , char result [64 ]) {
42- uint8_t raw_number [32 ] = {0 };
43- bn_write_be (b , raw_number );
44- for (int i = 0 ; i < 32 ; ++ i )
45- sprintf (result + i * 2 , "%02x" , ((unsigned char * )raw_number )[i ]);
46- }
47-
48- void schnorr_to_hex_str (const schnorr_sign_pair * sign , char hex_str [128 ]) {
49- bn_be_to_hex_string (& sign -> r , hex_str );
50- bn_be_to_hex_string (& sign -> s , hex_str + 64 );
51- }
52-
53- void schnorr_from_hex_str (const char hex_str [128 ], schnorr_sign_pair * sign ) {
54- uint8_t buf [32 ];
55- hex_to_raw_bignum (hex_str , buf );
56- bn_read_be (buf , & sign -> r );
57- hex_to_raw_bignum (hex_str + 64 , buf );
58- bn_read_be (buf , & sign -> s );
59- }
6024
6125// r = H(Q, kpub, m)
6226static void calc_r (const curve_point * Q , const uint8_t pub_key [33 ],
@@ -71,69 +35,101 @@ static void calc_r(const curve_point *Q, const uint8_t pub_key[33],
7135 sha256_Update (& ctx , pub_key , 33 );
7236 sha256_Update (& ctx , msg , msg_len );
7337 sha256_Final (& ctx , digest );
38+
39+ // Convert the raw bigendian 256 bit value to a normalized, partly reduced bignum
7440 bn_read_be (digest , r );
7541}
7642
77- // returns 0 if signing succeeded
43+ // Returns 0 if signing succeeded
7844int schnorr_sign (const ecdsa_curve * curve , const uint8_t * priv_key ,
7945 const bignum256 * k , const uint8_t * msg , const uint32_t msg_len ,
8046 schnorr_sign_pair * result ) {
47+ uint8_t pub_key [33 ];
48+ curve_point Q ;
8149 bignum256 private_key_scalar ;
50+ bignum256 r_temp ;
51+ bignum256 s_temp ;
52+ bignum256 r_kpriv_result ;
53+
8254 bn_read_be (priv_key , & private_key_scalar );
83- uint8_t pub_key [33 ];
8455 ecdsa_get_public_key33 (curve , priv_key , pub_key );
8556
86- /* Q = kG */
87- curve_point Q ;
88- scalar_multiply (curve , k , & Q );
57+ // Compute commitment Q = kG
58+ point_multiply (curve , k , & curve -> G , & Q );
8959
90- /* r = H(Q, kpub, m) */
91- calc_r (& Q , pub_key , msg , msg_len , & result -> r );
60+ // Compute challenge r = H(Q, kpub, m)
61+ calc_r (& Q , pub_key , msg , msg_len , & r_temp );
62+
63+ // Fully reduce the bignum
64+ bn_mod (& r_temp , & curve -> order );
9265
93- /* s = k - r*kpriv mod(order) */
94- bignum256 s_temp ;
95- bn_copy (& result -> r , & s_temp );
96- bn_multiply (& private_key_scalar , & s_temp , & curve -> order );
97- bn_subtractmod (k , & s_temp , & result -> s , & curve -> order );
98- memzero (& private_key_scalar , sizeof (private_key_scalar ));
66+ // Convert the normalized, fully reduced bignum to a raw bigendian 256 bit value
67+ bn_write_be (& r_temp , result -> r );
9968
100- while (bn_is_less (& curve -> order , & result -> s )) {
101- bn_mod (& result -> s , & curve -> order );
102- }
69+ // Compute s = k - r*kpriv
70+ bn_copy (& r_temp , & r_kpriv_result );
10371
104- if (bn_is_zero (& result -> s ) || bn_is_zero (& result -> r )) {
105- return 1 ;
106- }
72+ // r*kpriv result is partly reduced
73+ bn_multiply (& private_key_scalar , & r_kpriv_result , & curve -> order );
74+
75+ // k - r*kpriv result is normalized but not reduced
76+ bn_subtractmod (k , & r_kpriv_result , & s_temp , & curve -> order );
77+
78+ // Partly reduce the result
79+ bn_fast_mod (& s_temp , & curve -> order );
80+
81+ // Fully reduce the result
82+ bn_mod (& s_temp , & curve -> order );
83+
84+ // Convert the normalized, fully reduced bignum to a raw bigendian 256 bit value
85+ bn_write_be (& s_temp , result -> s );
86+
87+ if (bn_is_zero (& r_temp ) || bn_is_zero (& s_temp )) return 1 ;
10788
10889 return 0 ;
10990}
11091
111- // returns 0 if verification succeeded
92+ // Returns 0 if verification succeeded
11293int schnorr_verify (const ecdsa_curve * curve , const uint8_t * pub_key ,
11394 const uint8_t * msg , const uint32_t msg_len ,
11495 const schnorr_sign_pair * sign ) {
96+ curve_point pub_key_point ;
97+ curve_point sG , Q ;
98+ bignum256 r_temp ;
99+ bignum256 s_temp ;
100+ bignum256 r_computed ;
101+
115102 if (msg_len == 0 ) return 1 ;
116- if (bn_is_zero (& sign -> r )) return 2 ;
117- if (bn_is_zero (& sign -> s )) return 3 ;
118- if (bn_is_less (& curve -> order , & sign -> r )) return 4 ;
119- if (bn_is_less (& curve -> order , & sign -> s )) return 5 ;
120103
121- curve_point pub_key_point ;
104+ // Convert the raw bigendian 256 bit values to normalized, partly reduced bignums
105+ bn_read_be (sign -> r , & r_temp );
106+ bn_read_be (sign -> s , & s_temp );
107+
108+ // Check if r,s are in [1, ..., order-1]
109+ if (bn_is_zero (& r_temp )) return 2 ;
110+ if (bn_is_zero (& s_temp )) return 3 ;
111+ if (bn_is_less (& curve -> order , & r_temp )) return 4 ;
112+ if (bn_is_less (& curve -> order , & s_temp )) return 5 ;
113+ if (bn_is_equal (& curve -> order , & r_temp )) return 6 ;
114+ if (bn_is_equal (& curve -> order , & s_temp )) return 7 ;
115+
122116 if (!ecdsa_read_pubkey (curve , pub_key , & pub_key_point )) {
123- return 6 ;
117+ return 8 ;
124118 }
125119
126120 // Compute Q = sG + r*kpub
127- curve_point sG , Q ;
128- scalar_multiply (curve , & sign -> s , & sG );
129- point_multiply (curve , & sign -> r , & pub_key_point , & Q );
121+ point_multiply (curve , & s_temp , & curve -> G , & sG );
122+ point_multiply (curve , & r_temp , & pub_key_point , & Q );
130123 point_add (curve , & sG , & Q );
131124
132- /* r = H(Q, kpub, m) */
133- bignum256 r ;
134- calc_r (& Q , pub_key , msg , msg_len , & r );
125+ // Compute r' = H(Q, kpub, m)
126+ calc_r (& Q , pub_key , msg , msg_len , & r_computed );
127+
128+ // Fully reduce the bignum
129+ bn_mod (& r_computed , & curve -> order );
135130
136- if (bn_is_equal (& r , & sign -> r )) return 0 ; // success
131+ // Check r == r'
132+ if (bn_is_equal (& r_temp , & r_computed )) return 0 ; // success
137133
138134 return 10 ;
139135}
0 commit comments