Skip to content

Commit 06ce25f

Browse files
committed
[LAYERS] Fix a few exports and issues after the 0.7.2 merge
1 parent 9d4bb7d commit 06ce25f

File tree

24 files changed

+186
-139
lines changed

24 files changed

+186
-139
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ lazy val npmLibrary =
114114
new {
115115
object Version {
116116
val tfjsCore = "0.12.8"
117-
val tfjsLayers = "0.6.7"
117+
val tfjsLayers = "0.7.2"
118118
val tfjsConverter = "0.4.3"
119119
}
120120
val tfjsCore = "@tensorflow/tfjs-core" -> Version.tfjsCore

example/src/main/scala/example/Example.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ object Example {
3434
def main(args: Array[String]): Unit = {
3535

3636
println("Hello scalajs-tfjs!")
37-
println(s"tfjs version: ${tf.version}")
37+
println(s"tfjs-core version: ${tf.version}")
3838

3939
// Tensors
4040
val shape = js.Array(2, 3) // 2 rows, 3 columns

examples/mobilenet/src/main/scala/MobilenetDemo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ object MobilenetDemo {
208208
val status = (msg: String) => demoStatusElement.asInstanceOf[RichHTMLElement].innerText = msg
209209

210210
def main(args: Array[String]): Unit = {
211+
println(s"tfjs-core version: ${tf.version}")
212+
println(s"tfjs-layers version: ${tfl.version_layers}")
211213
mobilenetDemo
212214
}
213215
}

tfjs-core/src/main/scala/io/brunk/tfjs/core/environment.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,15 @@ object Environment extends js.Object {
5656
def getBackend(): String = js.native
5757
def disposeVariables(): Unit = js.native
5858
def memory(): MemoryInfo = js.native
59+
// TODO find a way to deal with the fact that the compiler doesn't see T as subtype of the TensorContainer union type
5960
def tidy[T <: TensorContainer](
60-
nameOrFn: String | ScopeFn[T],
61-
fn: ScopeFn[T] = ???,
62-
gradMode: Boolean = ???
61+
nameOrFn: String | ScopeFn[T],
62+
fn: ScopeFn[T] = ???,
63+
gradMode: Boolean = ???
6364
): T = js.native
65+
def tidy(
66+
nameOrFn: js.Function0[TensorND],
67+
): TensorND = js.native
6468
def dispose(container: TensorContainer): Unit = js.native
6569
def keep[T <: TensorND](result: T): T = js.native
6670
def time(f: js.Function0[Unit]): Promise[TimingInfo] = js.native

tfjs-core/src/main/scala/io/brunk/tfjs/core/globals.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ trait Globals extends js.Object {
4747
fn: ScopeFn[T] = ???,
4848
gradMode: Boolean = ???
4949
): T = js.native
50+
def tidy(
51+
nameOrFn: js.Function0[TensorND],
52+
): TensorND = js.native
5053
def dispose(container: TensorContainer): Unit = js.native
5154
def keep[T <: TensorND](result: T): T = js.native
5255
def time(f: js.Function0[Unit]): Promise[TimingInfo] = js.native

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/backend/state.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package io.brunk.tfjs.layers.backend
1818

19+
import io.brunk.tfjs.core.DataType
20+
import io.brunk.tfjs.tf.Scalar
21+
1922
import scala.scalajs.js
2023
import js.annotation._
2124
import js.|

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/backend/tfjs_backend.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ package io.brunk.tfjs.layers.backend
1818

1919
import io.brunk.tfjs.core.DataType
2020
import io.brunk.tfjs.layers.Common.DataFormat
21-
import io.brunk.tfjs.layers.{ LayerVariable, SymbolicTensor }
21+
import io.brunk.tfjs.layers.LayerVariable
2222
import io.brunk.tfjs.layers.Types.Shape
23+
import io.brunk.tfjs.layers.engine.SymbolicTensor
2324
import io.brunk.tfjs.tf.TensorND
2425

2526
import scala.scalajs.js

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/base_callbacks.scala

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616

1717
package io.brunk.tfjs.layers
1818

19+
import io.brunk.tfjs.layers.Base_callbacks.Params
20+
import io.brunk.tfjs.layers.Logs.{Logs, UnresolvedLogs}
21+
import io.brunk.tfjs.layers.engine.Container.Container
22+
1923
import scala.scalajs.js
2024
import js.annotation._
21-
import js.|
25+
import js.{Promise, |}
26+
import io.brunk.tfjs.tf.TensorND
2227

2328
@js.native
2429
@JSGlobal
2530
abstract class BaseCallback extends js.Object {
26-
var validationData: Tensor | js.Array[Tensor] = js.native
31+
var validationData: TensorND | js.Array[TensorND] = js.native
2732
var params: Params = js.native
2833
def setParams(params: Params): Unit = js.native
2934
def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
@@ -55,18 +60,18 @@ class CallbackList protected () extends js.Object {
5560
@js.native
5661
@JSGlobal
5762
class BaseLogger extends BaseCallback {
58-
def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
59-
def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
60-
def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
63+
override def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
64+
override def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
65+
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
6166
}
6267

6368
@js.native
6469
@JSGlobal
6570
class History extends BaseCallback {
6671
var epoch: js.Array[Double] = js.native
6772
var history: History.History = js.native
68-
def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
69-
def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
73+
override def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
74+
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
7075
def syncData(): Promise[Unit] = js.native
7176
}
7277

@@ -75,9 +80,9 @@ object History {
7580
@js.native
7681
trait History extends js.Object {
7782
@JSBracketAccess
78-
def apply(key: String): js.Array[Double | Tensor] = js.native
83+
def apply(key: String): js.Array[Double | TensorND] = js.native
7984
@JSBracketAccess
80-
def update(key: String, v: js.Array[Double | Tensor]): Unit = js.native
85+
def update(key: String, v: js.Array[Double | TensorND]): Unit = js.native
8186
}
8287
}
8388

@@ -101,12 +106,12 @@ class CustomCallback protected () extends BaseCallback {
101106
protected def epochEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
102107
protected def batchBegin: js.Function2[Double, Logs, Promise[Unit]] = js.native
103108
protected def batchEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
104-
def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
105-
def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
106-
def onBatchBegin(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
107-
def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
108-
def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
109-
def onTrainEnd(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
109+
override def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
110+
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
111+
override def onBatchBegin(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
112+
override def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
113+
override def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
114+
override def onTrainEnd(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
110115
}
111116

112117
@js.native

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/callbacks.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package io.brunk.tfjs.layers
1818

19-
import io.brunk.tfjs.layers.Callbacks.{ Logs, Params, UnresolvedLogs }
2019
import io.brunk.tfjs.layers.engine.Model
2120
import io.brunk.tfjs.tf._
2221

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/engine/container.scala

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,91 +16,100 @@
1616

1717
package io.brunk.tfjs.layers.engine
1818

19+
import io.brunk.tfjs.core.serialization
20+
import io.brunk.tfjs.layers.Types.{Kwargs, Shape}
21+
import io.brunk.tfjs.layers.{JsonDict, LayerVariable, NamedTensorMap}
22+
import io.brunk.tfjs.tf.{Scalar, TensorND}
23+
1924
import scala.scalajs.js
2025
import js.annotation._
2126
import js.|
2227

2328
@js.native
24-
trait ContainerConfig extends js.Object {
25-
var inputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
26-
var outputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
27-
var name: String = js.native
28-
}
29-
30-
@js.native
31-
@JSGlobal
32-
abstract class Container protected () extends Layer {
33-
def this(config: ContainerConfig) = this()
34-
var inputs: js.Array[SymbolicTensor] = js.native
35-
var outputs: js.Array[SymbolicTensor] = js.native
36-
var inputLayers: js.Array[Layer] = js.native
37-
var inputLayersNodeIndices: js.Array[Double] = js.native
38-
var inputLayersTensorIndices: js.Array[Double] = js.native
39-
var outputLayers: js.Array[Layer] = js.native
40-
var outputLayersNodeIndices: js.Array[Double] = js.native
41-
var outputLayersTensorIndices: js.Array[Double] = js.native
42-
var layers: js.Array[Layer] = js.native
43-
var layersByDepth: Container.LayersByDepth = js.native
44-
var nodesByDepth: Container.NodesByDepth = js.native
45-
var containerNodes: Set[String] = js.native
46-
var inputNames: js.Array[String] = js.native
47-
var outputNames: js.Array[String] = js.native
48-
var feedInputShapes: js.Array[Shape] = js.native
49-
protected var internalInputShapes: js.Array[Shape] = js.native
50-
protected var internalOutputShapes: js.Array[Shape] = js.native
51-
protected var feedInputNames: js.Array[String] = js.native
52-
protected var feedOutputNames: js.Array[String] = js.native
53-
def trainableWeights: js.Array[LayerVariable] = js.native
54-
def nonTrainableWeights: js.Array[LayerVariable] = js.native
55-
def weights: js.Array[LayerVariable] = js.native
56-
def loadWeights(
57-
weightsJSON: JsonDict | NamedTensorMap,
58-
skipMismatch: Boolean = ???,
59-
isNamedTensorMap: Boolean = ???
60-
): Unit = js.native
61-
def toJSON(unused: js.Any = ???, returnString: Boolean = ???): String | JsonDict = js.native
62-
def call(inputs: Tensor | js.Array[Tensor], kwargs: Kwargs): Tensor | js.Array[Tensor] = js.native
63-
def computeMask(
64-
inputs: Tensor | js.Array[Tensor],
65-
mask: Tensor | js.Array[Tensor] = ???
66-
): Tensor | js.Array[Tensor] = js.native
67-
def computeOutputShape(inputShape: Shape | js.Array[Shape]): Shape | js.Array[Shape] = js.native
68-
def runInternalGraph(
69-
inputs: js.Array[Tensor],
70-
masks: js.Array[Tensor] = ???
71-
): js.Tuple3[js.Array[Tensor], js.Array[Tensor], js.Array[Shape]] = js.native
72-
def getLayer(name: String = ???, index: Double = ???): Layer = js.native
73-
def calculateLosses(): js.Array[Scalar] = js.native
74-
def getConfig(): serialization.ConfigDict = js.native
75-
def stateful: Boolean = js.native
76-
}
29+
@JSGlobalScope
30+
object Container extends js.Object {
7731

78-
object Container {
32+
@js.native
33+
trait ContainerConfig extends js.Object {
34+
var inputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
35+
var outputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
36+
var name: String = js.native
37+
}
7938

8039
@js.native
81-
trait LayersByDepth extends js.Object {
82-
@JSBracketAccess
83-
def apply(depth: String): js.Array[Layer] = js.native
84-
@JSBracketAccess
85-
def update(depth: String, v: js.Array[Layer]): Unit = js.native
40+
abstract class Container protected () extends Layer {
41+
def this(config: ContainerConfig) = this()
42+
var inputs: js.Array[SymbolicTensor] = js.native
43+
var outputs: js.Array[SymbolicTensor] = js.native
44+
var inputLayers: js.Array[Layer] = js.native
45+
var inputLayersNodeIndices: js.Array[Double] = js.native
46+
var inputLayersTensorIndices: js.Array[Double] = js.native
47+
var outputLayers: js.Array[Layer] = js.native
48+
var outputLayersNodeIndices: js.Array[Double] = js.native
49+
var outputLayersTensorIndices: js.Array[Double] = js.native
50+
var layers: js.Array[Layer] = js.native
51+
var layersByDepth: Container.LayersByDepth = js.native
52+
var nodesByDepth: Container.NodesByDepth = js.native
53+
var containerNodes: Set[String] = js.native
54+
var inputNames: js.Array[String] = js.native
55+
var outputNames: js.Array[String] = js.native
56+
var feedInputShapes: js.Array[Shape] = js.native
57+
protected var internalInputShapes: js.Array[Shape] = js.native
58+
protected var internalOutputShapes: js.Array[Shape] = js.native
59+
protected var feedInputNames: js.Array[String] = js.native
60+
protected var feedOutputNames: js.Array[String] = js.native
61+
// method trainableWeights cannot override a mutable variable
62+
// override def trainableWeights: js.Array[LayerVariable] = js.native
63+
// method nonTrainableWeights cannot override a mutable variable
64+
// override def nonTrainableWeights: js.Array[LayerVariable] = js.native
65+
override def weights: js.Array[LayerVariable] = js.native
66+
def loadWeights(
67+
weightsJSON: JsonDict | NamedTensorMap,
68+
skipMismatch: Boolean = ???,
69+
isNamedTensorMap: Boolean = ???
70+
): Unit = js.native
71+
def toJSON(unused: js.Any = ???, returnString: Boolean = ???): String | JsonDict = js.native
72+
override def call(inputs: TensorND | js.Array[TensorND], kwargs: Kwargs): TensorND | js.Array[TensorND] =
73+
js.native
74+
override def computeMask(
75+
inputs: TensorND | js.Array[TensorND],
76+
mask: TensorND | js.Array[TensorND] = ???
77+
): TensorND | js.Array[TensorND] = js.native
78+
override def computeOutputShape(inputShape: Shape | js.Array[Shape]): Shape | js.Array[Shape] = js.native
79+
def runInternalGraph(
80+
inputs: js.Array[TensorND],
81+
masks: js.Array[TensorND] = ???
82+
): js.Tuple3[js.Array[TensorND], js.Array[TensorND], js.Array[Shape]] = js.native
83+
def getLayer(name: String = ???, index: Double = ???): Layer = js.native
84+
override def calculateLosses(): js.Array[Scalar] = js.native
85+
override def getConfig(): serialization.ConfigDict = js.native
86+
override def stateful: Boolean = js.native
8687
}
8788

8889
@js.native
89-
trait NodesByDepth extends js.Object {
90-
@JSBracketAccess
91-
def apply(depth: String): js.Array[Node] = js.native
92-
@JSBracketAccess
93-
def update(depth: String, v: js.Array[Node]): Unit = js.native
90+
object Container extends js.Object {
91+
92+
@js.native
93+
trait LayersByDepth extends js.Object {
94+
@JSBracketAccess
95+
def apply(depth: String): js.Array[Layer] = js.native
96+
@JSBracketAccess
97+
def update(depth: String, v: js.Array[Layer]): Unit = js.native
98+
}
99+
100+
@js.native
101+
trait NodesByDepth extends js.Object {
102+
@JSBracketAccess
103+
def apply(depth: String): js.Array[Node] = js.native
104+
@JSBracketAccess
105+
def update(depth: String, v: js.Array[Node]): Unit = js.native
106+
}
107+
def fromConfig[T <: serialization.Serializable](
108+
cls: serialization.SerializableConstructor[T],
109+
config: serialization.ConfigDict
110+
): T = js.native
94111
}
95-
def fromConfig[T <: serialization.Serializable](
96-
cls: serialization.SerializableConstructor[T],
97-
config: serialization.ConfigDict
98-
): T = js.native
99-
}
100112

101-
@js.native
102-
@JSGlobalScope
103-
object Container extends js.Object {
104113
def loadWeightsFromJson(
105114
weightsJSON: JsonDict,
106115
layers: js.Array[Layer],

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/engine/executor.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package io.brunk.tfjs.layers.engine
1818

19-
import io.brunk.tfjs.layers.SymbolicTensor
2019
import io.brunk.tfjs.layers.Types.Kwargs
2120

2221
import scala.scalajs.js

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/engine/input_layer.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package io.brunk.tfjs.layers.engine
1818

19+
import io.brunk.tfjs.core.{DataType, serialization}
20+
import io.brunk.tfjs.layers.Types.{Kwargs, Shape}
21+
import io.brunk.tfjs.core.TensorModule.TensorND
22+
1923
import scala.scalajs.js
2024
import js.annotation._
2125
import js.|
@@ -35,12 +39,13 @@ trait InputLayerConfig extends js.Object {
3539
class InputLayer protected () extends Layer {
3640
def this(config: InputLayerConfig) = this()
3741
var sparse: Boolean = js.native
38-
@JSName("apply")
39-
def apply(
40-
inputs: Tensor | js.Array[Tensor] | SymbolicTensor | js.Array[SymbolicTensor],
41-
kwargs: Kwargs = ???
42-
): Tensor | js.Array[Tensor] | SymbolicTensor = js.native
43-
def getConfig(): serialization.ConfigDict = js.native
42+
// until we have real union types aka dotty, we have to stick with the wider type of the base class
43+
// @JSName("apply")
44+
// override def apply(
45+
// inputs: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor],
46+
// kwargs: Kwargs = ???
47+
// ): TensorND | js.Array[TensorND] | SymbolicTensor = js.native
48+
override def getConfig(): serialization.ConfigDict = js.native
4449
}
4550

4651
@js.native

0 commit comments

Comments
 (0)