Created
May 6, 2025 13:44
-
-
Save awni/a760f92cec559060eb70947f2dd507a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GLU: Module, UnaryLayer { | |
let dim: Int | |
init(dim: Int) { | |
self.dim = dim | |
} | |
func callAsFunction(_ x: MLXArray) -> MLXArray { | |
let (a, b) = x.split(axis: dim) | |
return a * MLXNN.sigmoid(b) | |
} | |
} | |
class BLSTM: Module, UnaryLayer { | |
let lstm: LSTM | |
let linear: Linear | |
init(inputSize: Int, hiddenSize: Int, numLayers: Int = 2) { | |
self.lstm = LSTM(inputSize: inputSize, hiddenSize: hiddenSize) | |
self.linear = Linear(hiddenSize * 2, hiddenSize) | |
} | |
func callAsFunction(_ x: MLXArray) -> MLXArray { | |
let (lstmOut, _) = lstm(x) | |
return linear(lstmOut) | |
} | |
} | |
class EncoderBlock: Module, UnaryLayer { | |
let conv1: Conv1d | |
let relu: ReLU | |
let conv2: Conv1d | |
let glu: GLU | |
init(inChannels: Int, outChannels: Int) { | |
self.conv1 = Conv1d( | |
inputChannels: inChannels, outputChannels: outChannels, kernelSize: 8, stride: 4) | |
self.relu = ReLU() | |
self.conv2 = Conv1d( | |
inputChannels: outChannels, outputChannels: outChannels * 2, kernelSize: 1, stride: 1) | |
self.glu = GLU(dim: 1) | |
} | |
func callAsFunction(_ x: MLXArray) -> MLXArray { | |
var x = conv1(x) | |
x = relu(x) | |
x = conv2(x) | |
x = glu(x) | |
return x | |
} | |
} | |
class DecoderBlock: Module, UnaryLayer { | |
let conv1: Conv1d | |
let glu: GLU | |
let convTranspose: Conv1d | |
let relu: ReLU? | |
public init(inChannels: Int, outChannels: Int, isLast: Bool = false) { | |
self.conv1 = Conv1d( | |
inputChannels: inChannels, outputChannels: inChannels * 2, kernelSize: 3, stride: 1) | |
self.glu = GLU(dim: 1) | |
self.convTranspose = Conv1d( | |
inputChannels: inChannels, outputChannels: outChannels, kernelSize: 8, stride: 4) | |
self.relu = isLast ? nil : ReLU() | |
} | |
func callAsFunction(_ x: MLXArray) -> MLXArray { | |
var x = conv1(x) | |
x = glu(x) | |
x = convTranspose(x) | |
if let relu = relu { | |
x = relu(x) | |
} | |
return x | |
} | |
} | |
class DemucsModel: Module, UnaryLayer { | |
let encoder: Sequential | |
let lstm: BLSTM | |
let decoder: Sequential | |
override init() { | |
self.encoder = Sequential { | |
EncoderBlock(inChannels: 2, outChannels: 64) | |
EncoderBlock(inChannels: 64, outChannels: 128) | |
EncoderBlock(inChannels: 128, outChannels: 256) | |
EncoderBlock(inChannels: 256, outChannels: 512) | |
EncoderBlock(inChannels: 512, outChannels: 1024) | |
EncoderBlock(inChannels: 1024, outChannels: 2048) | |
} | |
self.lstm = BLSTM(inputSize: 2048, hiddenSize: 2048) | |
self.decoder = Sequential { | |
DecoderBlock(inChannels: 2048, outChannels: 1024) | |
DecoderBlock(inChannels: 1024, outChannels: 512) | |
DecoderBlock(inChannels: 512, outChannels: 256) | |
DecoderBlock(inChannels: 256, outChannels: 128) | |
DecoderBlock(inChannels: 128, outChannels: 64) | |
DecoderBlock(inChannels: 64, outChannels: 8, isLast: true) | |
} | |
} | |
public func callAsFunction(_ x: MLXArray) -> MLXArray { | |
let encoded = encoder(x) | |
let lstmOut = lstm(encoded) | |
let decoded = decoder(lstmOut) | |
return decoded | |
} | |
} | |
struct TrainingView: View { | |
@Binding var trainer: ModelState | |
var body: some View { | |
VStack { | |
Spacer() | |
ScrollView(.vertical) { | |
ForEach(trainer.messages, id: \.self) { | |
Text($0) | |
} | |
} | |
HStack { | |
Spacer() | |
switch trainer.state { | |
case .untrained: | |
Button("Train") { | |
Task { | |
try! await trainer.train() | |
} | |
} | |
case .trained(let model), .predict(let model): | |
Button("Draw a digit") { | |
trainer.state = .predict(model) | |
} | |
} | |
Spacer() | |
} | |
Spacer() | |
} | |
.padding() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment