Skip to content

Commit 287058c

Browse files
committed
[LAYERS] Merge 0.7.2 changes
1 parent 578c04e commit 287058c

File tree

11 files changed

+38
-290
lines changed

11 files changed

+38
-290
lines changed

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/backend/tfjs_backend.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,9 @@ import io.brunk.tfjs.tf._
3030
@js.native
3131
@JSGlobalScope
3232
object Tfjs_backend extends js.Object {
33-
def disposeScalarCache(): Unit = js.native
3433
def setBackend(requestedBackend: String): Unit = js.native
3534
def getBackend(): String = js.native
36-
def getScalar(value: Double, dtype: DataType = ???): Scalar = js.native
37-
val epsilon: Double = js.native
3835
def isBackendSymbolic(): Boolean = js.native
39-
def shape(x: TensorND | SymbolicTensor): Shape = js.native
40-
def intShape(x: TensorND | SymbolicTensor): js.Array[Double] = js.native
41-
def ndim(x: TensorND | SymbolicTensor): Double = js.native
42-
def dtype(x: TensorND | SymbolicTensor): DataType = js.native
4336
def countParams(x: TensorND | SymbolicTensor): Double = js.native
4437
def cast(x: TensorND, dtype: String): TensorND = js.native
4538
def expandDims(x: TensorND, axis: Double = ???): TensorND = js.native
@@ -53,11 +46,6 @@ object Tfjs_backend extends js.Object {
5346
def concatenate(tensors: js.Array[TensorND], axis: Double = ???): TensorND = js.native
5447
def concatAlongFirstAxis(a: TensorND, b: TensorND): TensorND = js.native
5548
def tile(x: TensorND, n: Double | js.Array[Double]): TensorND = js.native
56-
def identity(x: TensorND): TensorND = js.native
57-
def eyeVariable(size: Double, dtype: DataType = ???, name: String = ???): LayerVariable =
58-
js.native
59-
def scalarTimesArray(c: Scalar, x: TensorND): TensorND = js.native
60-
def scalarPlusArray(c: Scalar, x: TensorND): TensorND = js.native
6149
def randomNormal(
6250
shape: Shape,
6351
mean: Double = ???,
@@ -67,7 +55,6 @@ object Tfjs_backend extends js.Object {
6755
): TensorND = js.native
6856
def dot(x: TensorND, y: TensorND): TensorND = js.native
6957
def sign(x: TensorND): TensorND = js.native
70-
def qr(x: Tensor2D): js.Tuple2[TensorND, TensorND] = js.native
7158
def oneHot(indices: TensorND, numClasses: Double): TensorND = js.native
7259
def gather(
7360
reference: TensorND,
@@ -85,14 +72,7 @@ object Tfjs_backend extends js.Object {
8572
noiseShape: js.Array[Double] = ???,
8673
seed: Double = ???
8774
): TensorND = js.native
88-
def nameScope[T](name: String, fn: js.Function0[T]): T = js.native
89-
def floatx(): DataType = js.native
90-
def getUid(prefix: String = ???): String = js.native
9175
def hardSigmoid(x: TensorND): TensorND = js.native
9276
def inTrainPhase[T](x: js.Function0[T], alt: js.Function0[T], training: Boolean = ???): T =
9377
js.native
94-
def gradients(
95-
lossFn: js.Function0[Scalar],
96-
variables: js.Array[LayerVariable]
97-
): js.Array[TensorND] = js.native
9878
}

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/callbacks.scala

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -27,111 +27,6 @@ import js.{ Promise, | }
2727
@js.native
2828
@JSImport("@tensorflow/tfjs-layers", "Callback")
2929
abstract class Callback extends js.Object {
30-
var validationData: TensorND | js.Array[TensorND] = js.native
3130
var model: Model = js.native
32-
var params: Params = js.native
33-
def setParams(params: Params): Unit = js.native
3431
def setModel(model: Model): Unit = js.native
35-
def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
36-
def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
37-
def onBatchBegin(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
38-
def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
39-
def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
40-
def onTrainEnd(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
41-
}
42-
43-
@js.native
44-
@JSImport("@tensorflow/tfjs-layers", "CallbackList")
45-
class CallbackList protected () extends js.Object {
46-
def this(callbacks: js.Array[Callback] = ???, queueLength: Double = ???) = this()
47-
var callbacks: js.Array[Callback] = js.native
48-
var queueLength: Double = js.native
49-
def append(callback: Callback): Unit = js.native
50-
def setParams(params: Params): Unit = js.native
51-
def setModel(model: Model): Unit = js.native
52-
def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
53-
def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
54-
def onBatchBegin(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
55-
def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
56-
def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
57-
def onTrainEnd(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
58-
}
59-
60-
@js.native
61-
@JSGlobal
62-
class BaseLogger extends Callback {
63-
override def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
64-
override def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
65-
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
66-
}
67-
68-
/**
69-
* Callback that records events into a `History` object. This callback is
70-
* automatically applied to every TF.js Layers model. The `History` object gets
71-
* returned by the `fit` method of models.
72-
*/
73-
@js.native
74-
@JSGlobal
75-
class History extends Callback {
76-
var epoch: js.Array[Double] = js.native
77-
var history: History.History = js.native
78-
override def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
79-
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
80-
def syncData(): Promise[Unit] = js.native
81-
}
82-
83-
object History {
84-
85-
@js.native
86-
trait History extends js.Object {
87-
@JSBracketAccess
88-
def apply(key: String): js.Array[Double | TensorND] = js.native
89-
@JSBracketAccess
90-
def update(key: String, v: js.Array[Double | TensorND]): Unit = js.native
91-
}
92-
}
93-
94-
@js.native
95-
trait CustomCallbackConfig extends js.Object {
96-
var onTrainBegin: js.Function1[Logs, Promise[Unit]] = js.native
97-
var onTrainEnd: js.Function1[Logs, Promise[Unit]] = js.native
98-
var onEpochBegin: js.Function2[Double, Logs, Promise[Unit]] = js.native
99-
var onEpochEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
100-
var onBatchBegin: js.Function2[Double, Logs, Promise[Unit]] = js.native
101-
var onBatchEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
102-
}
103-
104-
@js.native
105-
@JSImport("@tensorflow/tfjs-layers", "CustomCallback")
106-
class CustomCallback protected () extends Callback {
107-
def this(config: CustomCallbackConfig) = this()
108-
protected def trainBegin: js.Function1[Logs, Promise[Unit]] = js.native
109-
protected def trainEnd: js.Function1[Logs, Promise[Unit]] = js.native
110-
protected def epochBegin: js.Function2[Double, Logs, Promise[Unit]] = js.native
111-
protected def epochEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
112-
protected def batchBegin: js.Function2[Double, Logs, Promise[Unit]] = js.native
113-
protected def batchEnd: js.Function2[Double, Logs, Promise[Unit]] = js.native
114-
override def onEpochBegin(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
115-
override def onEpochEnd(epoch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
116-
override def onBatchBegin(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
117-
override def onBatchEnd(batch: Double, logs: UnresolvedLogs = ???): Promise[Unit] = js.native
118-
override def onTrainBegin(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
119-
override def onTrainEnd(logs: UnresolvedLogs = ???): Promise[Unit] = js.native
120-
}
121-
122-
@js.native
123-
@JSGlobalScope
124-
object Callbacks extends js.Object {
125-
type UnresolvedLogs = js.Dictionary[Double | Scalar]
126-
type Logs = js.Dictionary[Double]
127-
type Params = js.Dictionary[
128-
Double | String | Boolean | js.Array[Double] | js.Array[String] | js.Array[Boolean]
129-
]
130-
def resolveScalarsInLogs(logs: UnresolvedLogs): Promise[Unit] = js.native
131-
def disposeTensorsInLogs(logs: UnresolvedLogs): Unit = js.native
132-
def standardizeCallbacks(
133-
callbacks: Callback | js.Array[Callback] | CustomCallbackConfig | js.Array[
134-
CustomCallbackConfig
135-
]
136-
): js.Array[Callback] = js.native
13732
}

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/engine/topology.scala

Lines changed: 26 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,32 @@ object InputSpec {
7070
}
7171
}
7272

73+
@js.native
74+
@JSGlobal
75+
class SymbolicTensor protected () extends js.Object {
76+
def this(
77+
dtype: DataType,
78+
shape: Shape,
79+
sourceLayer: Layer,
80+
inputs: js.Array[SymbolicTensor],
81+
callArgs: Kwargs,
82+
name: String = ???,
83+
outputTensorIndex: Double = ???
84+
) = this()
85+
def dtype: DataType = js.native
86+
def shape: Shape = js.native
87+
var sourceLayer: Layer = js.native
88+
def inputs: js.Array[SymbolicTensor] = js.native
89+
def callArgs: Kwargs = js.native
90+
def outputTensorIndex: Double = js.native
91+
def id: Double = js.native
92+
def name: String = js.native
93+
def originalName: String = js.native
94+
def rank: Double = js.native
95+
var nodeIndex: Double = js.native
96+
var tensorIndex: Double = js.native
97+
}
98+
7399
@js.native
74100
trait NodeConfig extends js.Object {
75101
var outboundLayer: Layer = js.native
@@ -188,143 +214,14 @@ object Layer extends js.Object {
188214
def nodeKey(layer: Layer, nodeIndex: Double): String = js.native
189215
}
190216

191-
@js.native
192-
trait InputLayerConfig extends js.Object {
193-
var inputShape: Shape = js.native
194-
var batchSize: Double = js.native
195-
var batchInputShape: Shape = js.native
196-
var dtype: DataType = js.native
197-
var sparse: Boolean = js.native
198-
var name: String = js.native
199-
}
200-
201-
@js.native
202-
@JSGlobal
203-
class InputLayer protected () extends Layer {
204-
def this(config: InputLayerConfig) = this()
205-
var sparse: Boolean = js.native
206-
// TODO until we have real union types aka dotty, we have to stick with the wider type of the base class
207-
// @JSName("apply")
208-
//override def apply(inputs: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor], kwargs: Kwargs = ???): TensorND | js.Array[TensorND] | SymbolicTensor = js.native
209-
override def getConfig(): serialization.ConfigDict = js.native
210-
}
211-
212-
@js.native
213-
@JSGlobal
214-
object InputLayer extends js.Object {
215-
def className: String = js.native
216-
}
217-
218-
@js.native
219-
trait InputConfig extends js.Object {
220-
var shape: Shape = js.native
221-
var batchShape: Shape = js.native
222-
var name: String = js.native
223-
var dtype: DataType = js.native
224-
var sparse: Boolean = js.native
225-
}
226-
227-
@js.native
228-
trait ContainerConfig extends js.Object {
229-
var inputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
230-
var outputs: SymbolicTensor | js.Array[SymbolicTensor] = js.native
231-
var name: String = js.native
232-
}
233-
234-
@js.native
235-
@JSGlobal
236-
abstract class Container protected () extends Layer {
237-
def this(config: ContainerConfig) = this()
238-
var inputs: js.Array[SymbolicTensor] = js.native
239-
var outputs: js.Array[SymbolicTensor] = js.native
240-
var inputLayers: js.Array[Layer] = js.native
241-
var inputLayersNodeIndices: js.Array[Double] = js.native
242-
var inputLayersTensorIndices: js.Array[Double] = js.native
243-
var outputLayers: js.Array[Layer] = js.native
244-
var outputLayersNodeIndices: js.Array[Double] = js.native
245-
var outputLayersTensorIndices: js.Array[Double] = js.native
246-
var layers: js.Array[Layer] = js.native
247-
var layersByDepth: Container.LayersByDepth = js.native
248-
var nodesByDepth: Container.NodesByDepth = js.native
249-
var containerNodes: Set[String] = js.native
250-
var inputNames: js.Array[String] = js.native
251-
var outputNames: js.Array[String] = js.native
252-
var feedInputShapes: js.Array[Shape] = js.native
253-
protected var internalInputShapes: js.Array[Shape] = js.native
254-
protected var internalOutputShapes: js.Array[Shape] = js.native
255-
protected var feedInputNames: js.Array[String] = js.native
256-
protected var feedOutputNames: js.Array[String] = js.native
257-
// TODO This is a var in the base class. Unlike TS Scala does not allow to override a var with a def or val in subclasses
258-
//def trainableWeights: js.Array[LayerVariable] = js.native
259-
//def nonTrainableWeights: js.Array[LayerVariable] = js.native
260-
override def weights: js.Array[LayerVariable] = js.native
261-
def loadWeights(
262-
weightsJSON: JsonDict | NamedTensorMap,
263-
skipMismatch: Boolean = ???,
264-
isNamedTensorMap: Boolean = ???
265-
): Unit = js.native
266-
def toJSON(unused: js.Any = ???, returnString: Boolean = ???): String | JsonDict = js.native
267-
override def call(
268-
inputs: TensorND | js.Array[TensorND],
269-
kwargs: Kwargs
270-
): TensorND | js.Array[TensorND] = js.native
271-
override def computeMask(
272-
inputs: TensorND | js.Array[TensorND],
273-
mask: TensorND | js.Array[TensorND] = ???
274-
): TensorND | js.Array[TensorND] = js.native
275-
override def computeOutputShape(inputShape: Shape | js.Array[Shape]): Shape | js.Array[Shape] =
276-
js.native
277-
def runInternalGraph(
278-
inputs: js.Array[TensorND],
279-
masks: js.Array[TensorND] = ???
280-
): js.Tuple3[js.Array[TensorND], js.Array[TensorND], js.Array[Shape]] = js.native
281-
def getLayer(name: String = ???, index: Double = ???): Layer = js.native
282-
override def calculateLosses(): js.Array[Scalar] = js.native
283-
override def getConfig(): serialization.ConfigDict = js.native
284-
override def stateful: Boolean = js.native
285-
}
286-
287-
@js.native
288-
@JSGlobal
289-
object Container extends js.Object {
290-
291-
@js.native
292-
trait LayersByDepth extends js.Object {
293-
@JSBracketAccess
294-
def apply(depth: String): js.Array[Layer] = js.native
295-
@JSBracketAccess
296-
def update(depth: String, v: js.Array[Layer]): Unit = js.native
297-
}
298-
299-
@js.native
300-
trait NodesByDepth extends js.Object {
301-
@JSBracketAccess
302-
def apply(depth: String): js.Array[Node] = js.native
303-
@JSBracketAccess
304-
def update(depth: String, v: js.Array[Node]): Unit = js.native
305-
}
306-
def fromConfig[T <: serialization.Serializable](
307-
cls: serialization.SerializableConstructor[T],
308-
config: serialization.ConfigDict
309-
): T = js.native
310-
}
311-
312217
@js.native
313218
@JSGlobalScope
314219
object Topology extends js.Object {
315220
type Op = js.Function1[LayerVariable, LayerVariable]
316221
type CallHook = js.Function2[TensorND | js.Array[TensorND], Kwargs, Unit]
317-
def Input(config: InputConfig): SymbolicTensor = js.native
318222
def getSourceInputs(
319223
tensor: SymbolicTensor,
320224
layer: Layer = ???,
321225
nodeIndex: Double = ???
322226
): js.Array[SymbolicTensor] = js.native
323-
def loadWeightsFromNamedTensorMap(weights: NamedTensorMap, layers: js.Array[Layer]): Unit =
324-
js.native
325-
def loadWeightsFromJson(
326-
weightsJSON: JsonDict,
327-
layers: js.Array[Layer],
328-
skipMismatch: Boolean = ???
329-
): Unit = js.native
330227
}

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/engine/training.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ trait ModelFitConfig extends js.Object {
6565
var initialEpoch: Double = js.native
6666
var stepsPerEpoch: Double = js.native
6767
var validationSteps: Double = js.native
68+
var yieldEvery: String = js.native
6869
}
6970

7071
object ModelFitConfig {

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/layers/convolutional.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ abstract class BaseConv protected () extends Layer {
7171
protected var bias: LayerVariable = js.native
7272
def DEFAULT_KERNEL_INITIALIZER: InitializerIdentifier = js.native
7373
def DEFAULT_BIAS_INITIALIZER: InitializerIdentifier = js.native
74+
override def getConfig(): serialization.ConfigDict = js.native
7475
}
7576

7677
@js.native

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/layers/merge.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class Concatenate protected () extends Merge {
118118
override def mergeFunction(inputs: js.Array[TensorND]): TensorND = js.native
119119
override def computeOutputShape(inputShape: Shape | js.Array[Shape]): Shape | js.Array[Shape] =
120120
js.native
121+
override def getConfig(): serialization.ConfigDict = js.native
121122
}
122123

123124
@js.native

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/layers/recurrent.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ class RNN protected () extends Layer {
6767
//override def computeMask(inputs: TensorND | js.Array[TensorND], mask: TensorND | js.Array[TensorND] = ???): TensorND = js.native
6868
override def build(inputShape: Shape | js.Array[Shape]): Unit = js.native
6969
def resetStates(states: TensorND | js.Array[TensorND] = ???): Unit = js.native
70-
def standardizeArgs(
71-
inputs: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor],
72-
initialState: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor],
73-
constants: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor]
74-
): js.Any = js.native
7570
@JSName("apply")
7671
override def apply(
7772
inputs: TensorND | js.Array[TensorND] | SymbolicTensor | js.Array[SymbolicTensor],

tfjs-layers/src/main/scala/io/brunk/tfjs/layers/layers/wrappers.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class Bidirectional protected () extends Wrapper {
114114
@JSGlobal
115115
object Bidirectional extends js.Object {
116116
var className: String = js.native
117+
def fromConfig[T <: serialization.Serializable](
118+
cls: serialization.SerializableConstructor[T],
119+
config: serialization.ConfigDict
120+
): T = js.native
117121
}
118122

119123
@js.native

0 commit comments

Comments
 (0)