Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/MLX/Transforms+Compile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public func compile(
inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false,
_ f: @Sendable @escaping (MLXArray, MLXArray, MLXArray) -> MLXArray
)
-> (MLXArray, MLXArray, MLXArray) -> MLXArray
-> @Sendable (MLXArray, MLXArray, MLXArray) -> MLXArray

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this was missing while working on this (match the other compile() implementations)

{
let compileState = CompiledFunction(inputs: inputs, outputs: outputs, shapeless: shapeless) {
[f($0[0], $0[1], $0[2])]
Expand Down
228 changes: 228 additions & 0 deletions Source/MLXNN/Activations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ public func softsign(_ x: MLXArray) -> MLXArray {
compiledSoftsign(x)
}

/// Applies the Softshrink activation function.
///
/// This is (element-wise):
///
/// ```swift
/// if x > lambda {
/// x - lambda
/// } else if x < -lambda {
/// x + lambda
/// } else {
/// 0
/// }
/// ```
///
/// - Parameters:
/// - x: input array
/// - lambda: lambda value
public func softshrink(_ x: MLXArray, lambda: Float = 0.5) -> MLXArray {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several missing activations

let lambda = lambda.asMLXArray(dtype: x.dtype)
return compiledSoftshrink(x, lambda)
}

/// Applies the Continuously Differentiable Exponential Linear Unit.
///
/// This is:
Expand Down Expand Up @@ -353,6 +375,61 @@ public func hardSwish(_ x: MLXArray) -> MLXArray {
compiledHardSwish(x)
}

/// Applies the HardTanh function.
///
/// This is (element-wise):
///
/// ```swift
/// maximum(minimum(x, max), min)
/// ```
///
/// - Parameters:
/// - x: input array
/// - min: minimum value
/// - max: maximum value
public func hardTanH(_ x: MLXArray, min: Float = -1, max: Float = 1) -> MLXArray {
let min = min.asMLXArray(dtype: x.dtype)
let max = max.asMLXArray(dtype: x.dtype)
return compiledHardTanh(x, min, max)
}

/// Applies the HardShrink activation function.
///
/// This is (element-wise):
///
/// ```swift
/// if x > lambda {
/// x
/// } else if x < -lambda {
/// x
/// } else {
/// 0
/// }
/// ```
///
/// - Parameters:
/// - x: input array
/// - lambda: lambda value
public func hardShrink(_ x: MLXArray, lambda: Float = 0.5) -> MLXArray {
let lambda = lambda.asMLXArray(dtype: x.dtype)
return compiledHardShrink(x, lambda)
}

/// Applies the Softmin function.
///
/// This operation is a numerically stable version of:
///
/// ```swift
///exp(-a) / sum(exp(-a), axis, keepdims: true)
/// ```
///
/// - Parameters:
/// - x: input array
/// - axis: axis to evaluate on
public func softmin(_ x: MLXArray, axis: Int = -1) -> MLXArray {
softmax(-x, axis: axis)
}

/// Applies the gated linear unit function.
///
/// This function splits the `axis` dimension of the input into two halves
Expand Down Expand Up @@ -520,6 +597,29 @@ open class Softmax: Module, UnaryLayer {
}
}

/// Applies the Softmin function.
///
/// This operation is a numerically stable version of:
///
/// ```swift
/// exp(-a) / sum(exp(-a), axis, keepdims: true)
/// ```
///
/// ### See Also
/// - <doc:activations>
/// - ``softmin(_:axis:)``
open class Softmin: Module, UnaryLayer {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And their layers

public var axis: Int

public init(axis: Int = -1) {
self.axis = axis
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
softmin(x, axis: axis)
}
}

@available(*, deprecated, renamed: "Softplus")
@_documentation(visibility: internal)
open class SoftPlus: Module, UnaryLayer {
Expand Down Expand Up @@ -570,6 +670,60 @@ open class Softsign: Module, UnaryLayer {
}
}

/// Applies the Softshrink activation function.
///
/// This is (element-wise):
///
/// ```swift
/// if x > lambda {
/// x - lambda
/// } else if x < -lambda {
/// x + lambda
/// } else {
/// 0
/// }
/// ```
///
/// ### See Also
/// - <doc:activations>
/// - ``softshrink(_:lambda:)``
open class Softshrink: Module, UnaryLayer {
public var lambda: Float

public init(lambda: Float = 0.5) {
self.lambda = lambda
super.init()
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
softshrink(x, lambda: lambda)
}
}

/// Applies the Exponential Linear Unit.
///
/// This is:
///
/// ```swift
/// MLX.which(x .> 0, x, alpha * (exp(x) - 1))
/// ```
///
/// ### See Also
/// - <doc:activations>
/// - ``elu(_:alpha:)``
open class ELU: Module, UnaryLayer {
public var alpha: Float

public init(alpha: Float = 1.0) {
self.alpha = alpha
super.init()
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
elu(x, alpha: alpha)
}
}

/// Applies the Continuously Differentiable Exponential Linear Unit.
///
/// This is:
Expand Down Expand Up @@ -753,6 +907,62 @@ open class HardSwish: Module, UnaryLayer {
}
}

/// Applies the HardTanh function.
///
/// This is (element-wise):
///
/// ```swift
/// maximum(minimum(x, max), min)
/// ```
///
/// ### See Also
/// - <doc:activations>
/// - ``hardTanH(_:min:max:)``
open class HardTanh: Module, UnaryLayer {
public var min: Float
public var max: Float

public init(min: Float = -1, max: Float = 1) {
self.min = min
self.max = max
super.init()
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
hardTanH(x, min: min, max: max)
}
}

/// Applies the HardShrink activation function.
///
/// This is (element-wise):
///
/// ```swift
/// if x > lambda {
/// x
/// } else if x < -lambda {
/// x
/// } else {
/// 0
/// }
/// ```
///
/// ### See Also
/// - <doc:activations>
/// - ``hardShrink(_:lambda:)``
open class HardShrink: Module, UnaryLayer {
public var lambda: Float

public init(lambda: Float = 0.5) {
self.lambda = lambda
super.init()
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
hardShrink(x, lambda: lambda)
}
}

/// Applies the Step Activation Function.
///
/// This function implements a binary step activation, where the output is set
Expand Down Expand Up @@ -824,6 +1034,12 @@ private let compiledSoftsign: @Sendable (MLXArray) -> MLXArray = {
}
}()

private let compiledSoftshrink: @Sendable (MLXArray, MLXArray) -> MLXArray = {
compile(shapeless: true) { x, lambda in
which(abs(x) .> lambda, x - sign(x) * lambda, 0)
}
}()

private let compiledCelu: @Sendable (MLXArray, MLXArray) -> MLXArray = {
compile(shapeless: true) { x, alpha in
maximum(x, 0.0) + alpha * (exp(minimum(x, 0.0) / alpha) - 1)
Expand Down Expand Up @@ -885,6 +1101,18 @@ private let compiledHardSwish: @Sendable (MLXArray) -> MLXArray = {
}
}()

private let compiledHardTanh: @Sendable (MLXArray, MLXArray, MLXArray) -> MLXArray = {
compile(shapeless: true) { x, min, max in
minimum(maximum(x, min), max)
}
}()

private let compiledHardShrink: @Sendable (MLXArray, MLXArray) -> MLXArray = {
compile(shapeless: true) { x, lambda in
which(abs(x) .> lambda, x, 0)
}
}()

private let compiledRelu: @Sendable (MLXArray) -> MLXArray = {
compile(shapeless: true) { x in
maximum(x, 0)
Expand Down
19 changes: 16 additions & 3 deletions Source/MLXNN/Convolution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ open class Conv1d: Module, UnaryLayer {
precondition(inputChannels % groups == 0, "Input channels must be divisible by groups")

self.weight = MLXRandom.uniform(
low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups])
low: -scale, high: scale,
[
outputChannels,
kernelSize,
inputChannels / groups,
])
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
self.padding = padding
self.dilation = dilation
Expand Down Expand Up @@ -117,7 +122,11 @@ open class Conv2d: Module, UnaryLayer {

self.weight = MLXRandom.uniform(
low: -scale, high: scale,
[outputChannels, kernelSize.first, kernelSize.second, inputChannels / groups])
[
outputChannels,
kernelSize.first, kernelSize.second,
inputChannels / groups,
])
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
self.padding = padding.values
self.dilation = dilation.values
Expand Down Expand Up @@ -185,7 +194,11 @@ open class Conv3d: Module, UnaryLayer {

self.weight = MLXRandom.uniform(
low: -scale, high: scale,
[outputChannels, kernelSize.first, kernelSize.second, kernelSize.third, inputChannels])
[
outputChannels,
kernelSize.first, kernelSize.second, kernelSize.third,
inputChannels / groups,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputChannels should be dived by groups. curiously the swift version has groups for all 3 convolutions but python only has it for 1d and 2d 🤷

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix that in Python.

])
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
self.padding = padding.values
self.dilation = dilation.values
Expand Down
Loading