Skip to content

Commit 6756c41

Browse files
committed
Partial function synthesis changesOwner of selector
The selector expression may be non-trivial.
1 parent 09d64c6 commit 6756c41

File tree

9 files changed

+96
-40
lines changed

9 files changed

+96
-40
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
865865
/** An extractor for def of a closure contained the block of the closure. */
866866
object closureDef {
867867
def unapply(tree: Tree)(using Context): Option[DefDef] = tree match {
868-
case Block((meth : DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol =>
868+
case Block((meth: DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol =>
869869
Some(meth)
870870
case Block(Nil, expr) =>
871871
unapply(expr)

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
375375
* new parents { termForwarders; typeAliases }
376376
*
377377
* @param parents a non-empty list of class types
378-
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
378+
* @param termForwarders a non-empty list of forwarding definitions specified by their name
379+
* and the definition they forward to.
379380
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
380381
* @param adaptVarargs if true, allow matching a vararg superclass constructor
381382
* with a missing argument in superArgs, and synthesize an

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6048,19 +6048,18 @@ object Types extends TypeUtils {
60486048
end samParent
60496049

60506050
def samClass(tp: Type)(using Context): Symbol = tp match
6051-
case tp: ClassInfo =>
6052-
val cls = tp.cls
6051+
case tp @ ClassInfo(_, cls, _, _, _) =>
60536052
def takesNoArgs(tp: Type) =
60546053
!tp.classSymbol.primaryConstructor.exists
60556054
// e.g. `ContextFunctionN` does not have constructors
6056-
|| tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0
6055+
|| tp.applicableConstructors(argTypes = Nil, adaptVarargs = true).lengthCompare(1) == 0
60576056
// we require a unique constructor so that SAM expansion is deterministic
60586057
val noArgsNeeded: Boolean =
60596058
takesNoArgs(tp)
6060-
&& (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head))
6059+
&& (!cls.is(Trait) || takesNoArgs(tp.parents.head))
60616060
def isInstantiable =
6062-
!tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
6063-
if noArgsNeeded && isInstantiable then tp.cls
6061+
!cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
6062+
if noArgsNeeded && isInstantiable then cls
60646063
else NoSymbol
60656064
case tp: AppliedType =>
60666065
samClass(tp.superType)

compiler/src/dotty/tools/dotc/transform/Dependencies.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
4545

4646
/** A map from local methods and classes to the owners to which they will be lifted as members.
4747
* For methods and classes that do not have any dependencies this will be the enclosing package.
48-
* symbols with packages as lifted owners will subsequently represented as static
48+
* Symbols with packages as lifted owners will be subsequently represented as static
4949
* members of their toplevel class, unless their enclosing class was already static.
50-
* Note: During tree transform (which runs at phase LambdaLift + 1), liftedOwner
50+
* Note: During tree transform (which runs at phase LambdaLift + 1), logicOwner
5151
* is also used to decide whether a method had a term owner before.
5252
*/
5353
private val logicOwner = new LinkedHashMap[Symbol, Symbol]
@@ -75,8 +75,8 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
7575
|| owner.is(Trait) && isLocal(owner)
7676
|| sym.isConstructor && isLocal(owner)
7777

78-
/** Set `liftedOwner(sym)` to `owner` if `owner` is more deeply nested
79-
* than the previous value of `liftedowner(sym)`.
78+
/** Set `logicOwner(sym)` to `owner` if `owner` is more deeply nested
79+
* than the previous value of `logicOwner(sym)`.
8080
*/
8181
private def narrowLogicOwner(sym: Symbol, owner: Symbol)(using Context): Unit =
8282
if sym.maybeOwner.isTerm
@@ -89,7 +89,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
8989

9090
/** Mark symbol `sym` as being free in `enclosure`, unless `sym` is defined
9191
* in `enclosure` or there is an intermediate class properly containing `enclosure`
92-
* in which `sym` is also free. Also, update `liftedOwner` of `enclosure` so
92+
* in which `sym` is also free. Also, update `logicOwner` of `enclosure` so
9393
* that `enclosure` can access `sym`, or its proxy in an intermediate class.
9494
* This means:
9595
*
@@ -284,7 +284,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
284284
changedFreeVars
285285
do ()
286286

287-
/** Compute final liftedOwner map by closing over caller dependencies */
287+
/** Compute final logicOwner map by closing over caller dependencies */
288288
private def computeLogicOwners()(using Context): Unit =
289289
while
290290
changedLogicOwner = false

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ import Names.TypeName
1010

1111
import NullOpsDecorator.*
1212
import ast.untpd
13+
import scala.collection.mutable.ListBuffer
1314

1415
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
1516
* These fall into five categories
1617
*
1718
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
1819
* 2. Closures implementing non-trait classes
1920
* 3. Closures implementing classes that inherit from a class other than Object
20-
* (a lambda cannot not be a run-time subtype of such a class)
21+
* (a lambda cannot be a run-time subtype of such a class)
2122
* 4. Closures that implement traits which run initialization code.
2223
* 5. Closures that get synthesized abstract methods in the transformation pipeline. These methods can be
2324
* (1) superaccessors, (2) outer references, (3) accessors for fields.
@@ -59,7 +60,7 @@ class ExpandSAMs extends MiniPhase:
5960
// A SAM type is allowed to have type aliases refinements (see
6061
// SAMType#samParent) which must be converted into type members if
6162
// the closure is desugared into a class.
62-
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
63+
val refinements = ListBuffer.empty[(TypeName, TypeAlias)]
6364
def collectAndStripRefinements(tp: Type): Type = tp match
6465
case RefinedType(parent, name, info: TypeAlias) =>
6566
val res = collectAndStripRefinements(parent)
@@ -81,34 +82,40 @@ class ExpandSAMs extends MiniPhase:
8182
tree
8283
}
8384

84-
/** A partial function literal:
85+
/** A pattern-matching anonymous function:
8586
*
8687
* ```
8788
* val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En }
8889
* ```
90+
* or
91+
* ```
92+
* x => e(x) { case C1 => E1; ...; case Cn => En }
93+
* ```
94+
* where the expression `e(x)` may be trivially `x`
8995
*
9096
* which desugars to:
9197
*
9298
* ```
9399
* val x: PartialFunction[A, B] = {
94-
* def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En }
100+
* def $anonfun(x: A): B = e(x) match { case C1 => E1; ...; case Cn => En }
95101
* closure($anonfun: PartialFunction[A, B])
96102
* }
97103
* ```
104+
* where the expression `e(x)` defaults to `x` for a simple block of cases
98105
*
99106
* is expanded to an anonymous class:
100107
*
101108
* ```
102109
* val x: PartialFunction[A, B] = {
103110
* class $anon extends AbstractPartialFunction[A, B] {
104-
* final def isDefinedAt(x: A): Boolean = x match {
111+
* final def isDefinedAt(x: A): Boolean = e(x) match {
105112
* case C1 => true
106113
* ...
107114
* case Cn => true
108115
* case _ => false
109116
* }
110117
*
111-
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match {
118+
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = e(x) match {
112119
* case C1 => E1
113120
* ...
114121
* case Cn => En
@@ -120,7 +127,7 @@ class ExpandSAMs extends MiniPhase:
120127
* }
121128
* ```
122129
*/
123-
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
130+
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree =
124131
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree: @unchecked
125132

126133
// The right hand side from which to construct the partial function. This is always a Match.
@@ -146,7 +153,7 @@ class ExpandSAMs extends MiniPhase:
146153
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
147154
defn.SerializableType)
148155

149-
AnonClass(anonSym.owner, parents, tree.span) { pfSym =>
156+
AnonClass(anonSym.owner, parents, tree.span): pfSym =>
150157
def overrideSym(sym: Symbol) = sym.copy(
151158
owner = pfSym,
152159
flags = Synthetic | Method | Final | Override,
@@ -155,7 +162,8 @@ class ExpandSAMs extends MiniPhase:
155162
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
156163
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
157164

158-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
165+
def translateMatch(owner: Symbol)(pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) =
166+
val tree: Match = pfRHS
159167
val selector = tree.selector
160168
val cases1 = if cases.exists(isDefaultCase) then cases
161169
else
@@ -165,31 +173,27 @@ class ExpandSAMs extends MiniPhase:
165173
cases :+ defaultCase
166174
cpy.Match(tree)(selector, cases1)
167175
.subst(param.symbol :: Nil, pfParam :: Nil)
168-
// Needed because a partial function can be written as:
176+
// Needed because a partial function can be written as:
169177
// param => param match { case "foo" if foo(param) => param }
170178
// And we need to update all references to 'param'
171-
}
179+
.changeOwner(anonSym, owner)
172180

173-
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
181+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) =
174182
val tru = Literal(Constant(true))
175-
def translateCase(cdef: CaseDef) =
176-
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
183+
def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru)
177184
val paramRef = paramRefss.head.head
178185
val defaultValue = Literal(Constant(false))
179-
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
180-
}
186+
translateMatch(isDefinedAtFn)(paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
181187

182-
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
188+
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) =
183189
val List(paramRef, defaultRef) = paramRefss(1)
184-
def translateCase(cdef: CaseDef) =
185-
cdef.changeOwner(anonSym, applyOrElseFn)
186190
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
187-
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
188-
}
191+
translateMatch(applyOrElseFn)(paramRef.symbol, pfRHS.cases, defaultValue)
189192

190-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
191-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
193+
val isDefinedAtDef = transformFollowingDeep:
194+
DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn)))
195+
val applyOrElseDef = transformFollowingDeep:
196+
DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn)))
192197
List(isDefinedAtDef, applyOrElseDef)
193-
}
194-
}
198+
end toPartialFunction
195199
end ExpandSAMs

docs/_spec/08-pattern-matching.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ new scala.PartialFunction[´S´, ´T´] {
616616
def apply(´x´: ´S´): ´T´ = x match {
617617
case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´
618618
}
619-
def isDefinedAt(´x´: ´S´): Boolean = {
619+
def isDefinedAt(´x´: ´S´): Boolean = x match {
620620
case ´p_1´ => true ... case ´p_n´ => true
621621
case _ => false
622622
}
@@ -626,6 +626,22 @@ new scala.PartialFunction[´S´, ´T´] {
626626
Here, ´x´ is a fresh name and ´T´ is the least upper bound of the types of all ´b_i´.
627627
The final default case in the `isDefinedAt` method is omitted if one of the patterns ´p_1, ..., p_n´ is already a variable or wildcard pattern.
628628

629+
As a convenience, the partial function may be written using function literal notation:
630+
631+
```scala
632+
x: S´) => e(´x´) match {
633+
case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´
634+
}
635+
```
636+
where the selector expression is used for matches in the expansion.
637+
The body of the function must consist solely of the match expression.
638+
639+
This syntax permits annotating the selector:
640+
641+
```scala
642+
x: S´) => (e(´x´): @unchecked) match { ... }
643+
```
644+
629645
###### Example
630646
Here's an example which uses `foldLeft` to compute the scalar product of two vectors:
631647

tests/pos/i23025.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
class A {
3+
def f: PartialFunction[Int, Int] =
4+
a => { (try a catch { case e : Throwable => throw e}) match { case n => n } }
5+
}

tests/pos/i23054.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
object Bug:
3+
4+
def m0(f: PartialFunction[Char, Unit]): Unit = ()
5+
6+
def m1(): Unit =
7+
m0: x =>
8+
"abc".filter(_ == x) match
9+
case _ => ()
10+
11+
def m2(): Unit =
12+
m0: x =>
13+
x match
14+
case _ => ()
15+

tests/pos/i23310.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
object Example {
3+
val pf: PartialFunction[Unit, Unit] = s => (s match {
4+
case a => a
5+
}) match {
6+
case a => ()
7+
}
8+
}
9+
10+
object ExampleB:
11+
def test =
12+
List(42).collect:
13+
_.match
14+
case x => x
15+
.match
16+
case y => y + 27

0 commit comments

Comments
 (0)