Skip to content

Quote pattern matching runtime spec #8687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 200 additions & 28 deletions library/src/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
@@ -4,6 +4,135 @@ import scala.annotation.internal.sharable

import scala.quoted._

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
*
*
* Semantics:
*
* We use `'{..}` for expression, `'[..]` for types and `⟨..⟩` for patterns nested in expressions.
* The semantics are defined as a list of reduction rules that are tried one by one until one matches.
*
* Operations:
* - `s =?= p` checks if a scrutinee `s` matches the pattern `p` while accumulating extracted parts of the code.
* - `isColosedUnder(x1, .., xn)('{e})` returns true if and only if all the references in `e` to names defined in the patttern are contained in the set `{x1, ... xn}`.
* - `lift(x1, .., xn)('{e})` returns `(y1, ..., yn) => [xi = $yi]'{e}` where `yi` is an `Expr` of the type of `xi`.
* - `withEnv(x1 -> y1, ..., xn -> yn)(matching)` evaluates mathing recording that `xi` is equivalent to `yi`.
* - `matched` denotes that the the match succedded and `matched('{e})` denotes that a matech succeded and extracts `'{e}`
* - `&&&` matches if both sides match. Concatenates the extracted expressions of both sides.
*
* Note: that not all quoted terms bellow are valid expressions
*
* ```scala
* /* Term hole */
* '{ e } =?= '{ hole[T] } && typeOf('{e}) <:< T && isColosedUnder()('{e}) ===> matched('{e})
*
* /* Higher order term hole */
* '{ e } =?= '{ hole[(T1, ..., Tn) => T](x1, ..., xn) } && isColosedUnder(x1, ... xn)('{e}) ===> matched(lift(x1, ..., xn)('{e}))
*
* /* Match literal */
* '{ lit } =?= '{ lit } ===> matched
*
* /* Match type ascription (a) */
* '{ e: T } =?= '{ p } ===> '{e} =?= '{p}
*
* /* Match type ascription (b) */
* '{ e } =?= '{ p: P } ===> '{e} =?= '{p}
*
* /* Match selection */
* '{ e.x } =?= '{ p.x } ===> '{e} =?= '{p}
*
* /* Match reference */
* '{ x } =?= '{ x } ===> matched
*
* /* Match application */
* '{e0(e1, ..., en)} =?= '{p0(p1, ..., p2)} ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& ... %% '{en} =?= '{pn}
*
* /* Match type application */
* '{e[T1, ..., Tn]} =?= '{p[P1, ..., Pn]} ===> '{e} =?= '{p} &&& '[T1] =?= '{P1} &&& ... %% '[Tn] =?= '[Pn]
*
* /* Match block flattening */
* '{ {e0; e1; ...; en}; em } =?= '{ {p0; p1; ...; pm}; em } ===> '{ e0; {e1; ...; en; em} } =?= '{ p0; {p1; ...; pm; em} }
*
* /* Match block */
* '{ e1; e2 } =?= '{ p1; p2 } ===> '{e1} =?= '{p1} &&& '{e2} =?= '{p2}
*
* /* Match def block */
* '{ e1; e2 } =?= '{ p1; p2 } ===> withEnv(symOf(e1) -> symOf(p1))('{e1} =?= '{p1} &&& '{e2} =?= '{p2})
*
* /* Match if */
* '{ if e0 then e1 else e2 } =?= '{ if p0 then p1 else p2 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}
*
* /* Match while */
* '{ while e0 do e1 } =?= '{ while p0 do p1 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1}
*
* /* Match assign */
* '{ e0 = e1 } =?= '{ p0 = p1 } && '{e0} =?= '{p0} ===> '{e1} =?= '{p1}
*
* /* Match new */
* '{ new T } =?= '{ new T } ===> matched
*
* /* Match this */
* '{ C.this } =?= '{ C.this } ===> matched
*
* /* Match super */
* '{ e.super } =?= '{ p.super } ===> '{e} =?= '{p}
*
* /* Match varargs */
* '{ e: _* } =?= '{ p: _* } ===> '{e} =?= '{p}
*
* /* Match val */
* '{ val x: T = e1; e2 } =?= '{ val y: P = p1; p2 } ===> withEnv(x -> y)('[T] =?= '[P] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2})
*
* /* Match def */
* '{ def x0(x1: T1, ..., xn: Tn): T0 = e1; e2 } =?= '{ def y0(y1: P1, ..., yn: Pn): P0 = p1; p2 } ===> withEnv(x0 -> y0, ..., xn -> yn)('[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2})
*
* /* Match match */
* '{ e0 match { case u1 => e1; ...; case un => en } } =?= '{ p0 match { case q1 => p1; ...; case qn => pn } } ===>
* '{e0} =?= '{p0} &&& ... &&& '{en} =?= '{pn} &&& '⟨u1⟩ =?= '⟨q1⟩ &&& ... &&& '⟨un⟩ =?= '⟨qn⟩
*
* /* Match try */
* '{ try e0 catch { case u1 => e1; ...; case un => en } finally ef } =?= '{ try p0 catch { case q1 => p1; ...; case qn => pn } finally pf } ===> '{e0} =?= '{p0} &&& ... &&& '{en} =?= '{pn} &&& '⟨u1⟩ =?= '⟨q1⟩ &&& ... &&& '⟨un⟩ =?= '⟨qn⟩ &&& '{ef} =?= '{pf}
*
* // Types
*
* /* Match type */
* '[T] =?= '[P] && T <:< P ===> matched
*
* /* Match applied type */
* '[ T0[T1, ..., Tn] ] =?= '[ P0[P1, ..., Pn] ] ===> '[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn]
*
* /* Match annot (a) */
* '[T @annot] =?= '[P] ===> '[T] =?= '[P]
*
* /* Match annot (b) */
* '[T] =?= '[P @annot] ===> '[T] =?= '[P]
*
* // Patterns
*
* /* Match pattern whildcard */
* '⟨ _ ⟩ =?= '⟨ _ ⟩ ===> matched
*
* /* Match pattern bind */
* '⟨ x @ e ⟩ =?= '⟨ y @ p ⟩ ===> withEnv(x -> y)('⟨e⟩ =?= '⟨p⟩)
*
* /* Match pattern unapply */
* '⟨ e0(e1, ..., en)(using i1, ..., im ) ⟩ =?= '⟨ p0(p1, ..., pn)(using q1, ..., 1m) ⟩ ===> '⟨e0⟩ =?= '⟨p0⟩ &&& ... &&& '⟨en⟩ =?= '⟨pn⟩ &&& '{i1} =?= '{q1} &&& ... &&& '{im} =?= '{qm}
*
* /* Match pattern alternatives */
* '⟨ e1 | ... | en ⟩ =?= '⟨ p1 | ... | pn ⟩ ===> '⟨e1⟩ =?= '⟨p1⟩ &&& ... &&& '⟨en⟩ =?= '⟨pn⟩
*
* /* Match pattern type test */
* '⟨ e: T ⟩ =?= '⟨ p: U ⟩ ===> '⟨e⟩ =?= '⟨p⟩ &&& '[T] =?= [U]
*
* /* Match pattern ref */
* '⟨ `x` ⟩ =?= '⟨ `x` ⟩ ===> matched
*
* /* Match pattern ref splice */
* '⟨ `x` ⟩ =?= '⟨ hole ⟩ ===> matched('{`x`})
*
* ```
*/
private[quoted] object Matcher {

class QuoteMatcher[QCtx <: QuoteContext & Singleton](using val qctx: QCtx) {
@@ -83,15 +212,15 @@ private[quoted] object Matcher {
case annot => annot.symbol.owner == internal.Definitions_InternalQuoted_fromAboveAnnot
}

/** Check that all trees match with `mtch` and concatenate the results with && */
/** Check that all trees match with `mtch` and concatenate the results with &&& */
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
case (x :: xs, y :: ys) => mtch(x, y) && matchLists(xs, ys)(mtch)
case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch)
case (Nil, Nil) => matched
case _ => notMatched
}

private extension treeListOps on (scrutinees: List[Tree]) {
/** Check that all trees match with =?= and concatenate the results with && */
/** Check that all trees match with =?= and concatenate the results with &&& */
def =?= (patterns: List[Tree])(using Context, Env): Matching =
matchLists(scrutinees, patterns)(_ =?= _)
}
@@ -108,6 +237,7 @@ private[quoted] object Matcher {
*/
def =?= (pattern0: Tree)(using Context, Env): Matching = {

/* Match block flattening */ // TODO move to cases
/** Normalize the tree */
def normalize(tree: Tree): Tree = tree match {
case Block(Nil, expr) => normalize(expr)
@@ -129,20 +259,24 @@ private[quoted] object Matcher {

(scrutinee, pattern) match {

/* Term hole */
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
matched(scrutinee.seal)

/* Term hole */
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
matched(scrutinee.seal)

/* Higher order term hole */
// Matches an open term and wraps it into a lambda that provides the free variables
// TODO do not encode with `hole`. Maybe use `higherOrderHole[(T1, ..., Tn) => R]((x1: T1, ..., xn: Tn)): R`
case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args)))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
def bodyFn(lambdaArgs: List[Tree]): Tree = {
@@ -164,34 +298,47 @@ private[quoted] object Matcher {
// Match two equivalent trees
//

/* Match literal */
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
matched

/* Match type ascription (a) */
case (Typed(expr1, _), pattern) =>
expr1 =?= pattern

/* Match type ascription (b) */
case (scrutinee, Typed(expr2, _)) =>
scrutinee =?= expr2

case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
matched

/* Match selection */
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
qual1 =?= qual2

/* Match reference */
// TODO could be subsumed by the next case
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
matched

/* Match reference */
case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol =>
matched

/* Match application */
// TODO may not need to check the symbol (done in fn1 =?= fn2)
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
fn1 =?= fn2 && args1 =?= args2
fn1 =?= fn2 &&& args1 =?= args2

/* Match type application */
// TODO may not need to check the symbol (done in fn1 =?= fn2)
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
fn1 =?= fn2 && args1 =?= args2
fn1 =?= fn2 &&& args1 =?= args2

case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) =>
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil)
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) &&& Block(stats1, expr1) =?= Block(stats2, expr2)

/* Match block */
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
val newEnv = (stat1, stat2) match {
case (stat1: Definition, stat2: Definition) =>
@@ -200,48 +347,62 @@ private[quoted] object Matcher {
summon[Env]
}
withEnv(newEnv) {
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2)
}

case (scrutinee, Block(typeBindings, expr2)) if typeBindings.forall(isTypeBinding) =>
val bindingSymbols = typeBindings.map(_.symbol)
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(bindingSymbols)
bindingSymbols.foldRight(scrutinee =?= expr2)((x, acc) => matched(new SymBinding(x, hasFromAboveAnnotation(x))) && acc)
bindingSymbols.foldRight(scrutinee =?= expr2)((x, acc) => matched(new SymBinding(x, hasFromAboveAnnotation(x))) &&& acc)

/* Match if */
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
cond1 =?= cond2 && thenp1 =?= thenp2 && elsep1 =?= elsep2
cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2

/* Match while */
case (While(cond1, body1), While(cond2, body2)) =>
cond1 =?= cond2 &&& body1 =?= body2

/* Match assign */
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
val lhsMatch =
if ((lhs1 =?= lhs2).isMatch) matched
else notMatched
lhsMatch && rhs1 =?= rhs2

case (While(cond1, body1), While(cond2, body2)) =>
cond1 =?= cond2 && body1 =?= body2
// TODO lhs1 =?= lhs2 &&& rhs1 =?= rhs2
lhsMatch &&& rhs1 =?= rhs2

/* Match new */
case (New(tpt1), New(tpt2)) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol =>
matched

/* Match this */
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
matched

/* Match super */
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
qual1 =?= qual2

/* Match varargs */
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
elems1 =?= elems2

/* Match type */
// TODO remove this?
case (scrutinee: TypeTree, pattern: TypeTree) if scrutinee.tpe <:< pattern.tpe =>
matched

/* Match applied type */
// TODO remove this?
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
tycon1 =?= tycon2 && args1 =?= args2
tycon1 =?= tycon2 &&& args1 =?= args2

/* Match val */
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol)
tpt1 =?= tpt2 && treeOptMatches(rhs1, rhs2)(using summon[Context], rhsEnv)
tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using summon[Context], rhsEnv)

