Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional cases for #122 #136

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions Source/MLXNN/Activations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ public func hardSwish(_ x: MLXArray) -> MLXArray {
/// - <doc:activations>
/// - ``glu(_:axis:)``
open class GLU: Module, UnaryLayer {
public let axis: Int
public var axis: Int
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to var for consistency


public init(axis: Int = -1) {
self.axis = axis
Expand Down Expand Up @@ -430,7 +430,7 @@ open class ReLU: Module, UnaryLayer {
/// - ``leakyRelu(_:negativeSlope:)``
open class LeakyReLU: Module, UnaryLayer {

public let negativeSlope: Float
public var negativeSlope: Float

public init(negativeSlope: Float = 0.01) {
self.negativeSlope = negativeSlope
Expand Down Expand Up @@ -458,6 +458,14 @@ open class ReLU6: Module, UnaryLayer {
}
}

@available(*, deprecated, renamed: "Softmax")
@_documentation(visibility:internal)
open class SoftMax: Module, UnaryLayer {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this one earlier -- this redirects it to be named Softmax

open func callAsFunction(_ x: MLXArray) -> MLXArray {
softmax(x)
}
}

/// Applies the Softmax function.
///
/// This is:
Expand All @@ -468,9 +476,15 @@ open class ReLU6: Module, UnaryLayer {
///
/// ### See Also
/// - <doc:activations>
open class SoftMax: Module, UnaryLayer {
open class Softmax: Module, UnaryLayer {
public var axis: Int

public init(axis: Int = -1) {
self.axis = axis
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GLU already had this (exposed axis parameter) and the original request was to add negative slope. This just exposes axis on Softmax and LogSoftmax.


open func callAsFunction(_ x: MLXArray) -> MLXArray {
softmax(x, axis: -1)
softmax(x, axis: axis)
}
}

Expand Down Expand Up @@ -536,7 +550,7 @@ open class Softsign: Module, UnaryLayer {
/// - <doc:activations>
/// - ``celu(_:alpha:)``
open class CELU: Module, UnaryLayer {
public let alpha: Float
public var alpha: Float

public init(alpha: Float = 1.0) {
self.alpha = alpha
Expand Down Expand Up @@ -585,8 +599,14 @@ open class LogSoftMax: Module, UnaryLayer {
/// - <doc:activations>
/// - ``logSoftmax(_:axis:)``
open class LogSoftmax: Module, UnaryLayer {
public var axis: Int

public init(axis: Int = -1) {
self.axis = axis
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
logSoftmax(x)
logSoftmax(x, axis: axis)
}
}

Expand Down Expand Up @@ -715,7 +735,7 @@ open class HardSwish: Module, UnaryLayer {
/// - ``step(_:threshold:)``
open class Step: Module, UnaryLayer {

public let threshold: Float
public var threshold: Float

public init(threshold: Float = 0.0) {
self.threshold = threshold
Expand Down
2 changes: 1 addition & 1 deletion Tests/MLXTests/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6433,7 +6433,7 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(
a.sum().item(Float.self), 131.68545532226562,
accuracy: 2.6337091064453126)
let result = SoftMax()(a)
let result = Softmax()(a)
XCTAssertEqual(result.shape, [2, 8, 16])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
Expand Down