@@ -819,65 +819,180 @@ Expression *ExpressionUtils::rewriteRelExpr(const Expression *expr) {
819819 return simplifiedExpr;
820820 }
821821 // Move all evaluable expression to the right side
822- auto relRightOperandExpr = relExpr->right ()->clone ();
823- auto relLeftOperandExpr = rewriteRelExprHelper (relExpr->left (), relRightOperandExpr);
824- return RelationalExpression::makeKind (
825- pool, relExpr->kind (), relLeftOperandExpr->clone (), relRightOperandExpr->clone ());
822+ return rewriteRelExprHelper (pool, relExpr);
823+ // auto relRightOperandExpr = relExpr->right()->clone();
824+ // auto relLeftOperandExpr = rewriteRelExprHelper(relExpr->left(), relRightOperandExpr);
825+ // return RelationalExpression::makeKind(
826+ // pool, relExpr->kind(), relLeftOperandExpr->clone(), relRightOperandExpr->clone());
826827 };
827828
828829 return RewriteVisitor::transform (expr, matcher, rewriter);
829830}
830831
831- Expression *ExpressionUtils::rewriteRelExprHelper (const Expression *expr,
832- Expression *&relRightOperandExpr) {
833- ObjectPool *pool = expr->getObjPool ();
834- // TODO: Support rewrite mul/div expression after fixing overflow
835- auto matcher = [](const Expression *e) -> bool {
836- if (!e->isArithmeticExpr () || e->kind () == Expression::Kind::kMultiply ||
837- e->kind () == Expression::Kind::kDivision )
838- return false ;
839- auto arithExpr = static_cast <const ArithmeticExpression *>(e);
832+ static std::optional<Expression::Kind> invertRelOp (Expression::Kind kind) {
833+ switch (kind) {
834+ case Expression::Kind::kRelLT :
835+ return Expression::Kind::kRelGT ;
836+ case Expression::Kind::kRelLE :
837+ return Expression::Kind::kRelGE ;
838+ case Expression::Kind::kRelGT :
839+ return Expression::Kind::kRelLT ;
840+ case Expression::Kind::kRelGE :
841+ return Expression::Kind::kRelLE ;
842+ case Expression::Kind::kRelEQ :
843+ return Expression::Kind::kRelEQ ;
844+ case Expression::Kind::kRelNE :
845+ return Expression::Kind::kRelNE ;
846+ default :
847+ // we don't handle this cases
848+ return std::nullopt ;
849+ }
850+ }
840851
841- return ExpressionUtils::isEvaluableExpr (arithExpr->left ()) ||
842- ExpressionUtils::isEvaluableExpr (arithExpr->right ());
843- };
852+ static Expression *rewriteArith (ObjectPool *pool, const Expression *expr, Expression *&rhs) {
853+ if (!expr->isArithmeticExpr ()) {
854+ return const_cast <Expression *>(expr);
855+ }
844856
845- if (!matcher (expr)) {
857+ auto *arith = static_cast <const ArithmeticExpression *>(expr);
858+ auto k = arith->kind ();
859+ // only support add and minus for now
860+ if (k != Expression::Kind::kAdd && k != Expression::Kind::kMinus ) {
846861 return const_cast <Expression *>(expr);
847862 }
848863
849- auto arithExpr = static_cast <const ArithmeticExpression *>(expr);
850- auto kind = getNegatedArithmeticType (arithExpr->kind ());
851- auto lexpr = relRightOperandExpr->clone ();
852- const Expression *root = nullptr ;
853- Expression *rexpr = nullptr ;
864+ auto *l = arith->left ();
865+ auto *r = arith->right ();
866+ bool le = ExpressionUtils::isEvaluableExpr (l);
867+ bool re = ExpressionUtils::isEvaluableExpr (r);
854868
855- // Use left operand as root
856- if (ExpressionUtils::isEvaluableExpr (arithExpr->right ())) {
857- rexpr = arithExpr->right ()->clone ();
858- root = arithExpr->left ();
859- } else {
860- rexpr = arithExpr->left ()->clone ();
861- root = arithExpr->right ();
869+ // finish remove if both sides are not evaluable
870+ if (!le && !re) {
871+ return const_cast <Expression *>(expr);
862872 }
863- switch (kind) {
864- case Expression::Kind::kAdd :
865- relRightOperandExpr = ArithmeticExpression::makeAdd (pool, lexpr, rexpr);
866- break ;
867- case Expression::Kind::kMinus :
868- relRightOperandExpr = ArithmeticExpression::makeMinus (pool, lexpr, rexpr);
869- break ;
870- // Unsupported arithmetic kind
871- // case Expression::Kind::kMultiply:
872- // case Expression::Kind::kDivision:
873- default :
874- DLOG (ERROR) << " Unsupported expression kind: " << static_cast <uint8_t >(kind);
875- break ;
873+
874+ if (k == Expression::Kind::kAdd ) {
875+ if (le && !re) {
876+ // swap to make sure r is evaluable
877+ std::swap (l, r);
878+ std::swap (le, re);
879+ }
880+ DCHECK (re && ExpressionUtils::isEvaluableExpr (r));
881+ // a + C <op> rhs --> a <op> rhs - C
882+ rhs = ArithmeticExpression::makeMinus (pool, rhs->clone (), r->clone ());
883+ // nested rewrite the left operand
884+ return rewriteArith (pool, l, rhs);
885+ }
886+
887+ if (k == Expression::Kind::kMinus ) {
888+ if (re) {
889+ // a - C <op> rhs --> a <op> rhs + C
890+ rhs = ArithmeticExpression::makeAdd (pool, rhs->clone (), r->clone ());
891+ // nested rewrite the left operand
892+ return rewriteArith (pool, l, rhs);
893+ }
876894 }
877895
878- return rewriteRelExprHelper (root, relRightOperandExpr );
896+ return const_cast <Expression *>(expr );
879897}
880898
899+ Expression *ExpressionUtils::rewriteRelExprHelper (ObjectPool *pool, const Expression *expr) {
900+ if (!expr->isRelExpr ()) {
901+ return const_cast <Expression *>(expr);
902+ }
903+
904+ const RelationalExpression *relExpr = static_cast <const RelationalExpression *>(expr);
905+ auto *lhs = relExpr->left ();
906+ auto *rhs = relExpr->right ();
907+ auto relKind = relExpr->kind ();
908+
909+ // specially handle C - a <op> rhs case
910+ if (lhs->isArithmeticExpr ()) {
911+ auto *arith = static_cast <const ArithmeticExpression *>(lhs);
912+ if (arith->kind () == Expression::Kind::kMinus &&
913+ ExpressionUtils::isEvaluableExpr (arith->left ()) &&
914+ !ExpressionUtils::isEvaluableExpr (arith->right ())) {
915+ // C - a <op> rhs --> a <invert op> C - rhs
916+ auto *constantPart = arith->left ()->clone ();
917+ auto res = invertRelOp (relKind);
918+ if (!res.has_value ()) {
919+ return const_cast <Expression *>(expr);
920+ }
921+ relKind = invertRelOp (relKind).value ();
922+ // let the variable part be negative and move it to the left side of the relational expression
923+ lhs = arith->right ();
924+ rhs = ArithmeticExpression::makeMinus (pool, constantPart, rhs->clone ());
925+ }
926+ }
927+
928+ // swap the left and right if left is evaluable but right is not
929+ if (ExpressionUtils::isEvaluableExpr (lhs) && !ExpressionUtils::isEvaluableExpr (rhs)) {
930+ std::swap (lhs, rhs);
931+ auto res = invertRelOp (relKind);
932+ if (res.has_value ()) {
933+ relKind = res.value ();
934+ } else {
935+ return const_cast <Expression *>(expr);
936+ }
937+ }
938+
939+ Expression *re = rhs->clone ();
940+ // move evaluable part from left to right, rewrite the right in place and return the new left
941+ Expression *le = rewriteArith (pool, lhs, re);
942+
943+ return RelationalExpression::makeKind (pool, relKind, le->clone (), re->clone ());
944+ }
945+
946+ // Expression *ExpressionUtils::rewriteRelExprHelper(const Expression *expr,
947+ // Expression *&relRightOperandExpr) {
948+ // ObjectPool *pool = expr->getObjPool();
949+ // // TODO: Support rewrite mul/div expression after fixing overflow
950+ // auto matcher = [](const Expression *e) -> bool {
951+ // if (!e->isArithmeticExpr() || e->kind() == Expression::Kind::kMultiply ||
952+ // e->kind() == Expression::Kind::kDivision)
953+ // return false;
954+ // auto arithExpr = static_cast<const ArithmeticExpression *>(e);
955+
956+ // return ExpressionUtils::isEvaluableExpr(arithExpr->left()) ||
957+ // ExpressionUtils::isEvaluableExpr(arithExpr->right());
958+ // };
959+
960+ // if (!matcher(expr)) {
961+ // return const_cast<Expression *>(expr);
962+ // }
963+
964+ // auto arithExpr = static_cast<const ArithmeticExpression *>(expr);
965+ // auto kind = getNegatedArithmeticType(arithExpr->kind());
966+ // auto lexpr = relRightOperandExpr->clone();
967+ // const Expression *root = nullptr;
968+ // Expression *rexpr = nullptr;
969+
970+ // // Use left operand as root
971+ // if (ExpressionUtils::isEvaluableExpr(arithExpr->right())) {
972+ // rexpr = arithExpr->right()->clone();
973+ // root = arithExpr->left();
974+ // } else {
975+ // rexpr = arithExpr->left()->clone();
976+ // root = arithExpr->right();
977+ // }
978+ // switch (kind) {
979+ // case Expression::Kind::kAdd:
980+ // relRightOperandExpr = ArithmeticExpression::makeAdd(pool, lexpr, rexpr);
981+ // break;
982+ // case Expression::Kind::kMinus:
983+ // relRightOperandExpr = ArithmeticExpression::makeMinus(pool, lexpr, rexpr);
984+ // break;
985+ // // Unsupported arithmetic kind
986+ // // case Expression::Kind::kMultiply:
987+ // // case Expression::Kind::kDivision:
988+ // default:
989+ // DLOG(ERROR) << "Unsupported expression kind: " << static_cast<uint8_t>(kind);
990+ // break;
991+ // }
992+
993+ // return rewriteRelExprHelper(root, relRightOperandExpr);
994+ // }
995+
881996StatusOr<Expression *> ExpressionUtils::filterTransform (const Expression *filter) {
882997 // Check if any overflow happen before filter transform
883998 auto initialConstFold = foldConstantExpr (filter);
0 commit comments