Skip to content

Commit e8486db

Browse files
committed
Replace ComplexEqual and ComplexNotEqual by CmpOp
1 parent f6097ef commit e8486db

File tree

5 files changed

+73
-163
lines changed

5 files changed

+73
-163
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,56 +2456,6 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
24562456
let hasFolder = 1;
24572457
}
24582458

2459-
//===----------------------------------------------------------------------===//
2460-
// ComplexEqualOp
2461-
//===----------------------------------------------------------------------===//
2462-
2463-
def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
2464-
2465-
let summary = "Computes whether two complex values are equal";
2466-
let description = [{
2467-
The `cir.complex.eq` op takes two complex numbers and returns whether
2468-
they are equal.
2469-
2470-
```mlir
2471-
%r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
2472-
```
2473-
}];
2474-
2475-
let results = (outs CIR_BoolType:$result);
2476-
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2477-
2478-
let assemblyFormat = [{
2479-
$lhs `,` $rhs
2480-
`:` qualified(type($lhs)) attr-dict
2481-
}];
2482-
}
2483-
2484-
//===----------------------------------------------------------------------===//
2485-
// ComplexNotEqualOp
2486-
//===----------------------------------------------------------------------===//
2487-
2488-
def ComplexNotEqualOp : CIR_Op<"complex.neq", [Pure, SameTypeOperands]> {
2489-
2490-
let summary = "Computes whether two complex values are not equal";
2491-
let description = [{
2492-
The `cir.complex.neq` op takes two complex numbers and returns whether
2493-
they are not equal.
2494-
2495-
```mlir
2496-
%r = cir.complex.neq %a, %b : !cir.complex<!cir.float>
2497-
```
2498-
}];
2499-
2500-
let results = (outs CIR_BoolType:$result);
2501-
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2502-
2503-
let assemblyFormat = [{
2504-
$lhs `,` $rhs
2505-
`:` qualified(type($lhs)) attr-dict
2506-
}];
2507-
}
2508-
25092459
//===----------------------------------------------------------------------===//
25102460
// Assume Operations
25112461
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -901,13 +901,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
901901
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
902902

903903
BinOpInfo boInfo = emitBinOps(e);
904-
if (e->getOpcode() == BO_EQ) {
905-
result =
906-
builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
907-
} else {
908-
result =
909-
builder.create<cir::ComplexNotEqualOp>(loc, boInfo.lhs, boInfo.rhs);
910-
}
904+
cir::CmpOpKind opKind =
905+
e->getOpcode() == BO_EQ ? cir::CmpOpKind::eq : cir::CmpOpKind::ne;
906+
result = builder.create<cir::CmpOp>(loc, opKind, boInfo.lhs, boInfo.rhs);
911907
}
912908

913909
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 65 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,6 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
15861586
assert(!cir::MissingFeatures::dataMemberType());
15871587
assert(!cir::MissingFeatures::methodType());
15881588

1589-
// Lower to LLVM comparison op.
15901589
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
15911590
bool isSigned = mlir::isa<cir::IntType>(type)
15921591
? mlir::cast<cir::IntType>(type).isSigned()
@@ -1595,22 +1594,82 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
15951594
convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
15961595
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
15971596
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1598-
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1597+
return mlir::success();
1598+
}
1599+
1600+
if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
15991601
mlir::LLVM::ICmpPredicate kind =
16001602
convertCmpKindToICmpPredicate(cmpOp.getKind(),
16011603
/* isSigned=*/false);
16021604
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
16031605
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1604-
} else if (mlir::isa<cir::FPTypeInterface>(type)) {
1606+
return mlir::success();
1607+
}
1608+
1609+
if (mlir::isa<cir::FPTypeInterface>(type)) {
16051610
mlir::LLVM::FCmpPredicate kind =
16061611
convertCmpKindToFCmpPredicate(cmpOp.getKind());
16071612
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
16081613
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1609-
} else {
1610-
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1614+
return mlir::success();
16111615
}
16121616

1613-
return mlir::success();
1617+
if (mlir::isa<cir::ComplexType>(type)) {
1618+
mlir::Value lhs = adaptor.getLhs();
1619+
mlir::Value rhs = adaptor.getRhs();
1620+
mlir::Location loc = cmpOp.getLoc();
1621+
1622+
auto complexType = mlir::cast<cir::ComplexType>(cmpOp.getLhs().getType());
1623+
mlir::Type complexElemTy =
1624+
getTypeConverter()->convertType(complexType.getElementType());
1625+
1626+
auto lhsReal =
1627+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
1628+
auto lhsImag =
1629+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
1630+
auto rhsReal =
1631+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
1632+
auto rhsImag =
1633+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
1634+
1635+
if (cmpOp.getKind() == cir::CmpOpKind::eq) {
1636+
if (complexElemTy.isInteger()) {
1637+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1638+
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
1639+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1640+
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
1641+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1642+
return mlir::success();
1643+
}
1644+
1645+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1646+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
1647+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1648+
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
1649+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1650+
return mlir::success();
1651+
}
1652+
1653+
if (cmpOp.getKind() == cir::CmpOpKind::ne) {
1654+
if (complexElemTy.isInteger()) {
1655+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1656+
loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
1657+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1658+
loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
1659+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1660+
return mlir::success();
1661+
}
1662+
1663+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1664+
loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
1665+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1666+
loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
1667+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1668+
return mlir::success();
1669+
}
1670+
}
1671+
1672+
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
16141673
}
16151674

