Skip to content

Commit a6494a3

Browse files
authored
[HLSL][RootSignature] Allow for multiple parsing errors in RootSignatureParser (#147832)
This pr implements returning multiple parsing errors at the granularity of a `RootElement` This is achieved by adding a new interface onto `RootSignatureParser`, namely, `skipUntilExpectedToken`. This will be used to consume all the intermediate tokens between when an error has occurred and when the next `RootElement` begins. At this granularity, the implementation is somewhat straight forward, as we can just implement this `skip` function when we return from a `parse[RootElement]` method and continue in the main `parse` loop. With the exception that the `parseDescriptorTable` will also have to skip ahead to the next expected closing `')'`. If we want to provide any finer granularity, then the skip logic becomes significantly more complicated. Skipping to the next root element will provide a good ratio of user experience benefit to complexity of implementation. For more context see linked issue. - Updates `HLSLRootSignatureParser` with a `skipUntilExpectedToken` and `skipUntilClosedParen` interface - Updates the `parse` loops to use the skip interface when an error is found on parsing a root element - Updates `parseDescriptorTable` to skip ahead to the next `')'` if it was inside a clause - Adds test-case to demonstrate multiple error being reported Resolves: #145818
1 parent ef48b7f commit a6494a3

File tree

3 files changed

+155
-16
lines changed

3 files changed

+155
-16
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ class RootSignatureParser {
198198
bool tryConsumeExpectedToken(RootSignatureToken::Kind Expected);
199199
bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);
200200

201+
/// Consume tokens until the expected token has been peeked to be next
202+
/// or we have reached the end of the stream. Note that this means the
203+
/// expected token will be the next token not CurToken.
204+
///
205+
/// Returns true if it found a token of the given type.
206+
bool skipUntilExpectedToken(RootSignatureToken::Kind Expected);
207+
bool skipUntilExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);
208+
209+
/// Consume tokens until we reach a closing right paren, ')', or, until we
210+
/// have reached the end of the stream. This will place the current token
211+
/// to be the end of stream or the right paren.
212+
///
213+
/// Returns true if it is closed before the end of stream.
214+
bool skipUntilClosedParens(uint32_t NumParens = 1);
215+
201216
/// Convert the token's offset in the signature string to its SourceLocation
202217
///
203218
/// This allows to currently retrieve the location for multi-token

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ namespace hlsl {
1717

1818
using TokenKind = RootSignatureToken::Kind;
1919

20+
static const TokenKind RootElementKeywords[] = {
21+
TokenKind::kw_RootFlags,
22+
TokenKind::kw_CBV,
23+
TokenKind::kw_UAV,
24+
TokenKind::kw_SRV,
25+
TokenKind::kw_DescriptorTable,
26+
TokenKind::kw_StaticSampler,
27+
};
28+
2029
RootSignatureParser::RootSignatureParser(
2130
llvm::dxbc::RootSignatureVersion Version,
2231
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
@@ -27,51 +36,76 @@ RootSignatureParser::RootSignatureParser(
2736
bool RootSignatureParser::parse() {
2837
// Iterate as many RootSignatureElements as possible, until we hit the
2938
// end of the stream
39+
bool HadError = false;
3040
while (!peekExpectedToken(TokenKind::end_of_stream)) {
3141
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
3242
SourceLocation ElementLoc = getTokenLocation(CurToken);
3343
auto Flags = parseRootFlags();
34-
if (!Flags.has_value())
35-
return true;
44+
if (!Flags.has_value()) {
45+
HadError = true;
46+
skipUntilExpectedToken(RootElementKeywords);
47+
continue;
48+
}
49+
3650
Elements.emplace_back(ElementLoc, *Flags);
3751
} else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
3852
SourceLocation ElementLoc = getTokenLocation(CurToken);
3953
auto Constants = parseRootConstants();
40-
if (!Constants.has_value())
41-
return true;
54+
if (!Constants.has_value()) {
55+
HadError = true;
56+
skipUntilExpectedToken(RootElementKeywords);
57+
continue;
58+
}
4259
Elements.emplace_back(ElementLoc, *Constants);
4360
} else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
4461
SourceLocation ElementLoc = getTokenLocation(CurToken);
4562
auto Table = parseDescriptorTable();
46-
if (!Table.has_value())
47-
return true;
63+
if (!Table.has_value()) {
64+
HadError = true;
65+
// We are within a DescriptorTable, we will do our best to recover
66+
// by skipping until we encounter the expected closing ')'.
67+
skipUntilClosedParens();
68+
consumeNextToken();
69+
skipUntilExpectedToken(RootElementKeywords);
70+
continue;
71+
}
4872
Elements.emplace_back(ElementLoc, *Table);
4973
} else if (tryConsumeExpectedToken(
5074
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
5175
SourceLocation ElementLoc = getTokenLocation(CurToken);
5276
auto Descriptor = parseRootDescriptor();
53-
if (!Descriptor.has_value())
54-
return true;
77+
if (!Descriptor.has_value()) {
78+
HadError = true;
79+
skipUntilExpectedToken(RootElementKeywords);
80+
continue;
81+
}
5582
Elements.emplace_back(ElementLoc, *Descriptor);
5683
} else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
5784
SourceLocation ElementLoc = getTokenLocation(CurToken);
5885
auto Sampler = parseStaticSampler();
59-
if (!Sampler.has_value())
60-
return true;
86+
if (!Sampler.has_value()) {
87+
HadError = true;
88+
skipUntilExpectedToken(RootElementKeywords);
89+
continue;
90+
}
6191
Elements.emplace_back(ElementLoc, *Sampler);
6292
} else {
93+
HadError = true;
6394
consumeNextToken(); // let diagnostic be at the start of invalid token
6495
reportDiag(diag::err_hlsl_invalid_token)
6596
<< /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
66-
return true;
97+
skipUntilExpectedToken(RootElementKeywords);
98+
continue;
6799
}
68100

69-
// ',' denotes another element, otherwise, expected to be at end of stream
70-
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
101+
if (!tryConsumeExpectedToken(TokenKind::pu_comma)) {
102+
// ',' denotes another element, otherwise, expected to be at end of stream
71103
break;
104+
}
72105
}
73106

74-
return consumeExpectedToken(TokenKind::end_of_stream,
107+
return HadError ||
108+
consumeExpectedToken(TokenKind::end_of_stream,
75109
diag::err_expected_either, TokenKind::pu_comma);
76110
}
77111

@@ -262,8 +296,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
262296
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
263297
SourceLocation ElementLoc = getTokenLocation(CurToken);
264298
auto Clause = parseDescriptorTableClause();
265-
if (!Clause.has_value())
299+
if (!Clause.has_value()) {
300+
// We are within a DescriptorTableClause, we will do our best to recover
301+
// by skipping until we encounter the expected closing ')'
302+
skipUntilExpectedToken(TokenKind::pu_r_paren);
303+
consumeNextToken();
266304
return std::nullopt;
305+
}
267306
Elements.emplace_back(ElementLoc, *Clause);
268307
Table.NumClauses++;
269308
} else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
@@ -1371,6 +1410,40 @@ bool RootSignatureParser::tryConsumeExpectedToken(
13711410
return true;
13721411
}
13731412

1413+
bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
1414+
return skipUntilExpectedToken(ArrayRef{Expected});
1415+
}
1416+
1417+
bool RootSignatureParser::skipUntilExpectedToken(
1418+
ArrayRef<TokenKind> AnyExpected) {
1419+
1420+
while (!peekExpectedToken(AnyExpected)) {
1421+
if (peekExpectedToken(TokenKind::end_of_stream))
1422+
return false;
1423+
consumeNextToken();
1424+
}
1425+
1426+
return true;
1427+
}
1428+
1429+
bool RootSignatureParser::skipUntilClosedParens(uint32_t NumParens) {
1430+
TokenKind ParenKinds[] = {
1431+
TokenKind::pu_l_paren,
1432+
TokenKind::pu_r_paren,
1433+
};
1434+
while (skipUntilExpectedToken(ParenKinds)) {
1435+
consumeNextToken();
1436+
if (CurToken.TokKind == TokenKind::pu_r_paren)
1437+
NumParens--;
1438+
else
1439+
NumParens++;
1440+
if (NumParens == 0)
1441+
return true;
1442+
}
1443+
1444+
return false;
1445+
}
1446+
13741447
SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
13751448
return Signature->getLocationOfByte(Tok.LocOffset, PP.getSourceManager(),
13761449
PP.getLangOpts(), PP.getTargetInfo());

clang/test/SemaHLSL/RootSignature-err.hlsl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,57 @@ void bad_root_signature_22() {}
104104
[RootSignature("RootFlags(local_root_signature | root_flag_typo)")]
105105
void bad_root_signature_23() {}
106106

107+
#define DemoMultipleErrorsRootSignature \
108+
"CBV(b0, space = invalid)," \
109+
"StaticSampler()" \
110+
"DescriptorTable(" \
111+
" visibility = SHADER_VISIBILITY_ALL," \
112+
" visibility = SHADER_VISIBILITY_DOMAIN," \
113+
")," \
114+
"SRV(t0, space = 28947298374912374098172)" \
115+
"UAV(u0, flags = 3)" \
116+
"DescriptorTable(Sampler(s0 flags = DATA_VOLATILE))," \
117+
"CBV(b0),,"
118+
119+
// expected-error@+7 {{expected integer literal after '='}}
120+
// expected-error@+6 {{did not specify mandatory parameter 's register'}}
121+
// expected-error@+5 {{specified the same parameter 'visibility' multiple times}}
122+
// expected-error@+4 {{integer literal is too large to be represented as a 32-bit signed integer type}}
123+
// expected-error@+3 {{flag value is neither a literal 0 nor a named value}}
124+
// expected-error@+2 {{expected ')' or ','}}
125+
// expected-error@+1 {{invalid parameter of RootSignature}}
126+
[RootSignature(DemoMultipleErrorsRootSignature)]
127+
void multiple_errors() {}
128+
129+
#define DemoGranularityRootSignature \
130+
"CBV(b0, reported_diag, flags = skipped_diag)," \
131+
"DescriptorTable( " \
132+
" UAV(u0, reported_diag), " \
133+
" SRV(t0, skipped_diag), " \
134+
")," \
135+
"StaticSampler(s0, reported_diag, SRV(t0, reported_diag)" \
136+
""
137+
138+
// expected-error@+4 {{invalid parameter of CBV}}
139+
// expected-error@+3 {{invalid parameter of UAV}}
140+
// expected-error@+2 {{invalid parameter of StaticSampler}}
141+
// expected-error@+1 {{invalid parameter of SRV}}
142+
[RootSignature(DemoGranularityRootSignature)]
143+
void granularity_errors() {}
144+
145+
#define TestTableScope \
146+
"DescriptorTable( " \
147+
" UAV(u0, reported_diag), " \
148+
" SRV(t0, skipped_diag), " \
149+
" Sampler(s0, skipped_diag), " \
150+
")," \
151+
"CBV(s0, reported_diag)"
152+
153+
// expected-error@+2 {{invalid parameter of UAV}}
154+
// expected-error@+1 {{invalid parameter of CBV}}
155+
[RootSignature(TestTableScope)]
156+
void recover_scope_errors() {}
157+
107158
// Basic validation of register value and space
108159

109160
// expected-error@+2 {{value must be in the range [0, 4294967294]}}
@@ -138,4 +189,4 @@ void basic_validation_5() {}
138189

139190
// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
140191
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
141-
void basic_validation_6() {}
192+
void basic_validation_6() {}

0 commit comments

Comments
 (0)