Skip to content

Commit 706cc11

Browse files
committed
[CIR] Upstream VectorType support in helper function
1 parent 459de73 commit 706cc11

File tree

6 files changed

+176
-15
lines changed

6 files changed

+176
-15
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
3131
: Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
3232
summary, type.cppType>;
3333

34+
// Generates a type summary.
35+
// - For a single type: returns its summary.
36+
// - For multiple types: returns `any of <comma-separated summaries>`.
37+
class CIR_TypeSummaries<list<Type> types> {
38+
assert !not(!empty(types)), "expects non-empty list of types";
39+
40+
list<string> summaries = !foreach(type, types, type.summary);
41+
string joined = !interleave(summaries, ", ");
42+
43+
string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
44+
}
45+
3446
//===----------------------------------------------------------------------===//
3547
// Bool Type predicates
3648
//===----------------------------------------------------------------------===//
@@ -184,6 +196,8 @@ def CIR_PtrToVoidPtrType
184196
// Vector Type predicates
185197
//===----------------------------------------------------------------------===//
186198

199+
def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
200+
187201
// Vector of integral type
188202
def IntegerVector : Type<
189203
And<[
@@ -211,4 +225,27 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
211225
let cppFunctionName = "isScalarType";
212226
}
213227

228+
//===----------------------------------------------------------------------===//
229+
// Element type constraint bases
230+
//===----------------------------------------------------------------------===//
231+
232+
class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
233+
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
234+
235+
class CIR_VectorTypeOf<list<Type> types, string summary = "">
236+
: CIR_ConfinedType<CIR_AnyVectorType,
237+
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
238+
!if(!empty(summary),
239+
"vector of " # CIR_TypeSummaries<types>.value,
240+
summary)>;
241+
242+
// Vector of type constraints
243+
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
244+
245+
def CIR_AnyFloatOrVecOfFloatType
246+
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
247+
"floating point or vector of floating point type"> {
248+
let cppFunctionName = "isFPOrVectorOfFPType";
249+
}
250+
214251
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Types.h"
1818
#include "mlir/Interfaces/DataLayoutInterfaces.h"
19+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
1920
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"
2021

2122
namespace cir {
@@ -26,8 +27,6 @@ struct RecordTypeStorage;
2627

2728
bool isValidFundamentalIntWidth(unsigned width);
2829

29-
bool isFPOrFPVectorTy(mlir::Type);
30-
3130
} // namespace cir
3231

3332
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,7 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &ops) {
13111311
!canElideOverflowCheck(cgf.getContext(), ops))
13121312
cgf.cgm.errorNYI("unsigned int overflow sanitizer");
13131313

1314-
if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
1314+
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
13151315
assert(!cir::MissingFeatures::cgFPOptionsRAII());
13161316
return builder.createFMul(loc, ops.lhs, ops.rhs);
13171317
}
@@ -1370,7 +1370,7 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &ops) {
13701370
!canElideOverflowCheck(cgf.getContext(), ops))
13711371
cgf.cgm.errorNYI("unsigned int overflow sanitizer");
13721372

1373-
if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
1373+
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
13741374
assert(!cir::MissingFeatures::cgFPOptionsRAII());
13751375
return builder.createFAdd(loc, ops.lhs, ops.rhs);
13761376
}
@@ -1418,7 +1418,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &ops) {
14181418
!canElideOverflowCheck(cgf.getContext(), ops))
14191419
cgf.cgm.errorNYI("unsigned int overflow sanitizer");
14201420

1421-
if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
1421+
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
14221422
assert(!cir::MissingFeatures::cgFPOptionsRAII());
14231423
return builder.createFSub(loc, ops.lhs, ops.rhs);
14241424
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -552,15 +552,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
552552
.getABIAlignment(dataLayout, params);
553553
}
554554

555-
//===----------------------------------------------------------------------===//
556-
// Floating-point and Float-point Vector type helpers
557-
//===----------------------------------------------------------------------===//
558-
559-
bool cir::isFPOrFPVectorTy(mlir::Type t) {
560-
assert(!cir::MissingFeatures::vectorType());
561-
return isAnyFloatingPointType(t);
562-
}
563-
564555
//===----------------------------------------------------------------------===//
565556
// FuncType Definitions
566557
//===----------------------------------------------------------------------===//

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,3 +1091,70 @@ void foo17() {
10911091
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10921092
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10931093
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1094+
1095+
void foo19() {
1096+
vf4 a;
1097+
vf4 b;
1098+
1099+
vf4 c = a + b;
1100+
vf4 d = a - b;
1101+
vf4 e = a * b;
1102+
vf4 f = a / b;
1103+
}
1104+
1105+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>, ["a"]
1106+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>, ["b"]
1107+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1108+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1109+
// CIR: %[[ADD:.*]] = cir.binop(add, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1110+
// CIR: cir.store{{.*}} %[[ADD]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1111+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1112+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1113+
// CIR: %[[SUB:.*]] = cir.binop(sub, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1114+
// CIR: cir.store{{.*}} %[[SUB]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1115+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1116+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1117+
// CIR: %[[MUL:.*]] = cir.binop(mul, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1118+
// CIR: cir.store{{.*}} %[[MUL]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1119+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1120+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1121+
// CIR: %[[DIV:.*]] = cir.binop(div, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1122+
// CIR: cir.store{{.*}} %[[DIV]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1123+
1124+
// LLVM: %[[VEC_A:.*]] = alloca <4 x float>, i64 1, align 16
1125+
// LLVM: %[[VEC_B:.*]] = alloca <4 x float>, i64 1, align 16
1126+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1127+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1128+
// LLVM: %[[ADD:.*]] = fadd <4 x float> %[[TMP_A]], %[[TMP_B]]
1129+
// LLVM: store <4 x float> %[[ADD]], ptr {{.*}}, align 16
1130+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1131+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1132+
// LLVM: %[[SUB:.*]] = fsub <4 x float> %[[TMP_A]], %[[TMP_B]]
1133+
// LLVM: store <4 x float> %[[SUB]], ptr {{.*}}, align 16
1134+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1135+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1136+
// LLVM: %[[MUL:.*]] = fmul <4 x float> %[[TMP_A]], %[[TMP_B]]
1137+
// LLVM: store <4 x float> %[[MUL]], ptr {{.*}}, align 16
1138+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1139+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1140+
// LLVM: %[[DIV:.*]] = fdiv <4 x float> %[[TMP_A]], %[[TMP_B]]
1141+
// LLVM: store <4 x float> %[[DIV]], ptr {{.*}}, align 16
1142+
1143+
// OGCG: %[[VEC_A:.*]] = alloca <4 x float>, align 16
1144+
// OGCG: %[[VEC_B:.*]] = alloca <4 x float>, align 16
1145+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1146+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1147+
// OGCG: %[[ADD:.*]] = fadd <4 x float> %[[TMP_A]], %[[TMP_B]]
1148+
// OGCG: store <4 x float> %[[ADD]], ptr {{.*}}, align 16
1149+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1150+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1151+
// OGCG: %[[SUB:.*]] = fsub <4 x float> %[[TMP_A]], %[[TMP_B]]
1152+
// OGCG: store <4 x float> %[[SUB]], ptr {{.*}}, align 16
1153+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1154+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1155+
// OGCG: %[[MUL:.*]] = fmul <4 x float> %[[TMP_A]], %[[TMP_B]]
1156+
// OGCG: store <4 x float> %[[MUL]], ptr {{.*}}, align 16
1157+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1158+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1159+
// OGCG: %[[DIV:.*]] = fdiv <4 x float> %[[TMP_A]], %[[TMP_B]]
1160+
// OGCG: store <4 x float> %[[DIV]], ptr {{.*}}, align 16

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,4 +1069,71 @@ void foo17() {
10691069

10701070
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1072-
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1072+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1073+
1074+
void foo19() {
1075+
vf4 a;
1076+
vf4 b;
1077+
1078+
vf4 c = a + b;
1079+
vf4 d = a - b;
1080+
vf4 e = a * b;
1081+
vf4 f = a / b;
1082+
}
1083+
1084+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>, ["a"]
1085+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>, ["b"]
1086+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1087+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1088+
// CIR: %[[ADD:.*]] = cir.binop(add, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1089+
// CIR: cir.store{{.*}} %[[ADD]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1090+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1091+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1092+
// CIR: %[[SUB:.*]] = cir.binop(sub, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1093+
// CIR: cir.store{{.*}} %[[SUB]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1094+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1095+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1096+
// CIR: %[[MUL:.*]] = cir.binop(mul, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1097+
// CIR: cir.store{{.*}} %[[MUL]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1098+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1099+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !cir.float>>, !cir.vector<4 x !cir.float>
1100+
// CIR: %[[DIV:.*]] = cir.binop(div, %[[TMP_A]], %[[TMP_B]]) : !cir.vector<4 x !cir.float>
1101+
// CIR: cir.store{{.*}} %[[DIV]], {{.*}} : !cir.vector<4 x !cir.float>, !cir.ptr<!cir.vector<4 x !cir.float>>
1102+
1103+
// LLVM: %[[VEC_A:.*]] = alloca <4 x float>, i64 1, align 16
1104+
// LLVM: %[[VEC_B:.*]] = alloca <4 x float>, i64 1, align 16
1105+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1106+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1107+
// LLVM: %[[ADD:.*]] = fadd <4 x float> %[[TMP_A]], %[[TMP_B]]
1108+
// LLVM: store <4 x float> %[[ADD]], ptr {{.*}}, align 16
1109+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1110+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1111+
// LLVM: %[[SUB:.*]] = fsub <4 x float> %[[TMP_A]], %[[TMP_B]]
1112+
// LLVM: store <4 x float> %[[SUB]], ptr {{.*}}, align 16
1113+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1114+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1115+
// LLVM: %[[MUL:.*]] = fmul <4 x float> %[[TMP_A]], %[[TMP_B]]
1116+
// LLVM: store <4 x float> %[[MUL]], ptr {{.*}}, align 16
1117+
// LLVM: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1118+
// LLVM: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1119+
// LLVM: %[[DIV:.*]] = fdiv <4 x float> %[[TMP_A]], %[[TMP_B]]
1120+
// LLVM: store <4 x float> %[[DIV]], ptr {{.*}}, align 16
1121+
1122+
// OGCG: %[[VEC_A:.*]] = alloca <4 x float>, align 16
1123+
// OGCG: %[[VEC_B:.*]] = alloca <4 x float>, align 16
1124+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1125+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1126+
// OGCG: %[[ADD:.*]] = fadd <4 x float> %[[TMP_A]], %[[TMP_B]]
1127+
// OGCG: store <4 x float> %[[ADD]], ptr {{.*}}, align 16
1128+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1129+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1130+
// OGCG: %[[SUB:.*]] = fsub <4 x float> %[[TMP_A]], %[[TMP_B]]
1131+
// OGCG: store <4 x float> %[[SUB]], ptr {{.*}}, align 16
1132+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1133+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1134+
// OGCG: %[[MUL:.*]] = fmul <4 x float> %[[TMP_A]], %[[TMP_B]]
1135+
// OGCG: store <4 x float> %[[MUL]], ptr {{.*}}, align 16
1136+
// OGCG: %[[TMP_A:.*]] = load <4 x float>, ptr %[[VEC_A]], align 16
1137+
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
1138+
// OGCG: %[[DIV:.*]] = fdiv <4 x float> %[[TMP_A]], %[[TMP_B]]
1139+
// OGCG: store <4 x float> %[[DIV]], ptr {{.*}}, align 16

0 commit comments

Comments
 (0)