Skip to content

Instantly share code, notes, and snippets.

@awni
Created May 6, 2025 13:44
Show Gist options
  • Save awni/a760f92cec559060eb70947f2dd507a8 to your computer and use it in GitHub Desktop.
Save awni/a760f92cec559060eb70947f2dd507a8 to your computer and use it in GitHub Desktop.
class GLU: Module, UnaryLayer {
let dim: Int
init(dim: Int) {
self.dim = dim
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
let (a, b) = x.split(axis: dim)
return a * MLXNN.sigmoid(b)
}
}
class BLSTM: Module, UnaryLayer {
let lstm: LSTM
let linear: Linear
init(inputSize: Int, hiddenSize: Int, numLayers: Int = 2) {
self.lstm = LSTM(inputSize: inputSize, hiddenSize: hiddenSize)
self.linear = Linear(hiddenSize * 2, hiddenSize)
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
let (lstmOut, _) = lstm(x)
return linear(lstmOut)
}
}
class EncoderBlock: Module, UnaryLayer {
let conv1: Conv1d
let relu: ReLU
let conv2: Conv1d
let glu: GLU
init(inChannels: Int, outChannels: Int) {
self.conv1 = Conv1d(
inputChannels: inChannels, outputChannels: outChannels, kernelSize: 8, stride: 4)
self.relu = ReLU()
self.conv2 = Conv1d(
inputChannels: outChannels, outputChannels: outChannels * 2, kernelSize: 1, stride: 1)
self.glu = GLU(dim: 1)
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = conv1(x)
x = relu(x)
x = conv2(x)
x = glu(x)
return x
}
}
class DecoderBlock: Module, UnaryLayer {
let conv1: Conv1d
let glu: GLU
let convTranspose: Conv1d
let relu: ReLU?
public init(inChannels: Int, outChannels: Int, isLast: Bool = false) {
self.conv1 = Conv1d(
inputChannels: inChannels, outputChannels: inChannels * 2, kernelSize: 3, stride: 1)
self.glu = GLU(dim: 1)
self.convTranspose = Conv1d(
inputChannels: inChannels, outputChannels: outChannels, kernelSize: 8, stride: 4)
self.relu = isLast ? nil : ReLU()
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = conv1(x)
x = glu(x)
x = convTranspose(x)
if let relu = relu {
x = relu(x)
}
return x
}
}
class DemucsModel: Module, UnaryLayer {
let encoder: Sequential
let lstm: BLSTM
let decoder: Sequential
override init() {
self.encoder = Sequential {
EncoderBlock(inChannels: 2, outChannels: 64)
EncoderBlock(inChannels: 64, outChannels: 128)
EncoderBlock(inChannels: 128, outChannels: 256)
EncoderBlock(inChannels: 256, outChannels: 512)
EncoderBlock(inChannels: 512, outChannels: 1024)
EncoderBlock(inChannels: 1024, outChannels: 2048)
}
self.lstm = BLSTM(inputSize: 2048, hiddenSize: 2048)
self.decoder = Sequential {
DecoderBlock(inChannels: 2048, outChannels: 1024)
DecoderBlock(inChannels: 1024, outChannels: 512)
DecoderBlock(inChannels: 512, outChannels: 256)
DecoderBlock(inChannels: 256, outChannels: 128)
DecoderBlock(inChannels: 128, outChannels: 64)
DecoderBlock(inChannels: 64, outChannels: 8, isLast: true)
}
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let encoded = encoder(x)
let lstmOut = lstm(encoded)
let decoded = decoder(lstmOut)
return decoded
}
}
struct TrainingView: View {
@Binding var trainer: ModelState
var body: some View {
VStack {
Spacer()
ScrollView(.vertical) {
ForEach(trainer.messages, id: \.self) {
Text($0)
}
}
HStack {
Spacer()
switch trainer.state {
case .untrained:
Button("Train") {
Task {
try! await trainer.train()
}
}
case .trained(let model), .predict(let model):
Button("Draw a digit") {
trainer.state = .predict(model)
}
}
Spacer()
}
Spacer()
}
.padding()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment