From 208a2ca2f0eaf9fadade0eb99c564ebd31e9d81f Mon Sep 17 00:00:00 2001 From: Benjamin Meyer Date: Mon, 18 May 2026 22:13:53 +0200 Subject: [PATCH 1/2] Remove nn package. - Move ActivationFunctions to core - Examples are already in deepwit - Move optimziers to optimizer package --- build.sbt | 12 +- .../scala/dimwit/nn/ActivationFunctions.scala | 5 +- core/src/main/scala/dimwit/nn/package.scala | 1 + .../src/main/scala/basic/Autoencoder.scala | 221 -------- .../main/scala/basic/LogisticRegression.scala | 4 +- .../main/scala/basic/MLClassifierMNist.scala | 165 ------ .../scala/basic/MLClassifierMNistCNN.scala | 133 ----- .../src/main/scala/basic/Playground.scala | 15 - examples/src/main/scala/complex/GPT2.scala | 400 --------------- .../src/main/scala/complex/GPT2Train.scala | 483 ------------------ .../main/scala/complex/GPTCheckpoint.scala | 24 - .../complex/VariationalAutoencoder.scala | 36 +- examples/src/main/scala/package.scala | 7 - nn/src/main/scala/nn/Conv2DLayer.scala | 26 - nn/src/main/scala/nn/LinearLayer.scala | 46 -- nn/src/main/scala/nn/Loss.scala | 20 - .../main/scala/nn/TransposeConv2DLayer.scala | 42 -- nn/src/main/scala/nn/package.scala | 1 - .../scala/optimizer}/GradientOptimizer.scala | 2 +- 19 files changed, 43 insertions(+), 1600 deletions(-) rename nn/src/main/scala/nn/Activation.scala => core/src/main/scala/dimwit/nn/ActivationFunctions.scala (88%) create mode 100644 core/src/main/scala/dimwit/nn/package.scala delete mode 100644 examples/src/main/scala/basic/Autoencoder.scala delete mode 100644 examples/src/main/scala/basic/MLClassifierMNist.scala delete mode 100644 examples/src/main/scala/basic/MLClassifierMNistCNN.scala delete mode 100644 examples/src/main/scala/basic/Playground.scala delete mode 100644 examples/src/main/scala/complex/GPT2.scala delete mode 100644 examples/src/main/scala/complex/GPT2Train.scala delete mode 100644 examples/src/main/scala/complex/GPTCheckpoint.scala delete mode 100644 examples/src/main/scala/package.scala delete mode 100644 nn/src/main/scala/nn/Conv2DLayer.scala delete mode 100644 nn/src/main/scala/nn/LinearLayer.scala delete mode 100644 nn/src/main/scala/nn/Loss.scala delete mode 100644 nn/src/main/scala/nn/TransposeConv2DLayer.scala delete mode 100644 nn/src/main/scala/nn/package.scala rename {nn/src/main/scala/nn => optimizer/src/main/scala/optimizer}/GradientOptimizer.scala (99%) diff --git a/build.sbt b/build.sbt index e3d09c1e..6144ddf8 100644 --- a/build.sbt +++ b/build.sbt @@ -11,7 +11,7 @@ ThisBuild / resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/c addCommandAlias("testAndCoverage", "; clean; coverage; test; coverageReport") lazy val root = (project in file(".")) - .aggregate(core, nn, examples) + .aggregate(core, optimizer, examples) .settings( name := "dimwit-root" ) @@ -43,16 +43,16 @@ lazy val core = (project in file("core")) Compile / packageDoc / publishArtifact := true ) -lazy val nn = (project in file("nn")) +lazy val optimizer = (project in file("optimizer")) .settings( - name := "dimwit-nn" + name := "dimwit-optimizer" ) .dependsOn(core) // Examples subproject lazy val examples = (project in file("examples")) .dependsOn(core) - .dependsOn(nn) + .dependsOn(optimizer) .settings( name := "dimwit-examples", // Examples use the same Scala version and dependencies as main project @@ -81,7 +81,7 @@ lazy val examples = (project in file("examples")) // Processes files in /mdocs that need to be copied to the root (e.g. README.md) lazy val docsRoot = (project in file(".dimwit-docs-root")) .enablePlugins(MdocPlugin) - .dependsOn(core, nn) + .dependsOn(core, optimizer) .settings( name := "dimwit-docs-root", mdocIn := (ThisBuild / baseDirectory).value / "mdocs", @@ -98,7 +98,7 @@ lazy val docsRoot = (project in file(".dimwit-docs-root")) // Processes all other docs in /mdocs/docs/ → output to docs/ lazy val docs = (project in file(".dimwit-docs")) .enablePlugins(MdocPlugin) - .dependsOn(core, nn) + .dependsOn(core, optimizer) .settings( name := "dimwit-docs", mdocIn := (ThisBuild / baseDirectory).value / "mdocs/docs", diff --git a/nn/src/main/scala/nn/Activation.scala b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala similarity index 88% rename from nn/src/main/scala/nn/Activation.scala rename to core/src/main/scala/dimwit/nn/ActivationFunctions.scala index f2c0f709..542419f4 100644 --- a/nn/src/main/scala/nn/Activation.scala +++ b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala @@ -1,6 +1,7 @@ -package nn +package dimwit.nn -import dimwit.* +import dimwit.tensor.* +import dimwit.tensor.TensorOps.IsFloating import dimwit.jax.Jax import dimwit.python.PyBridge.{liftPyTensor, toPyTensor} diff --git a/core/src/main/scala/dimwit/nn/package.scala b/core/src/main/scala/dimwit/nn/package.scala new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/core/src/main/scala/dimwit/nn/package.scala @@ -0,0 +1 @@ + diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala deleted file mode 100644 index 152f33a4..00000000 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ /dev/null @@ -1,221 +0,0 @@ -package examples.basic.ae - -import dimwit.* -import dimwit.Conversions.given - -import examples.timed -import dimwit.stats.Normal -import dimwit.random.Random -import nn.LinearLayer -import nn.ActivationFunctions.relu -import nn.GradientDescent -import dimwit.jax.Jax -import nn.ActivationFunctions.sigmoid -import dimwit.random.Random.Key -import dimwit.autodiff.* -import dimwit.autodiff.FloatTree.* -import examples.dataset.MNISTLoader - -import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} -import dimwit.python.PyBridge.toPyTensor -trait Hidden derives Label -trait Output derives Label - -type Pixel = Height |*| Width -type ReconstructedPixel = Height |*| Width - -trait EHidden1 derives Label -trait EHidden2 derives Label - -trait Latent derives Label - -trait DHidden1 derives Label -trait DHidden2 derives Label - -trait Batch derives Label - -class Encoder(p: Encoder.EncoderParams): - - val layer1 = LinearLayer(p.layer1) - val layer2 = LinearLayer(p.layer2) - val latentLayer = LinearLayer(p.latentLayer) - - def apply(v: Tensor1[Pixel, Float32]): Tensor1[Latent, Float32] = - val h1 = relu(layer1(v)) - val h2 = relu(layer2(h1)) - latentLayer(h2) - -object Encoder: - case class EncoderParams( - layer1: LinearLayer.Params[Pixel, EHidden1], - layer2: LinearLayer.Params[EHidden1, EHidden2], - latentLayer: LinearLayer.Params[EHidden2, Latent] - ) - -class Decoder(p: Decoder.DecoderParams): - - val layer1 = LinearLayer(p.layer1) - val layer2 = LinearLayer(p.layer2) - val outputLayer = LinearLayer(p.outputLayer) - - def apply(v: Tensor1[Latent, Float32]): Tensor1[ReconstructedPixel, Float32] = - val h1 = relu(layer1(v)) - val h2 = relu(layer2(h1)) - sigmoid(outputLayer(h2)) - -object Decoder: - case class DecoderParams( - layer1: LinearLayer.Params[Latent, DHidden1], - layer2: LinearLayer.Params[DHidden1, DHidden2], - outputLayer: LinearLayer.Params[DHidden2, ReconstructedPixel] - ) - -case class Autoencoder(params: Autoencoder.Params): - - val encoder = Encoder(params.encoderParams) - val decoder = Decoder(params.decoderParams) - - def apply(v: Tensor1[Pixel, Float32]): (Tensor1[ReconstructedPixel, Float32], Tensor1[Latent, Float32]) = - val latent = encoder(v) - val reconstructed = decoder(latent) - (reconstructed, latent) - - def loss(original: Tensor1[Pixel, Float32]): Tensor0[Float32] = - val (reconstructed, _) = apply(original) - val eps = 1e-5f - val reconstructionLoss = -((original * (reconstructed +! eps).log) + ((Tensor0(1f) -! original) * (1f -! reconstructed +! eps).log)).sum - reconstructionLoss - -object Autoencoder: - case class Params( - encoderParams: Encoder.EncoderParams, - decoderParams: Decoder.DecoderParams - ) - object Params: - def apply(params: Autoencoder.Params): Params = - Params( - params.encoderParams, - params.decoderParams - ) - -object AutoencoderExample: - - def main(args: Array[String]): Unit = - - dimwit.initialize() - - val learningRate = 5e-4f - - val numTestSamples = 9728 - val batchSize = 512 - val numSamples = 59904 - val numEpochs = 50 - val latentDim = 20 - - val initKey = Random.Key(42) - - val (trainX, trainY) = MNISTLoader.createTrainingDataset(maxSamples = Some(numSamples)).get - val (testX, testY) = MNISTLoader.createTestDataset(maxSamples = Some(numTestSamples)).get - - /* - * Initialize the model parameters - * */ - val initKeys = initKey.split(6) - val encoderParams = Encoder.EncoderParams( - LinearLayer.Params[Pixel, EHidden1](initKeys(0))( - Axis[Pixel] -> (28 * 28), - Axis[EHidden1] -> 512 - ), - LinearLayer.Params[EHidden1, EHidden2](initKeys(1))( - Axis[EHidden1] -> 512, - Axis[EHidden2] -> 256 - ), - LinearLayer.Params[EHidden2, Latent](initKeys(2))( - Axis[EHidden2] -> 256, - Axis[Latent] -> latentDim - ) - ) - val decoderParams = Decoder.DecoderParams( - LinearLayer.Params[Latent, DHidden1](initKeys(3))( - Axis[Latent] -> 20, - Axis[DHidden1] -> 256 - ), - LinearLayer.Params[DHidden1, DHidden2](initKeys(4))( - Axis[DHidden1] -> 256, - Axis[DHidden2] -> 512 - ), - LinearLayer.Params[DHidden2, ReconstructedPixel](initKeys(5))( - Axis[DHidden2] -> 512, - Axis[ReconstructedPixel] -> (28 * 28) - ) - ) - - // we need to scale down the initial parameters for - // better training stability. - // TODO linear layer et al. should support custom initializers - // or xavier initialization - val initialParams = Autoencoder.Params(encoderParams, decoderParams) - val scaledInitialParams = initialParams.map([T <: Tuple] => (n: Labels[T]) ?=> (t: Tensor[T, Float32]) => t *! Tensor0(0.1f)) - - /* - * Training loop - * */ - - def loss[S <: Sample: Label](trainData: Tensor3[S, Height, Width, Float32])(params: Autoencoder.Params): Tensor0[Float32] = - val ae = Autoencoder(params) - trainData - .vmap(Axis[S])(sample => ae.loss(sample.flatten)) - .mean - - val batches = trainX.chunk(Axis[TrainSample], numSamples / batchSize) - - val optimizer = GradientDescent(learningRate = Tensor0(learningRate)) - - def gradientStep(batch: Tensor3[TrainSample, Height, Width, Float32], params: Autoencoder.Params): Autoencoder.Params = - val grads = grad(loss(batch))(params) - val (newParams, _) = optimizer.update(grads, params, ()) - newParams - - val jittedGradientStep = jit(gradientStep) - - def trainEpoch(params: Autoencoder.Params): Autoencoder.Params = - batches.foldLeft(params): - case (currentParams, batch) => - jittedGradientStep(batch, currentParams) - - // run the loop - val trainTrajectory = Iterator.iterate(scaledInitialParams): currentParams => - timed("Training"): - dimwit.gc() - trainEpoch(currentParams) - - val trainedParams = trainTrajectory.zipWithIndex - .tapEach: - case (params, epoch) => - timed("Evaluation"): - val lossValue = loss(testX)(params) - println(s"Epoch $epoch | Test loss: $lossValue") - .map((params, _) => params) - .drop(numEpochs) - .next() - - /* - * Evaluation - * */ - val ae = Autoencoder(trainedParams) - - val reconstructed = testX - .slice(Axis[TestSample].at(0 until 64)) - .vmap(Axis[TestSample]): sample => - val latent = ae.encoder(sample.flatten) - ae.decoder(latent) - .relabel(Axis[TestSample].as(Axis[Prime[Height] |*| Prime[Width]])) - - val img2d = reconstructed.rearrange( - (Axis[Prime[Height] |*| Height], Axis[Prime[Width] |*| Width]), - (Axis[Prime[Height]] -> 8, Axis[Prime[Width]] -> 8, Axis[Height] -> 28, Axis[Width] -> 28) - ) - import me.shadaj.scalapy.py - val plt = py.module("matplotlib.pyplot") - plt.imshow(toPyTensor(img2d), cmap = "gray") - plt.show() diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 85070310..06a69a0f 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -3,8 +3,8 @@ package examples.basic import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.* -import nn.* -import nn.ActivationFunctions.{sigmoid, relu} +import dimwit.optimizer.GradientDescent +import dimwit.nn.ActivationFunctions.{sigmoid, relu} import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/MLClassifierMNist.scala b/examples/src/main/scala/basic/MLClassifierMNist.scala deleted file mode 100644 index 8fe12d54..00000000 --- a/examples/src/main/scala/basic/MLClassifierMNist.scala +++ /dev/null @@ -1,165 +0,0 @@ -package examples.basic - -import dimwit.* -import dimwit.autodiff.* -import nn.* -import nn.ActivationFunctions.{relu, sigmoid} -import dimwit.random.Random - -import examples.timed -import examples.dataset.MNISTLoader - -def binaryCrossEntropy[L: Label]( - logits: Tensor1[L, Float32], - label: Tensor0[Int32] -): Tensor0[Float32] = - val maxLogit = logits.max - val stableExp = (logits -! maxLogit).exp - val logSumExp = stableExp.sum.log + maxLogit - val targetLogit = logits.slice(Axis[L].at(label)) - -(targetLogit - logSumExp) - -object MLPClassifierMNist: - - import MNISTLoader.{Sample, TrainSample, Height, Width} - trait Hidden derives Label - trait Output derives Label - - object MLP: - case class Params( - layer1: LinearLayer.Params[Height |*| Width, Hidden], - layer2: LinearLayer.Params[Hidden, Output] - ) - - object Params: - - def apply( - layer1Dim: AxisExtent[Height |*| Width], - layer2Dim: AxisExtent[Hidden], - outputDim: AxisExtent[Output] - )( - paramKey: Random.Key - ): Params = - val (key1, key2) = paramKey.split2() - Params( - layer1 = LinearLayer.Params(key1)(layer1Dim, layer2Dim), - layer2 = LinearLayer.Params(key2)(layer2Dim, outputDim) - ) - - case class MLP(params: MLP.Params) extends Function[Tensor2[Height, Width, Float32], Tensor0[Int32]]: - - private val layer1 = LinearLayer(params.layer1) - private val layer2 = LinearLayer(params.layer2) - - def logits( - image: Tensor2[Height, Width, Float32] - ): Tensor1[Output, Float32] = - val hidden = relu(layer1(image.flatten)) - layer2(hidden) - - override def apply(image: Tensor2[Height, Width, Float32]): Tensor0[Int32] = logits(image).argmax(Axis[Output]) - - def main(args: Array[String]): Unit = - - dimwit.initialize() - - val numSamples = 59904 - val numTestSamples = 9728 - val batchSize = 512 - val numEpochs = 50 - val (dataKey, trainKey) = Random.Key(42).split2() - val (initKey, restKey) = trainKey.split2() - - val (trainX, trainY) = MNISTLoader.createTrainingDataset(maxSamples = Some(numSamples)).get - val (testX, testY) = MNISTLoader.createTestDataset(maxSamples = Some(numTestSamples)).get - - def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float32], batchLabels: Tensor1[TrainSample, Int32])( - params: MLP.Params - ): Tensor0[Float32] = - val model = MLP(params) - val losses = zipvmap(Axis[TrainSample])(batchImages, batchLabels): - case (image, label) => - val logits = model.logits(image) - binaryCrossEntropy(logits, label) - losses.mean - val initParams = MLP.Params( - Axis[Height |*| Width] -> 28 * 28, - Axis[Hidden] -> 128, - Axis[Output] -> 10 - )(initKey) - - def accuracy[S: Label]( - predictions: Tensor1[S, Int32], - targets: Tensor1[S, Int32] - ): Tensor0[Float32] = - val matches = zipvmap(Axis[S])(predictions, targets)(_ === _) - matches.asFloat32.mean - - // val optimizer = GradientDescent(learningRate = Tensor0(1e-4f)) - // type OptState = Unit - - // val optimizer = Lion(learningRate = Tensor0(1e-4f), weightDecay = Tensor0(0f)) - // type OptState = MLP.Params - - val optimizer = Adam(learningRate = Tensor0(1e-4f)) - type OptState = AdamState[MLP.Params] - - def gradientStep( - imageBatch: Tensor[(TrainSample, Height, Width), Float32], - labelBatch: Tensor1[TrainSample, Int32], - params: MLP.Params, - state: OptState - ): (MLP.Params, OptState) = - val lossBatch = batchLoss(imageBatch, labelBatch) - val grads = grad(lossBatch)(params) - optimizer.update(grads, params, state) - val (jitDonate, jitStep, jitReclaim) = jitDonating(gradientStep) - - def miniBatchGradientDescent( - imageBatches: Seq[Tensor[(TrainSample, Height, Width), Float32]], - labelBatches: Seq[Tensor1[TrainSample, Int32]] - )( - params: MLP.Params, - initialState: OptState - ): (MLP.Params, OptState) = - val res = imageBatches - .zip(labelBatches) - .foldLeft(jitDonate(params, initialState)): - case ((currentParams, state), (imageBatch, labelBatch)) => - jitStep(imageBatch, labelBatch.asInt32, currentParams, state) - jitReclaim(res) - - val trainMiniBatchGradientDescent = miniBatchGradientDescent( - trainX.chunk(Axis[TrainSample], numSamples / batchSize), - trainY.asInt32.chunk(Axis[TrainSample], numSamples / batchSize) - ) - val trainTrajectory = Iterator.iterate((initParams, optimizer.init(initParams))): (currentParams, state) => - timed("Training"): - dimwit.gc() - trainMiniBatchGradientDescent(currentParams, state) - def evaluate[S <: Sample: Label]( - params: MLP.Params, - dataX: Tensor3[S, Height, Width, Float32], - dataY: Tensor1[S, Int32] - ): Tensor0[Float32] = - val model = MLP(params) - val predictions = dataX.vmap(Axis[S])(model) - accuracy(predictions, dataY.asInt32) - val jitEvaluate = evaluate - val (finalParams, finalState) = trainTrajectory.zipWithIndex - .tapEach: - case ((params, state), epoch) => - timed("Evaluation"): - val testAccuracy = evaluate(params, testX, testY.asInt32) - val trainAccuracy = evaluate(params, trainX, trainY.asInt32) - println( - List( - s"Epoch $epoch", - f"Test accuracy: ${testAccuracy.item * 100}%.2f%%", - f"Train accuracy: ${trainAccuracy.item * 100}%.2f%%" - ).mkString(", ") - ) - .drop(numEpochs) - .next() - - println("\nTraining complete!") diff --git a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala deleted file mode 100644 index 605d4fee..00000000 --- a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala +++ /dev/null @@ -1,133 +0,0 @@ -package examples.basic.mnistcnn - -import dimwit.* -import dimwit.Conversions.given -import dimwit.autodiff.FloatTree.ops.* -import nn.* -import nn.ActivationFunctions.relu -import dimwit.random.Random -import examples.timed -import examples.dataset.MNISTLoader -import examples.basic.MLPClassifierMNist.MLP - -// Logits-based Cross Entropy (same as yours) -def binaryCrossEntropy[L: Label]( - logits: Tensor1[L, Float32], - label: Tensor0[Int32] -): Tensor0[Float32] = - val maxLogit = logits.max - val logSumExp = ((logits -! maxLogit).exp.sum + 1e-7f).log + maxLogit - val targetLogit = logits.slice(Axis[L].at(label)) - logSumExp - targetLogit - -object MNistCNN: - import MNISTLoader.{Sample, TrainSample, Height, Width} - - // New labels for CNN architecture - trait Channel derives Label - trait Hidden derives Label - trait PixelEmbedding derives Label - type ImageEmbedding = Height |*| Width |*| PixelEmbedding - trait Output derives Label - - object CNN: - case class Params( - conv1: Conv2DLayer.Params[Height, Width, Channel, Hidden, Float32], - conv2: Conv2DLayer.Params[Height, Width, Hidden, PixelEmbedding, Float32], - output: LinearLayer.Params[ImageEmbedding, Output] - ) - - object Params: - def apply(paramKey: Random.Key)( - numHidden1: Int, - numHidden2: Int - ): Params = - val keys = paramKey.split(3) - val kernelHeightDim = Axis[Height] -> 3 - val kernelWidthDim = Axis[Width] -> 3 - val channelDim = Axis[Channel] -> 1 - val hiddenDim = Axis[Hidden] -> numHidden1 - val pixelEmbeddingDim = Axis[PixelEmbedding] -> numHidden2 - val embeddingDim = Axis[ImageEmbedding] -> 7 * 7 * numHidden2 - val outputDim = Axis[Output] -> 10 - Params( - conv1 = Conv2DLayer.Params(keys(0))(Shape(kernelHeightDim, kernelWidthDim, channelDim, hiddenDim)), - conv2 = Conv2DLayer.Params(keys(1))(Shape(kernelHeightDim, kernelWidthDim, hiddenDim, pixelEmbeddingDim)), - output = LinearLayer.Params(keys(2))(embeddingDim, outputDim) - ) - - case class CNN(params: CNN.Params) extends Function[Tensor2[Height, Width, Float32], Tensor0[Int32]]: - private val conv1 = Conv2DLayer(params.conv1, stride = 2, padding = Padding.SAME) - private val conv2 = Conv2DLayer(params.conv2, stride = 2, padding = Padding.SAME) - private val output = LinearLayer(params.output) - - def logits(image: Tensor2[Height, Width, Float32]): Tensor1[Output, Float32] = - val input = image.appendAxis(Axis[Channel]) - val hidden = relu(conv1(input)) - val features = relu(conv2(hidden)) - output(features.flatten) - - override def apply(image: Tensor2[Height, Width, Float32]): Tensor0[Int32] = - logits(image).argmax(Axis[Output]) - - def main(args: Array[String]): Unit = - - dimwit.initialize() - - val learningRate = 0.01f - val numSamples = 59904 - val batchSize = 128 - val numEpochs = 50 - - val (dataKey, trainKey) = Random.Key(42).split2() - val (trainX, trainY) = MNISTLoader.createTrainingDataset(maxSamples = Some(numSamples)).get - val (testX, testY) = MNISTLoader.createTestDataset(maxSamples = Some(9728)).get - - val initParams = CNN.Params(trainKey)(16, 32) - val scaledInitialParams = initParams **! Tensor0(0.1f) - - def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float32], batchLabels: Tensor1[TrainSample, Int32])( - params: CNN.Params - ): Tensor0[Float32] = - val model = CNN(params) - val batchLosses = zipvmap(Axis[TrainSample])(batchImages, batchLabels): - case (img, lbl) => - binaryCrossEntropy(model.logits(img), lbl) - batchLosses.mean - - val optimizer = GradientDescent(learningRate = Tensor0(learningRate)) - - def gradientStep( - imageBatch: Tensor[(TrainSample, Height, Width), Float32], - labelBatch: Tensor1[TrainSample, Int32], - params: CNN.Params - ): CNN.Params = - val grads = Autodiff.grad(batchLoss(imageBatch, labelBatch))(params) - val (newParams, newState) = optimizer.update(grads, params, ()) - newParams - - val (jitDonate, jitStep, jitReclaim) = jitDonating(gradientStep) - - // Training Loop - val trainTrajectory = Iterator.iterate(scaledInitialParams): params => - timed("Training Epoch"): - val imgBatches = trainX.chunk(Axis[TrainSample], numSamples / batchSize) - val lblBatches = trainY.chunk(Axis[TrainSample], numSamples / batchSize) - val newParams = imgBatches.zip(lblBatches).foldLeft(jitDonate(params)): - case (params, (imgB, lblB)) => - jitStep(imgB, lblB.asInt32, params) - jitReclaim(newParams) - - // Evaluation - def evaluate[S <: Sample: Label](params: CNN.Params, dataX: Tensor[(S, Height, Width), Float32], dataY: Tensor1[S, Int32]): Tensor0[Float32] = - val model = CNN(params) - val predictions = dataX.vmap(Axis[S])(model) - val matches = zipvmap(Axis[S])(predictions, dataY)(_ === _) - matches.asFloat32.mean - - trainTrajectory.drop(1).zipWithIndex.foreach: - case (params, epoch) => - if epoch % 1 == 0 then - dimwit.gc() - val acc = evaluate(params, testX, testY.asInt32) - println(f"Epoch $epoch | Test Accuracy: ${acc.item * 100}%.2f%%") diff --git a/examples/src/main/scala/basic/Playground.scala b/examples/src/main/scala/basic/Playground.scala deleted file mode 100644 index db0ba479..00000000 --- a/examples/src/main/scala/basic/Playground.scala +++ /dev/null @@ -1,15 +0,0 @@ -package src.main.scala.basic - -import dimwit.* -import dimwit.autodiff.* - -object Playground extends App: - val k = Key(42) - - trait A derives Label - trait B derives Label - - def f(x: Tensor1[A, Float32]): Tensor0[Float32] = - x.sum - - grad(f) diff --git a/examples/src/main/scala/complex/GPT2.scala b/examples/src/main/scala/complex/GPT2.scala deleted file mode 100644 index 6caec9fd..00000000 --- a/examples/src/main/scala/complex/GPT2.scala +++ /dev/null @@ -1,400 +0,0 @@ -package examples.complex - -import dimwit.* -import dimwit.Conversions.given -import dimwit.python.PyBridge.liftPyTensor - -import nn.ActivationFunctions.* - -// Dimensions -trait Vocab derives Label // 50257 -trait Embedding derives Label // 768 -trait Context derives Label // 1024 -trait EmbeddingMixed derives Label // 3072 - -trait Batch derives Label - -case class LayerNormalizationParams( - weight: Tensor1[Embedding, Float32], - bias: Tensor1[Embedding, Float32] -) - -case class LinearLayerParams[In, Out]( - weight: Tensor2[In, Out, Float32], - bias: Tensor1[Out, Float32] -) - -case class ProjectionLayerParams[In, Out]( - weight: Tensor2[In, Out, Float32] -) - -trait Head derives Label -trait HeadKey derives Label -trait HeadQuery derives Label -trait HeadValue derives Label - -case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float32], val bias: Tensor2[Head, Kind, Float32]) - -case class MultiHeadAttentionParams( - wq: HeadsParams[HeadQuery], - wk: HeadsParams[HeadKey], - wv: HeadsParams[HeadValue], - proj: LinearLayerParams[Head |*| HeadValue, Embedding] -) derives TensorTree - -case class EmbeddingMixerParams( - c_fc: LinearLayerParams[Embedding, EmbeddingMixed], - c_proj: LinearLayerParams[EmbeddingMixed, Embedding] -) - -case class TransformerLayerParams( - ln1: LayerNormalizationParams, - attn: MultiHeadAttentionParams, - ln2: LayerNormalizationParams, - embeddingMixer: EmbeddingMixerParams -) - -case class GPT2Params( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], - positionalEmbeddings: Tensor2[Context, Embedding, Float32], - layers: List[TransformerLayerParams], - outputNormalization: LayerNormalizationParams, - output: ProjectionLayerParams[Embedding, Vocab] -) - -object GPT2Params: - def apply( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], - positionalEmbeddings: Tensor2[Context, Embedding, Float32], - layers: List[TransformerLayerParams], - outputNormalization: LayerNormalizationParams - ): GPT2Params = - val outputParams = ProjectionLayerParams( - vocabularyEmbeddings.transpose // Tying output weights with input embeddings - ) - GPT2Params(vocabularyEmbeddings, positionalEmbeddings, layers, outputNormalization, outputParams) - -case class GPT2(params: GPT2Params) extends (Tensor2[Batch, Context, Int32] => Tensor2[Batch, Context, Int32]): - - private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): - override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = - x.dot(Axis[In])(params.weight) + params.bias - - private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - private val hiddenLayer = LinearLayer(params.c_fc) - private val outputLayer = LinearLayer(params.c_proj) - // TODO add dropout - - def apply(in: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - in.vmap(Axis[Context])(x => - val hidden = gelu(hiddenLayer(x)) - outputLayer(hidden) - ) - - private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): - def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = - x.dot(Axis[In])(params.weight) - - private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - - private val projection = LinearLayer(params.proj) - - def apply(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - val heads = zipvmap(Axis[Head])(params.wq.weights, params.wq.bias, params.wk.weights, params.wk.bias, params.wv.weights, params.wv.bias): - attention.tupled(_)(x) - heads.vmap(Axis[Context])(heads => projection(heads.flatten)) - - private def attention( - wq: Tensor2[Embedding, HeadQuery, Float32], - wqBias: Tensor1[HeadQuery, Float32], - wk: Tensor2[Embedding, HeadKey, Float32], - wkBias: Tensor1[HeadKey, Float32], - wv: Tensor2[Embedding, HeadValue, Float32], - wvBias: Tensor1[HeadValue, Float32] - )(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, HeadValue, Float32] = - - trait AttnWeights derives Label - - def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float32]): Tensor2[Context, Prime[Context], Float32] = - val ctxLength = attnScores.shape(Axis[Context]) - val causalMask = tril(Tensor(Shape((Axis[Context] -> ctxLength, Axis[Prime[Context]] -> ctxLength))).fill(true)) - where(causalMask, attnScores, Tensor.like(attnScores).fill(Float.NegativeInfinity)) - - val queries = x.dot(Axis[Embedding])(wq) +! wqBias - val keys = x.dot(Axis[Embedding])(wk) +! wkBias - val values = x.dot(Axis[Embedding])(wv) +! wvBias - val dk = Tensor0(Math.sqrt(keys.shape(Axis[HeadKey])).toFloat) - val attnScores = (queries.dot(Axis[HeadQuery ~ HeadKey])(keys) /! dk) - val attnWeights = causalMasking(attnScores) - .vmap(Axis[Context])(attnScore => softmax(attnScore).relabelTo(Axis[AttnWeights])) - val res = attnWeights.dot(Axis[AttnWeights ~ Context])(values) - res - - private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float32] => Tensor1[Embedding, Float32]): - - private def standardize(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = - val x0 = x -! x.mean - val variance = x0.pow(2).mean - val epsilon = 1e-6f - x0 /! (variance + epsilon).sqrt - - def apply(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = - standardize(x) * params.weight + params.bias - - private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - private val embeddingMixer = EmbeddingMixer(params.embeddingMixer) - private val multiHeadAttention = MultiHeadAttention(params.attn) - private val preNormalization = LayerNorm(params.ln1) - private val postNormalization = LayerNorm(params.ln2) - - def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - var x = t - x = x + multiHeadAttention(x.vmap(Axis[Context])(preNormalization)) - x = x + embeddingMixer(x.vmap(Axis[Context])(postNormalization)) - x - - private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - override def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - layers.foldLeft(t): - case (t, layer) => layer(t) - - case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], positionalEmbeddings: Tensor2[Context, Embedding, Float32]): - - def apply(tokens: Tensor1[Context, Int32]): Tensor2[Context, Embedding, Float32] = - val embeddings = vocabularyEmbeddings.take(Axis[Vocab])(tokens) - embeddings + positionalEmbeddings - - case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float32] => Tensor1[Vocab, Float32]): - private val normalizationLayer = LayerNorm(normalization) - private val projection = ProjectionLayer(projectionParams) - override def apply(x: Tensor1[Embedding, Float32]): Tensor1[Vocab, Float32] = - projection(normalizationLayer(x)) - - private val embedder = Embedder(params.vocabularyEmbeddings, params.positionalEmbeddings) - private val transformerBlock = TransformerBlock(params.layers.map(TransformerLayer(_))) - private val outputLayer = OutputLayer(params.outputNormalization, params.output) - - def logits(inputTokens: Tensor2[Batch, Context, Int32]): Tensor3[Batch, Context, Vocab, Float32] = - inputTokens.vmap(Axis[Batch]): - case tokens => - val startEmbeddings = embedder(tokens) - val endEmbeddings = transformerBlock(startEmbeddings) - endEmbeddings.vmap(Axis[Context])(x => outputLayer(x)) - - def probits(inputTokens: Tensor2[Batch, Context, Int32]): Tensor3[Batch, Context, Vocab, Float32] = - val x = logits(inputTokens) - val res = x.vapply(Axis[Vocab])(softmax) - return res - - def apply(inputTokens: Tensor2[Batch, Context, Int32]): Tensor2[Batch, Context, Int32] = - val x = probits(inputTokens) - val res = x.argmax(Axis[Vocab]) - return res - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -lazy val tiktoken = py.module("tiktoken") - -case class Tokenizer(enc: py.Dynamic): - def encode(s: String): List[Int] = - val pythonSet = py.Dynamic.global.set(Seq("<|endoftext|>").toPythonProxy) - enc.encode(s, allowed_special = pythonSet).as[List[Int]] - - def decode(l: List[Int]): String = - enc.decode(l.toPythonProxy).as[String] - -case class Inference(gpt2: GPT2, tokenizer: Tokenizer): - - def apply(input: String): LazyList[String] = - println(s"Start inference for input: \"$input\"") - val tokenIds = tokenizer.encode(input) - def loop(currentTokenIds: List[Int]): LazyList[String] = - println(s"Current Token Ids: $currentTokenIds") - val paddedTokenIds = currentTokenIds ++ List.fill(1024 - currentTokenIds.length)(0) - val inputTensor = Tensor( - Shape((Axis[Batch] -> 1, Axis[Context] -> paddedTokenIds.length)) - ).fromArray( - paddedTokenIds.toArray - ) - val predTokensTensor = gpt2(inputTensor).slice((Axis[Batch].at(0))) - val nextToken = predTokensTensor.slice(Axis[Context].at(currentTokenIds.length - 1)) - val nextTokens = currentTokenIds :+ nextToken.item - val decoded = tokenizer.decode(nextTokens) - System.gc() - LazyList.cons(decoded, loop(nextTokens)) - loop(tokenIds) - -object GPT2Inference: - - import java.io.RandomAccessFile - import java.nio.channels.FileChannel - import java.nio.{ByteBuffer, ByteOrder} - import java.nio.charset.StandardCharsets - import dimwit.jax.Jax - import dimwit.tensor.DType - import me.shadaj.scalapy.py - import me.shadaj.scalapy.py.SeqConverters - - case class TensorInfo(dtype: String, shape: List[Int], start: Long, end: Long) - - object SafeTensorsReader: - import me.shadaj.scalapy.py.SeqConverters - import java.util.Base64 - - // A compact Python loader that decodes Base64 back to a tensor - // Defined as a single line to completely avoid IndentationErrors - private val pythonLoader = py.eval("""lambda b64, dtype, shape: (__import__('numpy').frombuffer(__import__('base64').b64decode(b64), dtype={'F32':__import__('numpy').float32,'I32':__import__('numpy').int32,'I64':__import__('numpy').int64}[dtype]).reshape(shape))""") - - def readHeader(filePath: String): (Map[String, TensorInfo], Long) = - - val file = new RandomAccessFile(filePath, "r") - val channel = file.getChannel - try - val headerSizeBuffer = ByteBuffer.allocate(8) - headerSizeBuffer.order(ByteOrder.LITTLE_ENDIAN) - channel.read(headerSizeBuffer) - headerSizeBuffer.flip() - val headerSize = headerSizeBuffer.getLong - - val jsonBuffer = ByteBuffer.allocate(headerSize.toInt) - channel.read(jsonBuffer) - jsonBuffer.flip() - val jsonString = new String(jsonBuffer.array(), StandardCharsets.UTF_8) - - val json = ujson.read(jsonString) - val meta = json.obj - - val tensorMap = meta - .filterKeys(_ != "__metadata__") - .map { case (name, data) => - val offsets = data("data_offsets").arr.map(_.num.toLong) - val shape = data("shape").arr.map(_.num.toInt).toList - val dtype = data("dtype").str - name -> TensorInfo(dtype, shape, offsets(0), offsets(1)) - } - .toMap - - val dataStartPos = 8 + headerSize - (tensorMap, dataStartPos) - finally file.close() - - def loadTensor(filePath: String, info: TensorInfo, dataStartPos: Long): Jax.PyDynamic = - val file = new RandomAccessFile(filePath, "r") - try - val len = (info.end - info.start).toInt - val bytes = new Array[Byte](len) - - file.seek(dataStartPos + info.start) - file.readFully(bytes) - - val b64String = Base64.getEncoder.encodeToString(bytes) - Jax.jnp.array(pythonLoader(b64String, info.dtype, info.shape.toPythonProxy)) - finally file.close() - - def main(args: Array[String]): Unit = - - dimwit.initialize() - - val filePath = "data/gpt.safetensors" - - val (tensorMap, dataStartPos) = SafeTensorsReader.readHeader(filePath) - - def loadAttnWeights(cAttnName: String, cProjName: String, numHeads: Int = 12): MultiHeadAttentionParams = - - /* - * Define types to make loading easier. - * QKV type expresses the structure of the flat stored format for attention weights and biases in GPT-2. - * The format is [Query |+| Key |+| Value], meaning that the weights for Query, Key, and Value are concatenated along a single dimension. - * Where Query, Key, and Value are themselves the combinations of their attention head weights. - */ - type Query = Head |*| HeadQuery - type Key = Head |*| HeadKey - type Value = Head |*| HeadValue - type QKV = Query |+| Key |+| Value - - val cAttn = loadLinear(cAttnName, Axis[Embedding], Axis[QKV]) - val cProj = loadLinear(cProjName, Axis[Head |*| HeadValue], Axis[Embedding]) - - def splitWeightToHeads[L](t: Tensor2[Embedding, Head |*| L, Float32], numHeads: Int)(using label: Label[L]): Tensor3[Head, Embedding, L, Float32] = - val tLength = t.shape(Axis[Head |*| L]) - require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") - t.rearrange( - (Axis[Head], Axis[Embedding], Axis[L]), - Axis[Head] -> numHeads, - Axis[L] -> (tLength / numHeads) - ) - def splitBiasToHeads[L](t: Tensor1[Head |*| L, Float32], numHeads: Int)(using label: Label[L]): Tensor2[Head, L, Float32] = - val tLength = t.shape(Axis[Head |*| L]) - require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") - t.rearrange( - (Axis[Head], Axis[L]), - Axis[Head] -> numHeads, - Axis[L] -> (tLength / numHeads) - ) - val qkvLength = cAttn.weight.shape(Axis[QKV]) - require(qkvLength % 3 == 0, s"QKV length $qkvLength not divisible by 3") - val (qLength, kLength, vLength) = (qkvLength / 3, qkvLength / 3, qkvLength / 3) - - val (wq, wk, wv) = cAttn.weight.deconcatenate( - axis = Axis[QKV], - ((Axis[Query] -> qLength), (Axis[Key] -> kLength), (Axis[Value] -> vLength)) - ) - val (wqb, wkb, wvb) = cAttn.bias.deconcatenate( - axis = Axis[QKV], - ((Axis[Query] -> qLength), (Axis[Key] -> kLength), (Axis[Value] -> vLength)) - ) - - MultiHeadAttentionParams( - wq = HeadsParams(splitWeightToHeads(wq, numHeads), splitBiasToHeads(wqb, numHeads)), - wk = HeadsParams(splitWeightToHeads(wk, numHeads), splitBiasToHeads(wkb, numHeads)), - wv = HeadsParams(splitWeightToHeads(wv, numHeads), splitBiasToHeads(wvb, numHeads)), - proj = cProj - ) - - def load1[L](name: String, axis: Axis[L])(using Label[L]): Tensor1[L, Float32] = - val info = tensorMap(name) - val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) - liftPyTensor(jaxArray) - - def load2[L1, L2](name: String, axis1: Axis[L1], axis2: Axis[L2])(using Label[L1], Label[L2]): Tensor2[L1, L2, Float32] = - val info = tensorMap(name) - val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) - liftPyTensor(jaxArray) - - def loadLinear[In, Out](prefix: String, inAxis: Axis[In], outAxis: Axis[Out])(using Label[In], Label[Out]): LinearLayerParams[In, Out] = - val w = load2(s"$prefix.weight", inAxis, outAxis) - val b = load1(s"$prefix.bias", outAxis) - LinearLayerParams(w, b) - - def loadLN(prefix: String): LayerNormalizationParams = - val w = load1(s"$prefix.weight", Axis[Embedding]) - val b = load1(s"$prefix.bias", Axis[Embedding]) - LayerNormalizationParams(w, b) - - val wpe = load2("wpe.weight", Axis[Context], Axis[Embedding]) - println("Successfully loaded WPE parameters") - val wte = load2("wte.weight", Axis[Vocab], Axis[Embedding]) - println("Successfully loaded WTE parameters") - val outputNormalization = loadLN("ln_f") - println("Successfully loaded final LayerNorm parameters") - - val layers = (0 until 12).map { i => - val prefix = s"h.$i" - val ln1 = loadLN(s"$prefix.ln_1") - val ln2 = loadLN(s"$prefix.ln_2") - val attn = loadAttnWeights(s"$prefix.attn.c_attn", s"$prefix.attn.c_proj") - val c_fc = loadLinear(s"$prefix.mlp.c_fc", Axis[Embedding], Axis[EmbeddingMixed]) - val c_proj = loadLinear(s"$prefix.mlp.c_proj", Axis[EmbeddingMixed], Axis[Embedding]) - val mlp = EmbeddingMixerParams(c_fc, c_proj) - println(s"Successfully loaded layer $i parameters") - - TransformerLayerParams(ln1, attn, ln2, mlp) - }.toList - println("Successfully loaded all layers parameters") - - val params = GPT2Params(wte, wpe, layers, outputNormalization) - val gpt2 = GPT2(params) - val inference = Inference(gpt2, Tokenizer(tiktoken.get_encoding("gpt2"))) - // val stream = inference("Hello, my name is Beni. Who ") - val stream = inference("Deep Learning is quite complicated. However, with the right tools, ") - stream.foreach(println) diff --git a/examples/src/main/scala/complex/GPT2Train.scala b/examples/src/main/scala/complex/GPT2Train.scala deleted file mode 100644 index 1e16115d..00000000 --- a/examples/src/main/scala/complex/GPT2Train.scala +++ /dev/null @@ -1,483 +0,0 @@ -package examples.complex.nanoGPT - -import dimwit.* -import dimwit.Conversions.given - -import nn.ActivationFunctions.* -import dimwit.random.Random -import dimwit.stats.Normal -import nn.AdamW -import nn.Adam -import nn.Loss -import examples.timed -import dimwit.python.PythonSetup -import src.main.scala.complex.safePyTree - -import java.io.RandomAccessFile -import java.nio.channels.FileChannel -import java.nio.{ByteBuffer, ByteOrder} -import java.nio.charset.StandardCharsets -import dimwit.jax.Jax -import dimwit.tensor.DType -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import src.main.scala.complex.loadPyTree -import dimwit.stats.Categorical -import dimwit.stats.Uniform - -object Config: - inline val numIterations = 60_000 - inline val trainLogInterval = 10 - inline val evalLogInterval = 250 - inline val numberOfEvalIterations = 200 - inline val vocabSize = 65 - inline val learningRate = 1e-3f - inline val beta1 = 0.9f - inline val beta2 = 0.99f - inline val batchSize = 64 - inline val contextLength = 256 - inline val numberOfLayers = 6 - inline val numberOfHeads = 6 - inline val extentEmbedding = 384 - inline val dropout = 0.2 - - private inline def validateConfig: Unit = - inline if extentEmbedding % numberOfHeads != 0 then - import scala.compiletime.{error, constValue} - import scala.compiletime.ops.int.ToString - scala.compiletime.error( - "Config Error: 'extentEmbedding' must be divisible by 'numberOfHeads', but got extentEmbedding = " + constValue[ToString[extentEmbedding.type]] + " and numberOfHeads = " + constValue[ToString[numberOfHeads.type]] - ) - - validateConfig - -import Config.* - -// assert(extentEmbedding % numberOfHeads == 0, "Embedding size must be divisible by number of heads") -val headAxisExtent = Axis[Head] -> numberOfHeads -val headKeyAxisExtent = Axis[HeadKey] -> extentEmbedding / numberOfHeads -val headQueryAxisExtent = Axis[HeadQuery] -> extentEmbedding / numberOfHeads -val headValueAxisExtent = Axis[HeadValue] -> extentEmbedding / numberOfHeads -val embeddingAxisExtent = Axis[Embedding] -> extentEmbedding -val embeddingMixedAxisExtent = Axis[EmbeddingMixed] -> extentEmbedding * 4 -val vocabAxisExtent = Axis[Vocab] -> vocabSize -val contextAxisExtent = Axis[Context] -> contextLength - -// Dimensions -trait Vocab derives Label -trait Embedding derives Label -trait Context derives Label -trait EmbeddingMixed derives Label - -trait Batch derives Label - -case class LayerNormalizationParams( - weight: Tensor1[Embedding, Float32], - bias: Tensor1[Embedding, Float32] -) - -case class LinearLayerParams[In, Out]( - weight: Tensor2[In, Out, Float32], - bias: Tensor1[Out, Float32] -) - -case class ProjectionLayerParams[In, Out]( - weight: Tensor2[In, Out, Float32] -) - -trait Head derives Label -trait HeadKey derives Label -trait HeadQuery derives Label -trait HeadValue derives Label - -case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float32], val bias: Tensor2[Head, Kind, Float32]) - -case class MultiHeadAttentionParams( - wq: HeadsParams[HeadQuery], - wk: HeadsParams[HeadKey], - wv: HeadsParams[HeadValue], - proj: LinearLayerParams[Head |*| HeadValue, Embedding] -) derives TensorTree - -case class EmbeddingMixerParams( - c_fc: LinearLayerParams[Embedding, EmbeddingMixed], - c_proj: LinearLayerParams[EmbeddingMixed, Embedding] -) derives TensorTree - -case class TransformerLayerParams( - ln1: LayerNormalizationParams, - attn: MultiHeadAttentionParams, - ln2: LayerNormalizationParams, - embeddingMixer: EmbeddingMixerParams -) derives TensorTree - -case class GPT2Params( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], - positionalEmbeddings: Tensor2[Context, Embedding, Float32], - layers: List[TransformerLayerParams], - outputNormalization: LayerNormalizationParams -) derives TensorTree - -object GPT2Params: - - def init(initKey: Random.Key): GPT2Params = - def initLayerNormalizationParams(): LayerNormalizationParams = - LayerNormalizationParams( - weight = Tensor(Shape(embeddingAxisExtent)).fill(1f), - bias = Tensor(Shape(embeddingAxisExtent)).fill(0f) - ) - def initMutliHeadAttentionParams(key: Random.Key): MultiHeadAttentionParams = - MultiHeadAttentionParams( - wq = HeadsParams( - weights = Normal.standardIsotropic(Shape(headAxisExtent, embeddingAxisExtent, headQueryAxisExtent), scale = 0.02f).sample(key), - bias = Tensor(Shape(headAxisExtent, headQueryAxisExtent)).fill(0f) - ), - wk = HeadsParams( - weights = Normal.standardIsotropic(Shape(headAxisExtent, embeddingAxisExtent, headKeyAxisExtent), scale = 0.02f).sample(key), - bias = Tensor(Shape(headAxisExtent, headKeyAxisExtent)).fill(0f) - ), - wv = HeadsParams( - weights = Normal.standardIsotropic(Shape(headAxisExtent, embeddingAxisExtent, headValueAxisExtent), scale = 0.02f).sample(key), - bias = Tensor(Shape(headAxisExtent, headValueAxisExtent)).fill(0f) - ), - proj = LinearLayerParams( - weight = Normal.standardIsotropic(Shape(headAxisExtent * headValueAxisExtent, embeddingAxisExtent), scale = 0.02f).sample(key), - bias = Tensor(Shape(embeddingAxisExtent)).fill(0f) - ) - ) - def initEmbeddingMixerParams(key: Random.Key): EmbeddingMixerParams = - val (fcKey, projKey) = key.split2() - EmbeddingMixerParams( - c_fc = LinearLayerParams( - weight = Normal.standardIsotropic(Shape(embeddingAxisExtent, embeddingMixedAxisExtent), scale = 0.02f).sample(fcKey), - bias = Tensor(Shape(embeddingMixedAxisExtent)).fill(0f) - ), - c_proj = LinearLayerParams( - weight = Normal.standardIsotropic(Shape(embeddingMixedAxisExtent, embeddingAxisExtent), scale = 0.02f).sample(projKey), - bias = Tensor(Shape(embeddingAxisExtent)).fill(0f) - ) - ) - def initTransformerLayerParams(key: Random.Key): TransformerLayerParams = - val (attnKey, mixKey) = key.split2() - TransformerLayerParams( - ln1 = initLayerNormalizationParams(), - attn = initMutliHeadAttentionParams(attnKey), - ln2 = initLayerNormalizationParams(), - embeddingMixer = initEmbeddingMixerParams(mixKey) - ) - val keys = initKey.split(4) - val layerKeys = keys(2).split(numberOfLayers) - GPT2Params( - vocabularyEmbeddings = Normal.standardIsotropic(Shape(vocabAxisExtent, embeddingAxisExtent), scale = 0.02f).sample(keys(0)), - positionalEmbeddings = Normal.standardIsotropic(Shape(contextAxisExtent, embeddingAxisExtent), scale = 0.02f).sample(keys(1)), - layers = layerKeys.map(initTransformerLayerParams).toList, - outputNormalization = initLayerNormalizationParams() - ) - -case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int32] => Tensor1[Context, Int32]): - - private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): - override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = - x.dot(Axis[In])(params.weight) + params.bias - - private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - private val hiddenLayer = LinearLayer(params.c_fc) - private val outputLayer = LinearLayer(params.c_proj) - // TODO add dropout - - def apply(in: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - in.vmap(Axis[Context])(x => - val hidden = gelu(hiddenLayer(x)) - outputLayer(hidden) - ) - - private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): - def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = - x.dot(Axis[In])(params.weight) - - private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - - private val projection = LinearLayer(params.proj) - - def apply(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - val heads = zipvmap(Axis[Head])(params.wq.weights, params.wq.bias, params.wk.weights, params.wk.bias, params.wv.weights, params.wv.bias): - attention.tupled(_)(x) - heads.vmap(Axis[Context])(heads => projection(heads.flatten)) - - private def attention( - wq: Tensor2[Embedding, HeadQuery, Float32], - wqBias: Tensor1[HeadQuery, Float32], - wk: Tensor2[Embedding, HeadKey, Float32], - wkBias: Tensor1[HeadKey, Float32], - wv: Tensor2[Embedding, HeadValue, Float32], - wvBias: Tensor1[HeadValue, Float32] - )(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, HeadValue, Float32] = - - trait AttnWeights derives Label - - def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float32]): Tensor2[Context, Prime[Context], Float32] = - val ctxLength = attnScores.shape(Axis[Context]) - val causalMask = tril(Tensor(Shape((Axis[Context] -> ctxLength, Axis[Prime[Context]] -> ctxLength))).fill(true)) - where(causalMask, attnScores, Tensor.like(attnScores).fill(Float.NegativeInfinity)) - - val queries = x.dot(Axis[Embedding])(wq) +! wqBias - val keys = x.dot(Axis[Embedding])(wk) +! wkBias - val values = x.dot(Axis[Embedding])(wv) +! wvBias - val dk = Tensor0(Math.sqrt(keys.shape(Axis[HeadKey])).toFloat) - val attnScores = (queries.dot(Axis[HeadQuery ~ HeadKey])(keys) /! dk) - val attnWeights = causalMasking(attnScores) - .vmap(Axis[Context])(attnScore => softmax(attnScore).relabelTo(Axis[AttnWeights])) - val res = attnWeights.dot(Axis[AttnWeights ~ Context])(values) - res - - private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float32] => Tensor1[Embedding, Float32]): - - private def standardize(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = - val x0 = x -! x.mean - val variance = x0.pow(2).mean - val epsilon = 1e-6f - x0 /! (variance + epsilon).sqrt - - def apply(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = - standardize(x) * params.weight + params.bias - - private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - private val embeddingMixer = EmbeddingMixer(params.embeddingMixer) - private val multiHeadAttention = MultiHeadAttention(params.attn) - private val preNormalization = LayerNorm(params.ln1) - private val postNormalization = LayerNorm(params.ln2) - - def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - var x = t - x = x + multiHeadAttention(x.vmap(Axis[Context])(preNormalization)) - x = x + embeddingMixer(x.vmap(Axis[Context])(postNormalization)) - x - - private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): - override def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = - layers.foldLeft(t): - case (t, layer) => layer(t) - - case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], positionalEmbeddings: Tensor2[Context, Embedding, Float32]): - - def apply(tokens: Tensor1[Context, Int32]): Tensor2[Context, Embedding, Float32] = - val embeddings = vocabularyEmbeddings.take(Axis[Vocab])(tokens) - embeddings + positionalEmbeddings - - case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float32] => Tensor1[Vocab, Float32]): - private val normalizationLayer = LayerNorm(normalization) - private val projection = ProjectionLayer(projectionParams) - override def apply(x: Tensor1[Embedding, Float32]): Tensor1[Vocab, Float32] = - projection(normalizationLayer(x)) - - private val embedder = Embedder(params.vocabularyEmbeddings, params.positionalEmbeddings) - private val transformerBlock = TransformerBlock(params.layers.map(TransformerLayer(_))) - private val outputLayer = OutputLayer( - params.outputNormalization, - ProjectionLayerParams(params.vocabularyEmbeddings.transpose) // Tying output weights with input embeddings - ) - - def logits(inputTokens: Tensor1[Context, Int32]): Tensor2[Context, Vocab, Float32] = - val startEmbeddings = embedder(inputTokens) - val endEmbeddings = transformerBlock(startEmbeddings) - endEmbeddings.vmap(Axis[Context])(x => outputLayer(x)) - - def probits(inputTokens: Tensor1[Context, Int32]): Tensor2[Context, Vocab, Float32] = - val x = logits(inputTokens) - val res = x.vapply(Axis[Vocab])(softmax) - return res - - def apply(inputTokens: Tensor1[Context, Int32]): Tensor1[Context, Int32] = - val x = probits(inputTokens) - val res = x.argmax(Axis[Vocab]) - return res - -@main def train(): Unit = - - import Tensor0.given - - dimwit.initialize() - - trait Data derives Label - - PythonSetup.initialize - lazy val np = py.module("numpy") - case class Sample( - input: Tensor2[Batch, Context, Int32], - labels: Tensor2[Batch, Context, Int32] - ) - def createDataset(key: Random.Key, pathToBinaryFile: String): Iterator[Sample] = - val data = dimwit.python.PyBridge.liftPyTensor1(Axis[Data], VType[Int32])(Jax.jnp.asarray(np.memmap(pathToBinaryFile, dtype = np.uint16, mode = "r"))) - def sliceContextBlockAt(idx: Tensor0[Int32]): Tensor1[Context, Int32] = - data - .dynamicSlice(idx, contextLength) - .relabelTo(Axis[Context]) - val numDataPoints = data.shape(Axis[Data]) - val lastValidIdx = numDataPoints - contextLength - val batchIndicesDist = IndependentDistribution.fromUnivariate(shape = Shape1(Axis[Batch] -> batchSize), Uniform(min = 0, max = lastValidIdx)) - Iterator.unfold(key): - case key => - val randomBatchIndex = batchIndicesDist.sample(key) - val x = randomBatchIndex.vmap(Axis[Batch])(index => sliceContextBlockAt(index)) - val y = randomBatchIndex.vmap(Axis[Batch])(index => sliceContextBlockAt(index + 1)) - Some(Sample(x, y), key.next) - - def createTrainDataset(key: Random.Key): Iterator[Sample] = - val pathToTrainBinaryFile = "data/nanoGPT/shakespeare_char/train.bin" - createDataset(key, pathToTrainBinaryFile) - def createValDataset(key: Random.Key): Iterator[Sample] = - val pathToValBinaryFile = "data/nanoGPT/shakespeare_char/val.bin" - createDataset(key, pathToValBinaryFile) - - val initParams = GPT2Params.init(Random.Key(42)) - - import Tensor0.given - val adam = Adam(learningRate = learningRate, b1 = beta1, b2 = beta2, epsilon = 1e-8f) - val adamW = AdamW(adam, weightDecayFactor = 1e-1f) - type AdamWState = adamW.State[GPT2Params] - - case class TrainingState( - params: GPT2Params, - adamWState: AdamWState, - loss: Tensor0[Float32] - ) - - def batchLoss(input: Tensor2[Batch, Context, Int32], labels: Tensor2[Batch, Context, Int32])(params: GPT2Params): Tensor0[Float32] = - val model = GPT2(params) - val logits = input.vmap(Axis[Batch])(model.logits) - val lossPerSample = zipvmap(Axis[Batch])(labels, logits): (labels, logits) => - val lossPerContextPosition = zipvmap(Axis[Context])(labels, logits): (label, logits) => - Loss.crossEntropy(logits = logits, label = label) - lossPerContextPosition.mean - lossPerSample.mean - - def gradientStep( - input: Tensor2[Batch, Context, Int32], - labels: Tensor2[Batch, Context, Int32], - state: TrainingState - ): TrainingState = - val lossBatch = batchLoss(input, labels) - val grads = Autodiff.grad(lossBatch)(state.params) - val loss = lossBatch(state.params) // TODO move to gradAndValue - val (params, adamWState) = adamW.update(grads, state.params, state.adamWState) - TrainingState(params = params, adamWState = adamWState, loss = loss) - val jitStep = jitDonatingUnsafe(gradientStep) - - def evaluate(input: Tensor2[Batch, Context, Int32], labels: Tensor2[Batch, Context, Int32], params: GPT2Params): Tensor0[Float32] = - batchLoss(input, labels)(params) - // val evalF = jit(evaluate) - val evalF = eagerCleanup(evaluate) - - def miniBatchGradientDescent( - samples: Iterator[Sample], - startState: TrainingState - ): Iterator[TrainingState] = - samples.scanLeft(startState): - case (state, sample) => - dimwit.gc() - jitStep(sample.input, sample.labels, state) - - val trainSampleStream = createTrainDataset(Random.Key(42)) - val valSampleStream = createValDataset(Random.Key(42)) - val initState = TrainingState(initParams, adamW.init(initParams), Tensor0(-1f)) - val trainTrajectory = miniBatchGradientDescent(trainSampleStream, initState) - val finalState = trainTrajectory.zipWithIndex - .drop(1) - .tapEach: - case (state, iter) => - if iter % trainLogInterval == 0 then - println( - List( - s"iter $iter", - f"loss: ${state.loss.item}%.2f" - ).mkString(", ") - ) - .tapEach: - case (state, iter) => - if iter % evalLogInterval == 0 then - val valLossStream = valSampleStream.map: sample => - evalF(sample.input, sample.labels, state.params).item // evalF is new - val avgValLoss = valLossStream.take(numberOfEvalIterations).sum / numberOfEvalIterations - println(f"Evaluation at iter $iter: validation loss: $avgValLoss%.2f") - safePyTree(state.params, f"gpt2_params_iter_$iter.pkl") - // dimwit.gc() - // Thread.sleep(100) - .drop(numIterations - 1) // iterate to final iteration - .next() - -@main def inference(): Unit = - // 1. Setup - PythonSetup.initialize - val checkpointPath = "gpt2_params_iter_1000.pkl" - - // 2. Load Weights - println(s"Loading model from $checkpointPath...") - val state: GPT2Params = loadPyTree[GPT2Params](checkpointPath) - val model = GPT2(state) - - // 3. Define Prompt - // val promptText = "To be, or not to be, that is the question:" - val promptText = "Romeo, Romeo, wherefore art thou" - println(s"Prompt: $promptText") - - // 4. Run Generation - val result = InferenceUtil.generate(model, promptText, maxNewTokens = 100) - - println("-" * 50) - println("Full Generated Text:") - println(result) - -object InferenceUtil: - // Standard characters from the shakespeare_char dataset (Vocab Size = 65) - val chars = "\n" + " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - println(s"Vocab Size: ${chars.length}") - val charToInt = chars.zipWithIndex.toMap - val intToChar = chars.zipWithIndex.map(_.swap).toMap - - def encode(s: String): List[Int] = s.map(c => charToInt.getOrElse(c, 0)).toList - def decode(l: List[Int]): String = l.map(i => intToChar.getOrElse(i, ' ')).mkString - - def generate( - model: GPT2, - prompt: String, - maxNewTokens: Int, - temperatur: Float = 1.0f - ): String = - - var currentTokens = encode(prompt) - - val ctxExtent = Axis[Context] -> Config.contextLength - - println(s"Generating $maxNewTokens tokens, starting at prompt length ${currentTokens.length}...") - - val sampleKey = Random.Key.fromTime() - - for i <- 0 until maxNewTokens do - val window = - if currentTokens.length > Config.contextLength - then currentTokens.takeRight(Config.contextLength) - else currentTokens - - val effectiveLength = window.length - - val inputTensor = Tensor(Shape(ctxExtent)).fill(0) - - val paddedData = window ++ List.fill(Config.contextLength - effectiveLength)(0) - - val inputData = paddedData.toArray - val currentBatch = Tensor(Shape(ctxExtent)).fromArray(inputData) - val logits = model.logits(currentBatch) - - val lastTokenIndex = effectiveLength - 1 - val nextTokenLogits = logits.slice(Axis[Context].at(lastTokenIndex)) - val nextTokenId = - if temperatur == 0f - then nextTokenLogits.argmax(Axis[Vocab]).item - else Categorical.fromFloat(softmax(nextTokenLogits /! temperatur)).sample(sampleKey).item - - currentTokens = currentTokens :+ nextTokenId - - println((nextTokenId, decode(List(nextTokenId)))) - - println() - decode(currentTokens) diff --git a/examples/src/main/scala/complex/GPTCheckpoint.scala b/examples/src/main/scala/complex/GPTCheckpoint.scala deleted file mode 100644 index 5cfabb4b..00000000 --- a/examples/src/main/scala/complex/GPTCheckpoint.scala +++ /dev/null @@ -1,24 +0,0 @@ -package src.main.scala.complex - -import dimwit.autodiff.TensorTree -import me.shadaj.scalapy.py - -def safePyTree[T: TensorTree](value: T, path: String): Unit = - val pickle = py.module("pickle") - val pyTree = TensorTree[T].toPyTree(value) - val file = py.Dynamic.global.open(path, "wb") - try - pickle.dump(pyTree, file, protocol = 5) - println(s"Saved checkpoint: $path") - finally - file.close() - -def loadPyTree[T: TensorTree](path: String): T = - val pickle = py.module("pickle") - val file = py.Dynamic.global.open(path, "rb") - - try - val pyTree = pickle.load(file) - TensorTree[T].fromPyTree(pyTree) - finally - file.close() diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index 3f5995fd..f42fa65d 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -1,7 +1,5 @@ package examples.complex.vae -import examples.timed - import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.* @@ -9,11 +7,10 @@ import dimwit.autodiff.FloatTree.* import dimwit.stats.Normal import dimwit.random.Random import examples.dataset.MNISTLoader -import nn.LinearLayer -import nn.ActivationFunctions.relu -import nn.GradientDescent +import dimwit.nn.ActivationFunctions.relu +import dimwit.optimizer.GradientDescent import dimwit.jax.Jax -import nn.ActivationFunctions.sigmoid +import dimwit.nn.ActivationFunctions.sigmoid import dimwit.random.Random.Key import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} @@ -29,6 +26,33 @@ trait Latent derives Label trait DHidden1 derives Label trait DHidden2 derives Label +def timed[A](template: String)(block: => A): A = + val t0 = System.currentTimeMillis() + val result = block + println(s"$template took ${System.currentTimeMillis() - t0} ms") + result + +object LinearLayer: + + case class Params[In, Out](weight: Tensor2[In, Out, Float32], bias: Tensor1[Out, Float32]) + + object Params: + given [I: Label, O: Label]: TensorTree[Params[I, O]] = TensorTree.derived + + def apply[In: Label, Out: Label](paramKey: Key)( + inputDim: AxisExtent[In], + outputDim: AxisExtent[Out] + ): Params[In, Out] = + Params( + weight = Normal.standardNormal(Shape(inputDim, outputDim)).sample(paramKey), + bias = Tensor(Shape(outputDim)).fill(0.0f) + ) + +case class LinearLayer[In: Label, Out: Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, Float32], Tensor1[Out, Float32]]: + override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = + import params.{weight, bias} + x.dot(Axis[In])(weight) + bias + class Encoder(p: Encoder.Params): val layer1 = LinearLayer(p.layer1) diff --git a/examples/src/main/scala/package.scala b/examples/src/main/scala/package.scala deleted file mode 100644 index 8da49595..00000000 --- a/examples/src/main/scala/package.scala +++ /dev/null @@ -1,7 +0,0 @@ -package examples - -def timed[A](template: String)(block: => A): A = - val t0 = System.currentTimeMillis() - val result = block - println(s"$template took ${System.currentTimeMillis() - t0} ms") - result diff --git a/nn/src/main/scala/nn/Conv2DLayer.scala b/nn/src/main/scala/nn/Conv2DLayer.scala deleted file mode 100644 index 4e12b541..00000000 --- a/nn/src/main/scala/nn/Conv2DLayer.scala +++ /dev/null @@ -1,26 +0,0 @@ -package nn - -import dimwit.* -import dimwit.random.Random.Key -import dimwit.stats.Normal - -object Conv2DLayer: - - case class Params[S1, S2, InChannel, OutChannel, V]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V] - ) - - object Params: - given [S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V]: TensorTree[Params[S1, S2, InChannel, OutChannel, V]] = TensorTree.derived - - def apply[S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V: IsFloating](paramKey: Key)(kernelShape: Shape[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple]): Params[S1, S2, InChannel, OutChannel, V] = - Params(kernel = Normal.standardNormal(kernelShape).sample(paramKey).asFloat(VType[V])) - -case class Conv2DLayer[S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V: IsFloating]( - params: Conv2DLayer.Params[S1, S2, InChannel, OutChannel, V], - stride: Stride2[S1, S2] | Int = 1, - padding: Padding = Padding.SAME -): - - def apply(x: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] = - x.conv2d(params.kernel, stride, padding) diff --git a/nn/src/main/scala/nn/LinearLayer.scala b/nn/src/main/scala/nn/LinearLayer.scala deleted file mode 100644 index 8934f997..00000000 --- a/nn/src/main/scala/nn/LinearLayer.scala +++ /dev/null @@ -1,46 +0,0 @@ -package nn - -import dimwit.* -import dimwit.random.Random -import dimwit.random.Random.Key -import dimwit.tensor.VType -import dimwit.stats.Normal - -object LinearLayer: - - case class Params[In, Out](weight: Tensor2[In, Out, Float32], bias: Tensor1[Out, Float32]) - - object Params: - given [I: Label, O: Label]: TensorTree[Params[I, O]] = TensorTree.derived - - def apply[In: Label, Out: Label](paramKey: Key)( - inputDim: AxisExtent[In], - outputDim: AxisExtent[Out] - ): Params[In, Out] = - Params( - weight = Normal.standardNormal(Shape(inputDim, outputDim)).sample(paramKey), - bias = Tensor(Shape(outputDim)).fill(0.0f) - ) - -case class LinearLayer[In: Label, Out: Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, Float32], Tensor1[Out, Float32]]: - override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = - import params.{weight, bias} - x.dot(Axis[In])(weight) + bias - -object LinearMap: - - case class Params[In](weight: Tensor1[In, Float32], bias: Tensor0[Float32]) - - object Params: - given [In: Label]: TensorTree[Params[In]] = TensorTree.derived - - def apply[In: Label](paramKey: Key)(inputDim: AxisExtent[In]): Params[In] = - Params( - weight = Normal.standardNormal(Shape(inputDim)).sample(paramKey), - bias = Tensor0(0.0f) - ) - -case class LinearMap[In: Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, Float32], Tensor0[Float32]]: - override def apply(x: Tensor1[In, Float32]): Tensor0[Float32] = - import params.{weight, bias} - x.dot(Axis[In])(weight) + bias diff --git a/nn/src/main/scala/nn/Loss.scala b/nn/src/main/scala/nn/Loss.scala deleted file mode 100644 index f24603ae..00000000 --- a/nn/src/main/scala/nn/Loss.scala +++ /dev/null @@ -1,20 +0,0 @@ -package nn - -import dimwit.* -import nn.ActivationFunctions.softmax - -object Loss: - - // TODO move this to a more general utils place? - private def logsumexp[L: Label](logits: Tensor1[L, Float32]): Tensor0[Float32] = - val maxLogit = logits.max(Axis[L]) - val logSumShifted = (logits -! maxLogit).exp.sum.log - maxLogit + logSumShifted - - def crossEntropy[L: Label]( - logits: Tensor1[L, Float32], - label: Tensor0[Int32] - ): Tensor0[Float32] = - val targetLogit = logits.slice(Axis[L].at(label)) - val logNormalizer = logsumexp(logits) - logNormalizer - targetLogit diff --git a/nn/src/main/scala/nn/TransposeConv2DLayer.scala b/nn/src/main/scala/nn/TransposeConv2DLayer.scala deleted file mode 100644 index 096f8067..00000000 --- a/nn/src/main/scala/nn/TransposeConv2DLayer.scala +++ /dev/null @@ -1,42 +0,0 @@ -package nn - -import dimwit.* -import dimwit.random.Random.Key -import dimwit.stats.Normal - -object TransposeConvLayer: - - case class Params[S1, S2, InChannels, OutChannels]( - kernel: Tensor[S1 *: S2 *: InChannels *: OutChannels *: EmptyTuple, Float32] - ) - - object Params: - given [S1: Label, S2: Label, IC: Label, OC: Label]: TensorTree[Params[S1, S2, IC, OC]] = TensorTree.derived - - /** Initialize transpose convolutional layer parameters - * - * @param paramKey - * Random key for parameter initialization - * @param kernelShape - * Shape of the convolutional kernel, e.g., (KernelH, KernelW, InChannels, OutChannels) for 2D transpose conv - */ - def apply[S1: Label, S2: Label, InChannels: Label, OutChannels: Label](paramKey: Key)( - kernelShape: Shape[S1 *: S2 *: InChannels *: OutChannels *: EmptyTuple] - ): Params[S1, S2, InChannels, OutChannels] = - Params( - kernel = Normal.standardNormal(kernelShape).sample(paramKey) - ) - -case class TransposeConvLayer[S1: Label, S2: Label, InChannels: Label, OutChannels: Label]( - params: TransposeConvLayer.Params[S1, S2, InChannels, OutChannels], - stride: Int = 1, - padding: Padding = Padding.SAME -): - /** Apply transpose convolution to input tensor - * - * Note: For transpose convolution, the input has OutChannels (matching forward conv output) and the output has InChannels (matching forward conv input). This is the adjoint operation to forward convolution. - * - * Input: (Spatial..., OutChannels) Output: (Spatial..., InChannels) - */ - def apply(x: Tensor[S1 *: S2 *: OutChannels *: EmptyTuple, Float32]): Tensor[S1 *: S2 *: InChannels *: EmptyTuple, Float32] = - x.transposeConv2d(params.kernel, stride, padding) diff --git a/nn/src/main/scala/nn/package.scala b/nn/src/main/scala/nn/package.scala deleted file mode 100644 index c0bd2233..00000000 --- a/nn/src/main/scala/nn/package.scala +++ /dev/null @@ -1 +0,0 @@ -package object nn {} diff --git a/nn/src/main/scala/nn/GradientOptimizer.scala b/optimizer/src/main/scala/optimizer/GradientOptimizer.scala similarity index 99% rename from nn/src/main/scala/nn/GradientOptimizer.scala rename to optimizer/src/main/scala/optimizer/GradientOptimizer.scala index d6a348a7..2ffaa87e 100644 --- a/nn/src/main/scala/nn/GradientOptimizer.scala +++ b/optimizer/src/main/scala/optimizer/GradientOptimizer.scala @@ -1,4 +1,4 @@ -package nn +package dimwit.optimizer import dimwit.* import dimwit.autodiff.FloatTree.ops.* From da754de9d03acea485269b5f3aa37212d75a6d8c Mon Sep 17 00:00:00 2001 From: Benjamin Meyer Date: Wed, 20 May 2026 10:20:04 +0200 Subject: [PATCH 2/2] Move optimizer module into core module (as package) --- .github/workflows/ci.yml | 3 --- AGENTS.md | 4 ++-- build.sbt | 13 +++---------- .../scala/dimwit}/optimizer/GradientOptimizer.scala | 0 docs/quickstart.md | 2 +- mdocs/AGENTS.md | 4 ++-- mdocs/docs/quickstart.md | 2 +- 7 files changed, 9 insertions(+), 19 deletions(-) rename {optimizer/src/main/scala => core/src/main/scala/dimwit}/optimizer/GradientOptimizer.scala (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2fa9097c..2d9c9fa4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,9 +47,6 @@ jobs: - name: Compile core module run: sbt "project core" compile - - name: Compile nn module - run: sbt "project nn" compile - - name: Compile examples module run: sbt "project examples" compile diff --git a/AGENTS.md b/AGENTS.md index 40c66cc5..6bd69686 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -837,7 +837,7 @@ val jacFwd = Autodiff.jacFwd(linearMap) ```scala import dimwit.* -import nn.{GradientDescent, GradientOptimizer} +import dimwit.optimizer.{GradientDescent, GradientOptimizer} import dimwit.random.Random trait Feature derives Label @@ -884,7 +884,7 @@ val trained = optimizer.iterate(initModelParams)(gradFunc) ### Lion Optimizer ```scala -import nn.Lion +import dimwit.optimizer.Lion // Lion optimizer with momentum val lionOptimizer = Lion(learningRate = Tensor0(1e-3f), beta1 = Tensor0(0.9f), beta2 = Tensor0(0.99f), weightDecay = Tensor0(0.0f)) diff --git a/build.sbt b/build.sbt index 6144ddf8..0cb6d143 100644 --- a/build.sbt +++ b/build.sbt @@ -11,7 +11,7 @@ ThisBuild / resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/c addCommandAlias("testAndCoverage", "; clean; coverage; test; coverageReport") lazy val root = (project in file(".")) - .aggregate(core, optimizer, examples) + .aggregate(core, examples) .settings( name := "dimwit-root" ) @@ -43,16 +43,9 @@ lazy val core = (project in file("core")) Compile / packageDoc / publishArtifact := true ) -lazy val optimizer = (project in file("optimizer")) - .settings( - name := "dimwit-optimizer" - ) - .dependsOn(core) - // Examples subproject lazy val examples = (project in file("examples")) .dependsOn(core) - .dependsOn(optimizer) .settings( name := "dimwit-examples", // Examples use the same Scala version and dependencies as main project @@ -81,7 +74,7 @@ lazy val examples = (project in file("examples")) // Processes files in /mdocs that need to be copied to the root (e.g. README.md) lazy val docsRoot = (project in file(".dimwit-docs-root")) .enablePlugins(MdocPlugin) - .dependsOn(core, optimizer) + .dependsOn(core) .settings( name := "dimwit-docs-root", mdocIn := (ThisBuild / baseDirectory).value / "mdocs", @@ -98,7 +91,7 @@ lazy val docsRoot = (project in file(".dimwit-docs-root")) // Processes all other docs in /mdocs/docs/ → output to docs/ lazy val docs = (project in file(".dimwit-docs")) .enablePlugins(MdocPlugin) - .dependsOn(core, optimizer) + .dependsOn(core) .settings( name := "dimwit-docs", mdocIn := (ThisBuild / baseDirectory).value / "mdocs/docs", diff --git a/optimizer/src/main/scala/optimizer/GradientOptimizer.scala b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala similarity index 100% rename from optimizer/src/main/scala/optimizer/GradientOptimizer.scala rename to core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala diff --git a/docs/quickstart.md b/docs/quickstart.md index d22a7ba1..ba901c2e 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -12,7 +12,7 @@ Before we start exploring the features of DimWit, let's look at a simple example // main imports for basic tensor operations and automatic differentiation import dimwit.* import dimwit.Autodiff.grad // TODO replace with cleaner import after PR is merged -import nn.GradientDescent // TODO replace with cleaner import after refactoring +import dimwit.optimizer.GradientDescent // TODO replace with cleaner import after refactoring // labels for tensor axes trait Batch derives Label diff --git a/mdocs/AGENTS.md b/mdocs/AGENTS.md index 178764f6..3904081a 100644 --- a/mdocs/AGENTS.md +++ b/mdocs/AGENTS.md @@ -644,7 +644,7 @@ val jacFwd = Autodiff.jacFwd(linearMap) ```scala mdoc:reset:silent import dimwit.* -import nn.{GradientDescent, GradientOptimizer} +import dimwit.optimizer.{GradientDescent, GradientOptimizer} import dimwit.random.Random trait Feature derives Label @@ -691,7 +691,7 @@ val trained = optimizer.iterate(initModelParams)(gradFunc) ### Lion Optimizer ```scala mdoc:silent -import nn.Lion +import dimwit.optimizer.Lion // Lion optimizer with momentum val lionOptimizer = Lion(learningRate = Tensor0(1e-3f), beta1 = Tensor0(0.9f), beta2 = Tensor0(0.99f), weightDecay = Tensor0(0.0f)) diff --git a/mdocs/docs/quickstart.md b/mdocs/docs/quickstart.md index 2d248f9d..231768e8 100644 --- a/mdocs/docs/quickstart.md +++ b/mdocs/docs/quickstart.md @@ -12,7 +12,7 @@ Before we start exploring the features of DimWit, let's look at a simple example // main imports for basic tensor operations and automatic differentiation import dimwit.* import dimwit.Autodiff.grad // TODO replace with cleaner import after PR is merged -import nn.GradientDescent // TODO replace with cleaner import after refactoring +import dimwit.optimizer.GradientDescent // TODO replace with cleaner import after refactoring // labels for tensor axes trait Batch derives Label