/* Match def */
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
def rhsEnv =
val oldEnv: Env = summon[Env]
@@ -251,23 +412,28 @@ private[quoted] object Matcher {
oldEnv ++ newEnv

typeParams1 =?= typeParams2
&& matchLists(paramss1, paramss2)(_ =?= _)
&& tpt1 =?= tpt2
&& withEnv(rhsEnv)(rhs1 =?= rhs2)
&&& matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(rhs1 =?= rhs2)

case (Closure(_, tpt1), Closure(_, tpt2)) =>
// TODO match tpt1 with tpt2?
matched

/* Match match */
case (Match(scru1, cases1), Match(scru2, cases2)) =>
scru1 =?= scru2 && matchLists(cases1, cases2)(caseMatches)
scru1 =?= scru2 &&& matchLists(cases1, cases2)(caseMatches)

/* Match try */
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
body1 =?= body2 && matchLists(cases1, cases2)(caseMatches) && treeOptMatches(finalizer1, finalizer2)
body1 =?= body2 &&& matchLists(cases1, cases2)(caseMatches) &&& treeOptMatches(finalizer1, finalizer2)

// Ignore type annotations
// TODO remove this
/* Match annot (a) */
case (Annotated(tpt, _), _) =>
tpt =?= pattern
/* Match annot (b) */
case (_, Annotated(tpt, _)) =>
scrutinee =?= tpt

@@ -336,9 +502,9 @@ private[quoted] object Matcher {
private def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(using Context, Env): Matching = {
val (caseEnv, patternMatch) = patternsMatches(scrutinee.pattern, pattern.pattern)
withEnv(caseEnv) {
patternMatch &&
treeOptMatches(scrutinee.guard, pattern.guard) &&
scrutinee.rhs =?= pattern.rhs
patternMatch
&&& treeOptMatches(scrutinee.guard, pattern.guard)
&&& scrutinee.rhs =?= pattern.rhs
}
}

@@ -354,24 +520,30 @@ private[quoted] object Matcher {
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
*/
private def patternsMatches(scrutinee: Tree, pattern: Tree)(using Context, Env): (Env, Matching) = (scrutinee, pattern) match {
/* Match pattern ref splice */
case (v1: Term, Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
if patternHole.symbol.owner == summon[Context].requiredModule("scala.runtime.quoted.Matcher") =>
(summon[Env], matched(v1.seal))

/* Match pattern whildcard */
case (Ident("_"), Ident("_")) =>
(summon[Env], matched)

/* Match pattern bind */
case (Bind(name1, body1), Bind(name2, body2)) =>
val bindEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol)
patternsMatches(body1, body2)(using summon[Context], bindEnv)

/* Match pattern unapply */
case (Unapply(fun1, implicits1, patterns1), Unapply(fun2, implicits2, patterns2)) =>
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
(patEnv, patternsMatches(fun1, fun2)._2 && implicits1 =?= implicits2 && patternsMatch)
(patEnv, patternsMatches(fun1, fun2)._2 &&& implicits1 =?= implicits2 &&& patternsMatch)

/* Match pattern alternatives */
case (Alternatives(patterns1), Alternatives(patterns2)) =>
foldPatterns(patterns1, patterns2)

/* Match pattern type test */
case (Typed(Ident("_"), tpt1), Typed(Ident("_"), tpt2)) =>
(summon[Env], tpt1 =?= tpt2)

@@ -403,7 +575,7 @@ private[quoted] object Matcher {
if (patterns1.size != patterns2.size) (summon[Env], notMatched)
else patterns1.zip(patterns2).foldLeft((summon[Env], matched)) { (acc, x) =>
val (env, res) = patternsMatches(x._1, x._2)(using summon[Context], acc._1)
(env, acc._2 && res)
(env, acc._2 &&& res)
}
}

@@ -425,7 +597,7 @@ private[quoted] object Matcher {
def (self: Matching) asOptionOfTuple: Option[Tuple] = self

/** Concatenates the contents of two successful matchings or return a `notMatched` */
def (self: Matching) && (that: => Matching): Matching = self match {
def (self: Matching) &&& (that: => Matching): Matching = self match {
case Some(x) =>
that match {
case Some(y) => Some(x ++ y)