Skip to content

Commit e9be528

Browse files
authored
[CIR] Implement NotEqualOp for ComplexType (#146129)
This change adds support for the not equal operation for ComplexType #141365
1 parent 6a97b56 commit e9be528

File tree

5 files changed

+142
-89
lines changed

5 files changed

+142
-89
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,31 +2456,6 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
24562456
let hasFolder = 1;
24572457
}
24582458

2459-
//===----------------------------------------------------------------------===//
2460-
// ComplexEqualOp
2461-
//===----------------------------------------------------------------------===//
2462-
2463-
def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
2464-
2465-
let summary = "Computes whether two complex values are equal";
2466-
let description = [{
2467-
The `complex.equal` op takes two complex numbers and returns whether
2468-
they are equal.
2469-
2470-
```mlir
2471-
%r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
2472-
```
2473-
}];
2474-
2475-
let results = (outs CIR_BoolType:$result);
2476-
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2477-
2478-
let assemblyFormat = [{
2479-
$lhs `,` $rhs
2480-
`:` qualified(type($lhs)) attr-dict
2481-
}];
2482-
}
2483-
24842459
//===----------------------------------------------------------------------===//
24852460
// Assume Operations
24862461
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -907,14 +907,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
907907
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
908908

909909
BinOpInfo boInfo = emitBinOps(e);
910-
if (e->getOpcode() == BO_EQ) {
911-
result =
912-
builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
913-
} else {
914-
assert(!cir::MissingFeatures::complexType());
915-
cgf.cgm.errorNYI(loc, "complex not equal");
916-
result = builder.getBool(false, loc);
917-
}
910+
result = builder.create<cir::CmpOp>(loc, kind, boInfo.lhs, boInfo.rhs);
918911
}
919912

920913
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,6 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
15861586
assert(!cir::MissingFeatures::dataMemberType());
15871587
assert(!cir::MissingFeatures::methodType());
15881588

1589-
// Lower to LLVM comparison op.
15901589
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
15911590
bool isSigned = mlir::isa<cir::IntType>(type)
15921591
? mlir::cast<cir::IntType>(type).isSigned()
@@ -1595,22 +1594,82 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
15951594
convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
15961595
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
15971596
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1598-
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1597+
return mlir::success();
1598+
}
1599+
1600+
if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
15991601
mlir::LLVM::ICmpPredicate kind =
16001602
convertCmpKindToICmpPredicate(cmpOp.getKind(),
16011603
/* isSigned=*/false);
16021604
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
16031605
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1604-
} else if (mlir::isa<cir::FPTypeInterface>(type)) {
1606+
return mlir::success();
1607+
}
1608+
1609+
if (mlir::isa<cir::FPTypeInterface>(type)) {
16051610
mlir::LLVM::FCmpPredicate kind =
16061611
convertCmpKindToFCmpPredicate(cmpOp.getKind());
16071612
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
16081613
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1609-
} else {
1610-
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1614+
return mlir::success();
16111615
}
16121616

1613-
return mlir::success();
1617+
if (mlir::isa<cir::ComplexType>(type)) {
1618+
mlir::Value lhs = adaptor.getLhs();
1619+
mlir::Value rhs = adaptor.getRhs();
1620+
mlir::Location loc = cmpOp.getLoc();
1621+
1622+
auto complexType = mlir::cast<cir::ComplexType>(cmpOp.getLhs().getType());
1623+
mlir::Type complexElemTy =
1624+
getTypeConverter()->convertType(complexType.getElementType());
1625+
1626+
auto lhsReal =
1627+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
1628+
auto lhsImag =
1629+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
1630+
auto rhsReal =
1631+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
1632+
auto rhsImag =
1633+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
1634+
1635+
if (cmpOp.getKind() == cir::CmpOpKind::eq) {
1636+
if (complexElemTy.isInteger()) {
1637+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1638+
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
1639+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1640+
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
1641+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1642+
return mlir::success();
1643+
}
1644+
1645+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1646+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
1647+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1648+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
1649+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1650+
return mlir::success();
1651+
}
1652+
1653+
if (cmpOp.getKind() == cir::CmpOpKind::ne) {
1654+
if (complexElemTy.isInteger()) {
1655+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1656+
loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
1657+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1658+
loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
1659+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1660+
return mlir::success();
1661+
}
1662+
1663+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1664+
loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
1665+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1666+
loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
1667+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1668+
return mlir::success();
1669+
}
1670+
}
1671+
1672+
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
16141673
}
16151674

16161675
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
@@ -1901,7 +1960,6 @@ void ConvertCIRToLLVMPass::runOnOperation() {
19011960
CIRToLLVMCallOpLowering,
19021961
CIRToLLVMCmpOpLowering,
19031962
CIRToLLVMComplexCreateOpLowering,
1904-
CIRToLLVMComplexEqualOpLowering,
19051963
CIRToLLVMComplexImagOpLowering,
19061964
CIRToLLVMComplexRealOpLowering,
19071965
CIRToLLVMConstantOpLowering,
@@ -2245,43 +2303,6 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
22452303
return mlir::success();
22462304
}
22472305

2248-
mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
2249-
cir::ComplexEqualOp op, OpAdaptor adaptor,
2250-
mlir::ConversionPatternRewriter &rewriter) const {
2251-
mlir::Value lhs = adaptor.getLhs();
2252-
mlir::Value rhs = adaptor.getRhs();
2253-
2254-
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2255-
mlir::Type complexElemTy =
2256-
getTypeConverter()->convertType(complexType.getElementType());
2257-
2258-
mlir::Location loc = op.getLoc();
2259-
auto lhsReal =
2260-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2261-
auto lhsImag =
2262-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2263-
auto rhsReal =
2264-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2265-
auto rhsImag =
2266-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2267-
2268-
if (complexElemTy.isInteger()) {
2269-
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2270-
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
2271-
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2272-
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
2273-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2274-
return mlir::success();
2275-
}
2276-
2277-
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2278-
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
2279-
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2280-
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
2281-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2282-
return mlir::success();
2283-
}
2284-
22852306
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
22862307
return std::make_unique<ConvertCIRToLLVMPass>();
22872308
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,6 @@ class CIRToLLVMComplexImagOpLowering
463463
mlir::ConversionPatternRewriter &) const override;
464464
};
465465

466-
class CIRToLLVMComplexEqualOpLowering
467-
: public mlir::OpConversionPattern<cir::ComplexEqualOp> {
468-
public:
469-
using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
470-
471-
mlir::LogicalResult
472-
matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
473-
mlir::ConversionPatternRewriter &) const override;
474-
};
475-
476466
} // namespace direct
477467
} // namespace cir
478468

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ bool foo18(int _Complex a, int _Complex b) {
376376

377377
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
378378
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
379-
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
379+
// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!s32i>, !cir.bool
380380

381381
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
382382
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
@@ -408,7 +408,8 @@ bool foo19(double _Complex a, double _Complex b) {
408408

409409
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
410410
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
411-
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
411+
// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!cir.double>, !cir.bool
412+
412413

413414
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
414415
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
@@ -442,6 +443,79 @@ bool foo19(double _Complex a, double _Complex b) {
442443
// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
443444
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
444445

446+
447+
bool foo20(int _Complex a, int _Complex b) {
448+
return a != b;
449+
}
450+
451+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
452+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
453+
// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!s32i>, !cir.bool
454+
455+
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
456+
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
457+
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
458+
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
459+
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
460+
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
461+
// LLVM: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]]
462+
// LLVM: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]]
463+
// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
464+
465+
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
466+
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
467+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
468+
// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
469+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
470+
// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
471+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
472+
// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
473+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
474+
// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
475+
// OGCG: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]]
476+
// OGCG: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]]
477+
// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
478+
479+
bool foo21(double _Complex a, double _Complex b) {
480+
return a != b;
481+
}
482+
483+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
484+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
485+
// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!cir.double>, !cir.bool
486+
487+
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
488+
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
489+
// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
490+
// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
491+
// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
492+
// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
493+
// LLVM: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]]
494+
// LLVM: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]]
495+
// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
496+
497+
// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
498+
// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
499+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
500+
// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
501+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
502+
// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
503+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
504+
// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
505+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
506+
// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
507+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
508+
// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
509+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
510+
// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
511+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
512+
// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
513+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
514+
// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
515+
// OGCG: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]]
516+
// OGCG: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]]
517+
// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
518+
445519
void foo22(int _Complex a, int _Complex b) {
446520
int _Complex c = (a, b);
447521
}

0 commit comments

Comments
 (0)