Skip to content

Commit 7b1900d

Browse files
authored
Merge pull request #2845 from cesanta/tls
Misc TLS cleanups
2 parents 8c4cfc8 + 5a8c56e commit 7b1900d

File tree

4 files changed

+76
-28
lines changed

4 files changed

+76
-28
lines changed

mongoose.c

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8145,7 +8145,9 @@ bool mg_match(struct mg_str s, struct mg_str p, struct mg_str *caps) {
81458145
size_t i = 0, j = 0, ni = 0, nj = 0;
81468146
if (caps) caps->buf = NULL, caps->len = 0;
81478147
while (i < p.len || j < s.len) {
8148-
if (i < p.len && j < s.len && (p.buf[i] == '?' || s.buf[j] == p.buf[i])) {
8148+
if (i < p.len && j < s.len &&
8149+
(p.buf[i] == '?' ||
8150+
(p.buf[i] != '*' && p.buf[i] != '#' && s.buf[j] == p.buf[i]))) {
81498151
if (caps == NULL) {
81508152
} else if (p.buf[i] == '?') {
81518153
caps->buf = &s.buf[j], caps->len = 1; // Finalize `?` cap
@@ -8188,10 +8190,10 @@ bool mg_span(struct mg_str s, struct mg_str *a, struct mg_str *b, char sep) {
81888190

81898191
bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
81908192
size_t i = 0, ndigits = 0;
8191-
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
8193+
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
81928194
: val_len == sizeof(uint16_t) ? 0xFFFF
81938195
: val_len == sizeof(uint32_t) ? 0xFFFFFFFF
8194-
: (uint64_t) ~0;
8196+
: (uint64_t) ~0;
81958197
uint64_t result = 0;
81968198
if (max == (uint64_t) ~0 && val_len != sizeof(uint64_t)) return false;
81978199
if (base == 0 && str.len >= 2) {
@@ -8207,7 +8209,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
82078209
case 2:
82088210
while (i < str.len && (str.buf[i] == '0' || str.buf[i] == '1')) {
82098211
uint64_t digit = (uint64_t) (str.buf[i] - '0');
8210-
if (result > max/2) return false; // Overflow
8212+
if (result > max / 2) return false; // Overflow
82118213
result *= 2;
82128214
if (result > max - digit) return false; // Overflow
82138215
result += digit;
@@ -8217,12 +8219,12 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
82178219
case 10:
82188220
while (i < str.len && str.buf[i] >= '0' && str.buf[i] <= '9') {
82198221
uint64_t digit = (uint64_t) (str.buf[i] - '0');
8220-
if (result > max/10) return false; // Overflow
8222+
if (result > max / 10) return false; // Overflow
82218223
result *= 10;
82228224
if (result > max - digit) return false; // Overflow
82238225
result += digit;
82248226
i++, ndigits++;
8225-
}
8227+
}
82268228
break;
82278229
case 16:
82288230
while (i < str.len) {
@@ -8232,7 +8234,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
82328234
: (c >= 'a' && c <= 'f') ? (uint64_t) (c - 'W')
82338235
: (uint64_t) ~0;
82348236
if (digit == (uint64_t) ~0) break;
8235-
if (result > max/16) return false; // Overflow
8237+
if (result > max / 16) return false; // Overflow
82368238
result *= 16;
82378239
if (result > max - digit) return false; // Overflow
82388240
result += digit;
@@ -9651,13 +9653,15 @@ static void mg_tls_drop_record(struct mg_connection *c) {
96519653
static void mg_tls_drop_message(struct mg_connection *c) {
96529654
uint32_t len;
96539655
struct tls_data *tls = (struct tls_data *) c->tls;
9654-
if (tls->recv.len == 0) {
9656+
if (tls->recv.len == 0) return;
9657+
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
9658+
if (tls->recv.len < len) {
9659+
mg_error(c, "wrong size");
96559660
return;
96569661
}
9657-
len = MG_LOAD_BE24(tls->recv.buf + 1);
9658-
mg_sha256_update(&tls->sha256, tls->recv.buf, len + TLS_MSGHDR_SIZE);
9659-
tls->recv.buf += len + TLS_MSGHDR_SIZE;
9660-
tls->recv.len -= len + TLS_MSGHDR_SIZE;
9662+
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
9663+
tls->recv.buf += len;
9664+
tls->recv.len -= len;
96619665
if (tls->recv.len == 0) {
96629666
mg_tls_drop_record(c);
96639667
}
@@ -9918,6 +9922,10 @@ static int mg_tls_recv_record(struct mg_connection *c) {
99189922
free(dec);
99199923
}
99209924
#else
9925+
if (msgsz < 16) {
9926+
mg_error(c, "wrong size");
9927+
return -1;
9928+
}
99219929
mg_aes_gcm_decrypt(msg, msg, msgsz - 16, key, 16, nonce, sizeof(nonce));
99229930
#endif
99239931
r = msgsz - 16 - 1;
@@ -9981,8 +9989,10 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
99819989
MG_INFO(("bad session id len"));
99829990
}
99839991
cipher_suites_len = MG_LOAD_BE16(rio->buf + 44 + session_id_len);
9992+
if (cipher_suites_len > (rio->len - 46 - session_id_len)) goto fail;
99849993
ext_len = MG_LOAD_BE16(rio->buf + 48 + session_id_len + cipher_suites_len);
99859994
ext = rio->buf + 50 + session_id_len + cipher_suites_len;
9995+
if (ext_len > (rio->len - 52 - session_id_len - cipher_suites_len)) goto fail;
99869996
for (j = 0; j < ext_len;) {
99879997
uint16_t k;
99889998
uint16_t key_exchange_len;
@@ -9993,10 +10003,14 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
999310003
j += (uint16_t) (n + 4);
999410004
continue;
999510005
}
9996-
key_exchange_len = MG_LOAD_BE16(ext + j + 5);
10006+
key_exchange_len = MG_LOAD_BE16(ext + j + 4);
999710007
key_exchange = ext + j + 6;
10008+
if (key_exchange_len >
10009+
rio->len - (uint16_t) ((size_t) key_exchange - (size_t) rio->buf) - 2)
10010+
goto fail;
999810011
for (k = 0; k < key_exchange_len;) {
999910012
uint16_t m = MG_LOAD_BE16(key_exchange + k + 2);
10013+
if (m > (key_exchange_len - k - 4)) goto fail;
1000010014
if (m == 32 && key_exchange[k] == 0x00 && key_exchange[k + 1] == 0x1d) {
1000110015
memmove(tls->x25519_cli, key_exchange + k + 4, m);
1000210016
mg_tls_drop_record(c);
@@ -10006,6 +10020,7 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
1000610020
}
1000710021
j += (uint16_t) (n + 4);
1000810022
}
10023+
fail:
1000910024
mg_error(c, "bad client hello");
1001010025
return -1;
1001110026
}
@@ -10324,13 +10339,15 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
1032410339

1032510340
ext_len = MG_LOAD_BE16(rio->buf + 5 + 39 + 32 + 3);
1032610341
ext = rio->buf + 5 + 39 + 32 + 3 + 2;
10342+
if (ext_len > (rio->len - (5 + 39 + 32 + 3 + 2))) goto fail;
1032710343

1032810344
for (j = 0; j < ext_len;) {
1032910345
uint16_t ext_type = MG_LOAD_BE16(ext + j);
1033010346
uint16_t ext_len2 = MG_LOAD_BE16(ext + j + 2);
1033110347
uint16_t group;
1033210348
uint8_t *key_exchange;
1033310349
uint16_t key_exchange_len;
10350+
if (ext_len2 > (ext_len - j - 4)) goto fail;
1033410351
if (ext_type != 0x0033) { // not a key share extension, ignore
1033510352
j += (uint16_t) (ext_len2 + 4);
1033610353
continue;
@@ -10353,6 +10370,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
1035310370
mg_tls_generate_handshake_keys(c);
1035410371
return 0;
1035510372
}
10373+
fail:
1035610374
mg_error(c, "bad client hello");
1035710375
return -1;
1035810376
}
@@ -10663,7 +10681,7 @@ static int mg_parse_pem(const struct mg_str pem, const struct mg_str label,
1066310681
size_t n = 0, m = 0;
1066410682
char *s;
1066510683
const char *c;
10666-
struct mg_str caps[5];
10684+
struct mg_str caps[6]; // number of wildcards + 1
1066710685
if (!mg_match(pem, mg_str("#-----BEGIN #-----#-----END #-----#"), caps)) {
1066810686
*der = mg_strdup(pem);
1066910687
return 0;
@@ -10713,6 +10731,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
1071310731
if (opts->name.len > 0) {
1071410732
if (opts->name.len >= sizeof(tls->hostname) - 1) {
1071510733
mg_error(c, "hostname too long");
10734+
return;
1071610735
}
1071710736
strncpy((char *) tls->hostname, opts->name.buf, sizeof(tls->hostname) - 1);
1071810737
tls->hostname[opts->name.len] = 0;

src/str.c

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ bool mg_match(struct mg_str s, struct mg_str p, struct mg_str *caps) {
6969
size_t i = 0, j = 0, ni = 0, nj = 0;
7070
if (caps) caps->buf = NULL, caps->len = 0;
7171
while (i < p.len || j < s.len) {
72-
if (i < p.len && j < s.len && (p.buf[i] == '?' || s.buf[j] == p.buf[i])) {
72+
if (i < p.len && j < s.len &&
73+
(p.buf[i] == '?' ||
74+
(p.buf[i] != '*' && p.buf[i] != '#' && s.buf[j] == p.buf[i]))) {
7375
if (caps == NULL) {
7476
} else if (p.buf[i] == '?') {
7577
caps->buf = &s.buf[j], caps->len = 1; // Finalize `?` cap
@@ -112,10 +114,10 @@ bool mg_span(struct mg_str s, struct mg_str *a, struct mg_str *b, char sep) {
112114

113115
bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
114116
size_t i = 0, ndigits = 0;
115-
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
117+
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
116118
: val_len == sizeof(uint16_t) ? 0xFFFF
117119
: val_len == sizeof(uint32_t) ? 0xFFFFFFFF
118-
: (uint64_t) ~0;
120+
: (uint64_t) ~0;
119121
uint64_t result = 0;
120122
if (max == (uint64_t) ~0 && val_len != sizeof(uint64_t)) return false;
121123
if (base == 0 && str.len >= 2) {
@@ -131,7 +133,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
131133
case 2:
132134
while (i < str.len && (str.buf[i] == '0' || str.buf[i] == '1')) {
133135
uint64_t digit = (uint64_t) (str.buf[i] - '0');
134-
if (result > max/2) return false; // Overflow
136+
if (result > max / 2) return false; // Overflow
135137
result *= 2;
136138
if (result > max - digit) return false; // Overflow
137139
result += digit;
@@ -141,12 +143,12 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
141143
case 10:
142144
while (i < str.len && str.buf[i] >= '0' && str.buf[i] <= '9') {
143145
uint64_t digit = (uint64_t) (str.buf[i] - '0');
144-
if (result > max/10) return false; // Overflow
146+
if (result > max / 10) return false; // Overflow
145147
result *= 10;
146148
if (result > max - digit) return false; // Overflow
147149
result += digit;
148150
i++, ndigits++;
149-
}
151+
}
150152
break;
151153
case 16:
152154
while (i < str.len) {
@@ -156,7 +158,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
156158
: (c >= 'a' && c <= 'f') ? (uint64_t) (c - 'W')
157159
: (uint64_t) ~0;
158160
if (digit == (uint64_t) ~0) break;
159-
if (result > max/16) return false; // Overflow
161+
if (result > max / 16) return false; // Overflow
160162
result *= 16;
161163
if (result > max - digit) return false; // Overflow
162164
result += digit;

src/tls_builtin.c

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,15 @@ static void mg_tls_drop_record(struct mg_connection *c) {
214214
static void mg_tls_drop_message(struct mg_connection *c) {
215215
uint32_t len;
216216
struct tls_data *tls = (struct tls_data *) c->tls;
217-
if (tls->recv.len == 0) {
217+
if (tls->recv.len == 0) return;
218+
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
219+
if (tls->recv.len < len) {
220+
mg_error(c, "wrong size");
218221
return;
219222
}
220-
len = MG_LOAD_BE24(tls->recv.buf + 1);
221-
mg_sha256_update(&tls->sha256, tls->recv.buf, len + TLS_MSGHDR_SIZE);
222-
tls->recv.buf += len + TLS_MSGHDR_SIZE;
223-
tls->recv.len -= len + TLS_MSGHDR_SIZE;
223+
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
224+
tls->recv.buf += len;
225+
tls->recv.len -= len;
224226
if (tls->recv.len == 0) {
225227
mg_tls_drop_record(c);
226228
}
@@ -481,6 +483,10 @@ static int mg_tls_recv_record(struct mg_connection *c) {
481483
free(dec);
482484
}
483485
#else
486+
if (msgsz < 16) {
487+
mg_error(c, "wrong size");
488+
return -1;
489+
}
484490
mg_aes_gcm_decrypt(msg, msg, msgsz - 16, key, 16, nonce, sizeof(nonce));
485491
#endif
486492
r = msgsz - 16 - 1;
@@ -544,8 +550,10 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
544550
MG_INFO(("bad session id len"));
545551
}
546552
cipher_suites_len = MG_LOAD_BE16(rio->buf + 44 + session_id_len);
553+
if (cipher_suites_len > (rio->len - 46 - session_id_len)) goto fail;
547554
ext_len = MG_LOAD_BE16(rio->buf + 48 + session_id_len + cipher_suites_len);
548555
ext = rio->buf + 50 + session_id_len + cipher_suites_len;
556+
if (ext_len > (rio->len - 52 - session_id_len - cipher_suites_len)) goto fail;
549557
for (j = 0; j < ext_len;) {
550558
uint16_t k;
551559
uint16_t key_exchange_len;
@@ -556,10 +564,14 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
556564
j += (uint16_t) (n + 4);
557565
continue;
558566
}
559-
key_exchange_len = MG_LOAD_BE16(ext + j + 5);
567+
key_exchange_len = MG_LOAD_BE16(ext + j + 4);
560568
key_exchange = ext + j + 6;
569+
if (key_exchange_len >
570+
rio->len - (uint16_t) ((size_t) key_exchange - (size_t) rio->buf) - 2)
571+
goto fail;
561572
for (k = 0; k < key_exchange_len;) {
562573
uint16_t m = MG_LOAD_BE16(key_exchange + k + 2);
574+
if (m > (key_exchange_len - k - 4)) goto fail;
563575
if (m == 32 && key_exchange[k] == 0x00 && key_exchange[k + 1] == 0x1d) {
564576
memmove(tls->x25519_cli, key_exchange + k + 4, m);
565577
mg_tls_drop_record(c);
@@ -569,6 +581,7 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
569581
}
570582
j += (uint16_t) (n + 4);
571583
}
584+
fail:
572585
mg_error(c, "bad client hello");
573586
return -1;
574587
}
@@ -887,13 +900,15 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
887900

888901
ext_len = MG_LOAD_BE16(rio->buf + 5 + 39 + 32 + 3);
889902
ext = rio->buf + 5 + 39 + 32 + 3 + 2;
903+
if (ext_len > (rio->len - (5 + 39 + 32 + 3 + 2))) goto fail;
890904

891905
for (j = 0; j < ext_len;) {
892906
uint16_t ext_type = MG_LOAD_BE16(ext + j);
893907
uint16_t ext_len2 = MG_LOAD_BE16(ext + j + 2);
894908
uint16_t group;
895909
uint8_t *key_exchange;
896910
uint16_t key_exchange_len;
911+
if (ext_len2 > (ext_len - j - 4)) goto fail;
897912
if (ext_type != 0x0033) { // not a key share extension, ignore
898913
j += (uint16_t) (ext_len2 + 4);
899914
continue;
@@ -916,6 +931,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
916931
mg_tls_generate_handshake_keys(c);
917932
return 0;
918933
}
934+
fail:
919935
mg_error(c, "bad client hello");
920936
return -1;
921937
}
@@ -1226,7 +1242,7 @@ static int mg_parse_pem(const struct mg_str pem, const struct mg_str label,
12261242
size_t n = 0, m = 0;
12271243
char *s;
12281244
const char *c;
1229-
struct mg_str caps[5];
1245+
struct mg_str caps[6]; // number of wildcards + 1
12301246
if (!mg_match(pem, mg_str("#-----BEGIN #-----#-----END #-----#"), caps)) {
12311247
*der = mg_strdup(pem);
12321248
return 0;
@@ -1276,6 +1292,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
12761292
if (opts->name.len > 0) {
12771293
if (opts->name.len >= sizeof(tls->hostname) - 1) {
12781294
mg_error(c, "hostname too long");
1295+
return;
12791296
}
12801297
strncpy((char *) tls->hostname, opts->name.buf, sizeof(tls->hostname) - 1);
12811298
tls->hostname[opts->name.len] = 0;

test/unit_test.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ static void test_match(void) {
9595
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
9696
ASSERT(mg_strcmp(caps[1], mg_str("bc")) == 0);
9797
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
98+
99+
ASSERT(mg_match(mg_str("a#c"), mg_str("?#"), caps) == true);
100+
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
101+
ASSERT(mg_strcmp(caps[1], mg_str("#c")) == 0);
102+
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
103+
104+
ASSERT(mg_match(mg_str("a*c"), mg_str("?*"), caps) == true);
105+
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
106+
ASSERT(mg_strcmp(caps[1], mg_str("*c")) == 0);
107+
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
98108
}
99109
}
100110

0 commit comments

Comments
 (0)