diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 4d3ebfb93615d..4afacab80957a 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2456,31 +2456,6 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> { let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// ComplexEqualOp -//===----------------------------------------------------------------------===// - -def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> { - - let summary = "Computes whether two complex values are equal"; - let description = [{ - The `complex.equal` op takes two complex numbers and returns whether - they are equal. - - ```mlir - %r = cir.complex.eq %a, %b : !cir.complex - ``` - }]; - - let results = (outs CIR_BoolType:$result); - let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs); - - let assemblyFormat = [{ - $lhs `,` $rhs - `:` qualified(type($lhs)) attr-dict - }]; -} - //===----------------------------------------------------------------------===// // Assume Operations //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 6eeecca7e7c8f..b76b703a79fe8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -907,14 +907,7 @@ class ScalarExprEmitter : public StmtVisitor { assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE); BinOpInfo boInfo = emitBinOps(e); - if (e->getOpcode() == BO_EQ) { - result = - builder.create(loc, boInfo.lhs, boInfo.rhs); - } else { - assert(!cir::MissingFeatures::complexType()); - cgf.cgm.errorNYI(loc, "complex not equal"); - result = builder.getBool(false, loc); - } + result = builder.create(loc, kind, boInfo.lhs, boInfo.rhs); } return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(), diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 1034b8780c03c..bdc7492d48211 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1586,7 +1586,6 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( assert(!cir::MissingFeatures::dataMemberType()); assert(!cir::MissingFeatures::methodType()); - // Lower to LLVM comparison op. if (mlir::isa(type)) { bool isSigned = mlir::isa(type) ? mlir::cast(type).isSigned() @@ -1595,22 +1594,82 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); rewriter.replaceOpWithNewOp( cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); - } else if (auto ptrTy = mlir::dyn_cast(type)) { + return mlir::success(); + } + + if (auto ptrTy = mlir::dyn_cast(type)) { mlir::LLVM::ICmpPredicate kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), /* isSigned=*/false); rewriter.replaceOpWithNewOp( cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); - } else if (mlir::isa(type)) { + return mlir::success(); + } + + if (mlir::isa(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); rewriter.replaceOpWithNewOp( cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); - } else { - return cmpOp.emitError() << "unsupported type for CmpOp: " << type; + return mlir::success(); } - return mlir::success(); + if (mlir::isa(type)) { + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); + mlir::Location loc = cmpOp.getLoc(); + + auto complexType = mlir::cast(cmpOp.getLhs().getType()); + mlir::Type complexElemTy = + getTypeConverter()->convertType(complexType.getElementType()); + + auto lhsReal = + rewriter.create(loc, complexElemTy, lhs, 0); + auto lhsImag = + rewriter.create(loc, complexElemTy, lhs, 1); + auto rhsReal = + rewriter.create(loc, complexElemTy, rhs, 0); + auto rhsImag = + rewriter.create(loc, complexElemTy, rhs, 1); + + if (cmpOp.getKind() == cir::CmpOpKind::eq) { + if (complexElemTy.isInteger()) { + auto realCmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); + auto imagCmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp(cmpOp, realCmp, imagCmp); + return mlir::success(); + } + + auto realCmp = rewriter.create( + loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); + auto imagCmp = rewriter.create( + loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp(cmpOp, realCmp, imagCmp); + return mlir::success(); + } + + if (cmpOp.getKind() == cir::CmpOpKind::ne) { + if (complexElemTy.isInteger()) { + auto realCmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal); + auto imagCmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp(cmpOp, realCmp, imagCmp); + return mlir::success(); + } + + auto realCmp = rewriter.create( + loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal); + auto imagCmp = rewriter.create( + loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp(cmpOp, realCmp, imagCmp); + return mlir::success(); + } + } + + return cmpOp.emitError() << "unsupported type for CmpOp: " << type; } mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite( @@ -1901,7 +1960,6 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMCallOpLowering, CIRToLLVMCmpOpLowering, CIRToLLVMComplexCreateOpLowering, - CIRToLLVMComplexEqualOpLowering, CIRToLLVMComplexImagOpLowering, CIRToLLVMComplexRealOpLowering, CIRToLLVMConstantOpLowering, @@ -2245,43 +2303,6 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite( return mlir::success(); } -mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite( - cir::ComplexEqualOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - mlir::Value lhs = adaptor.getLhs(); - mlir::Value rhs = adaptor.getRhs(); - - auto complexType = mlir::cast(op.getLhs().getType()); - mlir::Type complexElemTy = - getTypeConverter()->convertType(complexType.getElementType()); - - mlir::Location loc = op.getLoc(); - auto lhsReal = - rewriter.create(loc, complexElemTy, lhs, 0); - auto lhsImag = - rewriter.create(loc, complexElemTy, lhs, 1); - auto rhsReal = - rewriter.create(loc, complexElemTy, rhs, 0); - auto rhsImag = - rewriter.create(loc, complexElemTy, rhs, 1); - - if (complexElemTy.isInteger()) { - auto realCmp = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); - auto imagCmp = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); - rewriter.replaceOpWithNewOp(op, realCmp, imagCmp); - return mlir::success(); - } - - auto realCmp = rewriter.create( - loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); - auto imagCmp = rewriter.create( - loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); - rewriter.replaceOpWithNewOp(op, realCmp, imagCmp); - return mlir::success(); -} - std::unique_ptr createConvertCIRToLLVMPass() { return std::make_unique(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 25cf218cf8b6c..8502cb1ae5d9f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -463,16 +463,6 @@ class CIRToLLVMComplexImagOpLowering mlir::ConversionPatternRewriter &) const override; }; -class CIRToLLVMComplexEqualOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor, - mlir::ConversionPatternRewriter &) const override; -}; - } // namespace direct } // namespace cir diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp index e75ba5e92f99b..78d7a2024490b 100644 --- a/clang/test/CIR/CodeGen/complex.cpp +++ b/clang/test/CIR/CodeGen/complex.cpp @@ -376,7 +376,7 @@ bool foo18(int _Complex a, int _Complex b) { // CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex // CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex -// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex +// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex, !cir.bool // LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 // LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 @@ -408,7 +408,8 @@ bool foo19(double _Complex a, double _Complex b) { // CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex // CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex -// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex +// CIR: %[[RESULT:.*]] = cir.cmp(eq, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex, !cir.bool + // LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8 // LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8 @@ -442,6 +443,79 @@ bool foo19(double _Complex a, double _Complex b) { // OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]] // OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +bool foo20(int _Complex a, int _Complex b) { + return a != b; +} + +// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex +// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex +// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex, !cir.bool + +// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 +// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 +// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0 +// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1 +// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0 +// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1 +// LLVM: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]] +// LLVM: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]] +// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4 +// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4 +// OGCG: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]] +// OGCG: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]] +// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +bool foo21(double _Complex a, double _Complex b) { + return a != b; +} + +// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex +// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr>, !cir.complex +// CIR: %[[RESULT:.*]] = cir.cmp(ne, %[[COMPLEX_A]], %[[COMPLEX_B]]) : !cir.complex, !cir.bool + +// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8 +// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8 +// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0 +// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1 +// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0 +// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1 +// LLVM: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]] +// LLVM: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]] +// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8 +// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8 +// OGCG: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]] +// OGCG: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]] +// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]] + void foo22(int _Complex a, int _Complex b) { int _Complex c = (a, b); }