Skip to content

Commit 964a930

Browse files
committed
[CIR] Implement NotEqualOp for ComplexType
1 parent e4d8e06 commit 964a930

File tree

5 files changed

+147
-3
lines changed

5 files changed

+147
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,31 @@ def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
24812481
}];
24822482
}
24832483

2484+
//===----------------------------------------------------------------------===//
2485+
// ComplexNotEqualOp
2486+
//===----------------------------------------------------------------------===//
2487+
2488+
def ComplexNotEqualOp : CIR_Op<"complex.neq", [Pure, SameTypeOperands]> {
2489+
2490+
let summary = "Computes whether two complex values are not equal";
2491+
let description = [{
2492+
The `complex.equal` op takes two complex numbers and returns whether
2493+
they are not equal.
2494+
2495+
```mlir
2496+
%r = cir.complex.neq %a, %b : !cir.complex<!cir.float>
2497+
```
2498+
}];
2499+
2500+
let results = (outs CIR_BoolType:$result);
2501+
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2502+
2503+
let assemblyFormat = [{
2504+
$lhs `,` $rhs
2505+
`:` qualified(type($lhs)) attr-dict
2506+
}];
2507+
}
2508+
24842509
//===----------------------------------------------------------------------===//
24852510
// Assume Operations
24862511
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,9 +905,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
905905
result =
906906
builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
907907
} else {
908-
assert(!cir::MissingFeatures::complexType());
909-
cgf.cgm.errorNYI(loc, "complex not equal");
910-
result = builder.getBool(false, loc);
908+
result =
909+
builder.create<cir::ComplexNotEqualOp>(loc, boInfo.lhs, boInfo.rhs);
911910
}
912911
}
913912

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
19031903
CIRToLLVMComplexCreateOpLowering,
19041904
CIRToLLVMComplexEqualOpLowering,
19051905
CIRToLLVMComplexImagOpLowering,
1906+
CIRToLLVMComplexNotEqualOpLowering,
19061907
CIRToLLVMComplexRealOpLowering,
19071908
CIRToLLVMConstantOpLowering,
19081909
CIRToLLVMExpectOpLowering,
@@ -2282,6 +2283,43 @@ mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
22822283
return mlir::success();
22832284
}
22842285

2286+
mlir::LogicalResult CIRToLLVMComplexNotEqualOpLowering::matchAndRewrite(
2287+
cir::ComplexNotEqualOp op, OpAdaptor adaptor,
2288+
mlir::ConversionPatternRewriter &rewriter) const {
2289+
mlir::Value lhs = adaptor.getLhs();
2290+
mlir::Value rhs = adaptor.getRhs();
2291+
2292+
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2293+
mlir::Type complexElemTy =
2294+
getTypeConverter()->convertType(complexType.getElementType());
2295+
2296+
mlir::Location loc = op.getLoc();
2297+
auto lhsReal =
2298+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2299+
auto lhsImag =
2300+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2301+
auto rhsReal =
2302+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2303+
auto rhsImag =
2304+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2305+
2306+
if (complexElemTy.isInteger()) {
2307+
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2308+
loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
2309+
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
2310+
loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
2311+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, realCmp, imagCmp);
2312+
return mlir::success();
2313+
}
2314+
2315+
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2316+
loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
2317+
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
2318+
loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
2319+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, realCmp, imagCmp);
2320+
return mlir::success();
2321+
}
2322+
22852323
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
22862324
return std::make_unique<ConvertCIRToLLVMPass>();
22872325
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,16 @@ class CIRToLLVMComplexEqualOpLowering
473473
mlir::ConversionPatternRewriter &) const override;
474474
};
475475

476+
class CIRToLLVMComplexNotEqualOpLowering
477+
: public mlir::OpConversionPattern<cir::ComplexNotEqualOp> {
478+
public:
479+
using mlir::OpConversionPattern<cir::ComplexNotEqualOp>::OpConversionPattern;
480+
481+
mlir::LogicalResult
482+
matchAndRewrite(cir::ComplexNotEqualOp op, OpAdaptor,
483+
mlir::ConversionPatternRewriter &) const override;
484+
};
485+
476486
} // namespace direct
477487
} // namespace cir
478488

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,75 @@ bool foo19(double _Complex a, double _Complex b) {
442442
// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
443443
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
444444

445+
446+
bool foo20(int _Complex a, int _Complex b) {
447+
return a != b;
448+
}
449+
450+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
451+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
452+
// CIR: %[[RESULT:.*]] = cir.complex.neq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
453+
454+
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
455+
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
456+
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
457+
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
458+
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
459+
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
460+
// LLVM: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]]
461+
// LLVM: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]]
462+
// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
463+
464+
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
465+
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
466+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
467+
// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
468+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
469+
// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
470+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
471+
// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
472+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
473+
// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
474+
// OGCG: %[[CMP_REAL:.*]] = icmp ne i32 %[[A_REAL]], %[[B_REAL]]
475+
// OGCG: %[[CMP_IMAG:.*]] = icmp ne i32 %[[A_IMAG]], %[[B_IMAG]]
476+
// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
477+
478+
bool foo21(double _Complex a, double _Complex b) {
479+
return a != b;
480+
}
481+
482+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
483+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
484+
// CIR: %[[RESULT:.*]] = cir.complex.neq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
485+
486+
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
487+
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
488+
// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
489+
// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
490+
// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
491+
// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
492+
// LLVM: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]]
493+
// LLVM: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]]
494+
// LLVM: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]
495+
496+
// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
497+
// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
498+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
499+
// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
500+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
501+
// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
502+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
503+
// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
504+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
505+
// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
506+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
507+
// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
508+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
509+
// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
510+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
511+
// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
512+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
513+
// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
514+
// OGCG: %[[CMP_REAL:.*]] = fcmp une double %[[A_REAL]], %[[B_REAL]]
515+
// OGCG: %[[CMP_IMAG:.*]] = fcmp une double %[[A_IMAG]], %[[B_IMAG]]
516+
// OGCG: %[[RESULT:.*]] = or i1 %[[CMP_REAL]], %[[CMP_IMAG]]

0 commit comments

Comments
 (0)