16161675
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
@@ -1901,9 +1960,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
19011960
CIRToLLVMCallOpLowering,
19021961
CIRToLLVMCmpOpLowering,
19031962
CIRToLLVMComplexCreateOpLowering,
1904-
CIRToLLVMComplexEqualOpLowering,
19051963
CIRToLLVMComplexImagOpLowering,
1906-
CIRToLLVMComplexNotEqualOpLowering,
19071964
CIRToLLVMComplexRealOpLowering,
19081965
CIRToLLVMConstantOpLowering,
19091966
CIRToLLVMExpectOpLowering,
@@ -2246,80 +2303,6 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
22462303
return mlir::success();
22472304
}
22482305

2249-
mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
2250-
cir::ComplexEqualOp op, OpAdaptor adaptor,
2251-
mlir::ConversionPatternRewriter &rewriter) const {
2252-
mlir::Value lhs = adaptor.getLhs();
2253-
mlir::Value rhs = adaptor.getRhs();
2254-
2255-
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2256-
mlir::Type complexElemTy =
2257-
getTypeConverter()->convertType(complexType.getElementType());
2258-
2259-
mlir::Location loc = op.getLoc();
2260-
auto lhsReal =
2261-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2262-
auto lhsImag =
2263-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2264-
auto rhsReal =
2265-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2266-
auto rhsImag =
2267-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2268-
2269-
if (complexElemTy.isInteger()) {
2270-
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2271-
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
2272-
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2273-
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
2274-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2275-
return mlir::success();
2276-
}
2277-
2278-
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2279-
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
2280-
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2281-
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
2282-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
2283-
return mlir::success();
2284-
}
2285-
2286-
mlir::LogicalResult CIRToLLVMComplexNotEqualOpLowering::matchAndRewrite(
2287-
cir::ComplexNotEqualOp op, OpAdaptor adaptor,
2288-
mlir::ConversionPatternRewriter &rewriter) const {
2289-
mlir::Value lhs = adaptor.getLhs();
2290-
mlir::Value rhs = adaptor.getRhs();
2291-
2292-
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2293-
mlir::Type complexElemTy =
2294-
getTypeConverter()->convertType(complexType.getElementType());
2295-
2296-
mlir::Location loc = op.getLoc();
2297-
auto lhsReal =
2298-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2299-
auto lhsImag =
2300-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2301-
auto rhsReal =
2302-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2303-
auto rhsImag =
2304-
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2305-
2306-
if (complexElemTy.isInteger()) {
2307-
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2308-
loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
2309-
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2310-
loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
2311-
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, realCmp, imagCmp);
2312-
return mlir::success();
2313-
}
2314-
2315-
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2316-
loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
2317-
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2318-
loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
2319-
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, realCmp, imagCmp);
2320-
return mlir::success();
2321-
}
2322-
23232306
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
23242307
return std::make_unique<ConvertCIRToLLVMPass>();
23252308
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -463,26 +463,6 @@ class CIRToLLVMComplexImagOpLowering
463463
mlir::ConversionPatternRewriter &) const override;
464464
};
465465

466-
class CIRToLLVMComplexEqualOpLowering
467-
: public mlir::OpConversionPattern<cir::ComplexEqualOp> {
468-
public:
469-
using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
470-
471-
mlir::LogicalResult
472-
matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
473-
mlir::ConversionPatternRewriter &) const override;
474-
};
475-
476-
class CIRToLLVMComplexNotEqualOpLowering
477-
: public mlir::OpConversionPattern<cir::ComplexNotEqualOp> {
478-
public:
479-
using mlir::OpConversionPattern<cir::ComplexNotEqualOp>::OpConversionPattern;
480-
481-
mlir::LogicalResult
482-
matchAndRewrite(cir::ComplexNotEqualOp op, OpAdaptor,
483-
mlir::ConversionPatternRewriter &) const override;
484-
};
485-
486466
} // namespace direct
487467
} // namespace cir
488468

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ bool foo18(int _Complex a, int _Complex b) {
376376

377377
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
378378
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
379-
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
379+
// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!s32i>, !cir.bool
380380

381381
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
382382
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
@@ -408,7 +408,8 @@ bool foo19(double _Complex a, double _Complex b) {
408408

409409
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
410410
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
411-
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
411+
// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!cir.double>, !cir.bool
412+
412413

413414
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
414415
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
@@ -449,7 +450,7 @@ bool foo20(int _Complex a, int _Complex b) {
449450

450451
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
451452
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
452-
// CIR: %[[RESULT:.*]] = cir.complex.neq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
453+
// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!s32i>, !cir.bool
453454

454455
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
455456
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
@@ -481,7 +482,7 @@ bool foo21(double _Complex a, double _Complex b) {
481482

482483
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
483484
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
484-
// CIR: %[[RESULT:.*]] = cir.complex.neq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
485+
// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex<!cir.double>, !cir.bool
485486

486487
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
487488
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8

0 commit comments

Comments
 (0)