-
Notifications
You must be signed in to change notification settings - Fork 244
synchronize MLXNN code with python implementation #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb76f9b
e37890f
e6a1e65
0435762
616f706
201c6c5
6a9e39d
0759d82
d409e7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,28 @@ public func softsign(_ x: MLXArray) -> MLXArray { | |
| compiledSoftsign(x) | ||
| } | ||
|
|
||
| /// Applies the Softshrink activation function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// if x > lambda { | ||
| /// x - lambda | ||
| /// } else if x < -lambda { | ||
| /// x + lambda | ||
| /// } else { | ||
| /// 0 | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// - Parameters: | ||
| /// - x: input array | ||
| /// - lambda: lambda value | ||
| public func softshrink(_ x: MLXArray, lambda: Float = 0.5) -> MLXArray { | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Several missing activations |
||
| let lambda = lambda.asMLXArray(dtype: x.dtype) | ||
| return compiledSoftshrink(x, lambda) | ||
| } | ||
|
|
||
| /// Applies the Continuously Differentiable Exponential Linear Unit. | ||
| /// | ||
| /// This is: | ||
|
|
@@ -353,6 +375,61 @@ public func hardSwish(_ x: MLXArray) -> MLXArray { | |
| compiledHardSwish(x) | ||
| } | ||
|
|
||
| /// Applies the HardTanh function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// maximum(minimum(x, max), min) | ||
| /// ``` | ||
| /// | ||
| /// - Parameters: | ||
| /// - x: input array | ||
| /// - min: minimum value | ||
| /// - max: maximum value | ||
| public func hardTanH(_ x: MLXArray, min: Float = -1, max: Float = 1) -> MLXArray { | ||
| let min = min.asMLXArray(dtype: x.dtype) | ||
| let max = max.asMLXArray(dtype: x.dtype) | ||
| return compiledHardTanh(x, min, max) | ||
| } | ||
|
|
||
| /// Applies the HardShrink activation function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// if x > lambda { | ||
| /// x | ||
| /// } else if x < -lambda { | ||
| /// x | ||
| /// } else { | ||
| /// 0 | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// - Parameters: | ||
| /// - x: input array | ||
| /// - lambda: lambda value | ||
| public func hardShrink(_ x: MLXArray, lambda: Float = 0.5) -> MLXArray { | ||
| let lambda = lambda.asMLXArray(dtype: x.dtype) | ||
| return compiledHardShrink(x, lambda) | ||
| } | ||
|
|
||
| /// Applies the Softmin function. | ||
| /// | ||
| /// This operation is a numerically stable version of: | ||
| /// | ||
| /// ```swift | ||
| ///exp(-a) / sum(exp(-a), axis, keepdims: true) | ||
| /// ``` | ||
| /// | ||
| /// - Parameters: | ||
| /// - x: input array | ||
| /// - axis: axis to evaluate on | ||
| public func softmin(_ x: MLXArray, axis: Int = -1) -> MLXArray { | ||
| softmax(-x, axis: axis) | ||
| } | ||
|
|
||
| /// Applies the gated linear unit function. | ||
| /// | ||
| /// This function splits the `axis` dimension of the input into two halves | ||
|
|
@@ -520,6 +597,29 @@ open class Softmax: Module, UnaryLayer { | |
| } | ||
| } | ||
|
|
||
| /// Applies the Softmin function. | ||
| /// | ||
| /// This operation is a numerically stable version of: | ||
| /// | ||
| /// ```swift | ||
| /// exp(-a) / sum(exp(-a), axis, keepdims: true) | ||
| /// ``` | ||
| /// | ||
| /// ### See Also | ||
| /// - <doc:activations> | ||
| /// - ``softmin(_:axis:)`` | ||
| open class Softmin: Module, UnaryLayer { | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And their layers |
||
| public var axis: Int | ||
|
|
||
| public init(axis: Int = -1) { | ||
| self.axis = axis | ||
| } | ||
|
|
||
| open func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
| softmin(x, axis: axis) | ||
| } | ||
| } | ||
|
|
||
| @available(*, deprecated, renamed: "Softplus") | ||
| @_documentation(visibility: internal) | ||
| open class SoftPlus: Module, UnaryLayer { | ||
|
|
@@ -570,6 +670,60 @@ open class Softsign: Module, UnaryLayer { | |
| } | ||
| } | ||
|
|
||
| /// Applies the Softshrink activation function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// if x > lambda { | ||
| /// x - lambda | ||
| /// } else if x < -lambda { | ||
| /// x + lambda | ||
| /// } else { | ||
| /// 0 | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// ### See Also | ||
| /// - <doc:activations> | ||
| /// - ``softshrink(_:lambda:)`` | ||
| open class Softshrink: Module, UnaryLayer { | ||
| public var lambda: Float | ||
|
|
||
| public init(lambda: Float = 0.5) { | ||
| self.lambda = lambda | ||
| super.init() | ||
| } | ||
|
|
||
| open func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
| softshrink(x, lambda: lambda) | ||
| } | ||
| } | ||
|
|
||
| /// Applies the Exponential Linear Unit. | ||
| /// | ||
| /// This is: | ||
| /// | ||
| /// ```swift | ||
| /// MLX.which(x .> 0, x, alpha * (exp(x) - 1)) | ||
| /// ``` | ||
| /// | ||
| /// ### See Also | ||
| /// - <doc:activations> | ||
| /// - ``elu(_:alpha:)`` | ||
| open class ELU: Module, UnaryLayer { | ||
| public var alpha: Float | ||
|
|
||
| public init(alpha: Float = 1.0) { | ||
| self.alpha = alpha | ||
| super.init() | ||
| } | ||
|
|
||
| open func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
| elu(x, alpha: alpha) | ||
| } | ||
| } | ||
|
|
||
| /// Applies the Continuously Differentiable Exponential Linear Unit. | ||
| /// | ||
| /// This is: | ||
|
|
@@ -753,6 +907,62 @@ open class HardSwish: Module, UnaryLayer { | |
| } | ||
| } | ||
|
|
||
| /// Applies the HardTanh function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// maximum(minimum(x, max), min) | ||
| /// ``` | ||
| /// | ||
| /// ### See Also | ||
| /// - <doc:activations> | ||
| /// - ``hardTanH(_:min:max:)`` | ||
| open class HardTanh: Module, UnaryLayer { | ||
| public var min: Float | ||
| public var max: Float | ||
|
|
||
| public init(min: Float = -1, max: Float = 1) { | ||
| self.min = min | ||
| self.max = max | ||
| super.init() | ||
| } | ||
|
|
||
| open func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
| hardTanH(x, min: min, max: max) | ||
| } | ||
| } | ||
|
|
||
| /// Applies the HardShrink activation function. | ||
| /// | ||
| /// This is (element-wise): | ||
| /// | ||
| /// ```swift | ||
| /// if x > lambda { | ||
| /// x | ||
| /// } else if x < -lambda { | ||
| /// x | ||
| /// } else { | ||
| /// 0 | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// ### See Also | ||
| /// - <doc:activations> | ||
| /// - ``hardShrink(_:lambda:)`` | ||
| open class HardShrink: Module, UnaryLayer { | ||
| public var lambda: Float | ||
|
|
||
| public init(lambda: Float = 0.5) { | ||
| self.lambda = lambda | ||
| super.init() | ||
| } | ||
|
|
||
| open func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
| hardShrink(x, lambda: lambda) | ||
| } | ||
| } | ||
|
|
||
| /// Applies the Step Activation Function. | ||
| /// | ||
| /// This function implements a binary step activation, where the output is set | ||
|
|
@@ -824,6 +1034,12 @@ private let compiledSoftsign: @Sendable (MLXArray) -> MLXArray = { | |
| } | ||
| }() | ||
|
|
||
| private let compiledSoftshrink: @Sendable (MLXArray, MLXArray) -> MLXArray = { | ||
| compile(shapeless: true) { x, lambda in | ||
| which(abs(x) .> lambda, x - sign(x) * lambda, 0) | ||
| } | ||
| }() | ||
|
|
||
| private let compiledCelu: @Sendable (MLXArray, MLXArray) -> MLXArray = { | ||
| compile(shapeless: true) { x, alpha in | ||
| maximum(x, 0.0) + alpha * (exp(minimum(x, 0.0) / alpha) - 1) | ||
|
|
@@ -885,6 +1101,18 @@ private let compiledHardSwish: @Sendable (MLXArray) -> MLXArray = { | |
| } | ||
| }() | ||
|
|
||
| private let compiledHardTanh: @Sendable (MLXArray, MLXArray, MLXArray) -> MLXArray = { | ||
| compile(shapeless: true) { x, min, max in | ||
| minimum(maximum(x, min), max) | ||
| } | ||
| }() | ||
|
|
||
| private let compiledHardShrink: @Sendable (MLXArray, MLXArray) -> MLXArray = { | ||
| compile(shapeless: true) { x, lambda in | ||
| which(abs(x) .> lambda, x, 0) | ||
| } | ||
| }() | ||
|
|
||
| private let compiledRelu: @Sendable (MLXArray) -> MLXArray = { | ||
| compile(shapeless: true) { x in | ||
| maximum(x, 0) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,12 @@ open class Conv1d: Module, UnaryLayer { | |
| precondition(inputChannels % groups == 0, "Input channels must be divisible by groups") | ||
|
|
||
| self.weight = MLXRandom.uniform( | ||
| low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups]) | ||
| low: -scale, high: scale, | ||
| [ | ||
| outputChannels, | ||
| kernelSize, | ||
| inputChannels / groups, | ||
| ]) | ||
| self.bias = bias ? MLXArray.zeros([outputChannels]) : nil | ||
| self.padding = padding | ||
| self.dilation = dilation | ||
|
|
@@ -117,7 +122,11 @@ open class Conv2d: Module, UnaryLayer { | |
|
|
||
| self.weight = MLXRandom.uniform( | ||
| low: -scale, high: scale, | ||
| [outputChannels, kernelSize.first, kernelSize.second, inputChannels / groups]) | ||
| [ | ||
| outputChannels, | ||
| kernelSize.first, kernelSize.second, | ||
| inputChannels / groups, | ||
| ]) | ||
| self.bias = bias ? MLXArray.zeros([outputChannels]) : nil | ||
| self.padding = padding.values | ||
| self.dilation = dilation.values | ||
|
|
@@ -185,7 +194,11 @@ open class Conv3d: Module, UnaryLayer { | |
|
|
||
| self.weight = MLXRandom.uniform( | ||
| low: -scale, high: scale, | ||
| [outputChannels, kernelSize.first, kernelSize.second, kernelSize.third, inputChannels]) | ||
| [ | ||
| outputChannels, | ||
| kernelSize.first, kernelSize.second, kernelSize.third, | ||
| inputChannels / groups, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inputChannels should be dived by groups. curiously the swift version has
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should fix that in Python. |
||
| ]) | ||
| self.bias = bias ? MLXArray.zeros([outputChannels]) : nil | ||
| self.padding = padding.values | ||
| self.dilation = dilation.values | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed this was missing while working on this (match the other
compile()implementations)