Skip to content

Instantly share code, notes, and snippets.

@rbiswasfc
Created September 6, 2024 08:41
Show Gist options
  • Save rbiswasfc/576026f405ad696a377ce59d67582bc9 to your computer and use it in GitHub Desktop.
Save rbiswasfc/576026f405ad696a377ce59d67582bc9 to your computer and use it in GitHub Desktop.
Minai

Introduction

Minai is a flexible and lightweight deep learning training framework developed interactively in the part 2 of fast.ai's 2022-23 course.

The core of minai is the Learner class, which orchestrates the training process through an extensive callback system, allowing users to easily modify or extend any part of the training loop.

Minai provides a set of utilities for data handling, device management, visualization, and performance optimization, making it suitable for both quick prototyping and advanced deep learning projects.


README.md


minai

The mini version of fastai's miniai PyTorch framework created during the fastai course 2022-2023.

Installation

pip install minai

or to install from source clone this repo and run:

pip install -e .

Usage

This is still a work in progress - I'll add example usage soon. But in general, for examples from the course where you have from miniai.something import X you should be able to do from minai import X. You can do import minai as mi or even from minai import * for quick access to all the functions and things, if you're so inclined.

Tutorial 1 has a minimal example of fitting a model using minai - open it in Google colab here.

Tutorial 2 shows callbacks in action on a slightly more complex task - open it in Google colab here.

An example of the library in action: this notebook shows how to train a diffusion model on spectrograms to generate birdcalls, using minai. It is covered in the final lesson of Part 2 of the FastAI course.

And a lovely demo of use in the wild is this report by Thomas Capelle where he uses diffusion models to predict the next frame of an image sequence.


core.ipynb


Minai Core Functionality

Introduction

This is the core functionality of minai - the training framework developed interactively in the most recent FastAI course ('Impractical Deep Learning for Coders'). It is built on top of ideas from the fastai library, but is more flexible, and simpler.

#|export
import sys, gc, traceback, math, typing, random, numpy as np
from collections.abc import Mapping
from copy import copy
from itertools import zip_longest
from functools import partial, wraps
from operator import attrgetter, itemgetter

import matplotlib.pyplot as plt
import fastcore.all as fc
from fastprogress import progress_bar, master_bar

import torch, torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import default_collate

from torcheval.metrics import Mean
#|export
try: from accelerate import Accelerator
except: Accelerator=None
from torcheval.metrics import  MulticlassAccuracy
#| export
def set_seed(seed, deterministic=False):
    torch.use_deterministic_algorithms(deterministic)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

Datasets and Dataloaders

from datasets import load_dataset
import torchvision.transforms.functional as TF
ds = load_dataset('fashion_mnist')
trn = ds['train']
val = ds['test']
trn

Dataset({ features: ['image', 'label'], num_rows: 60000 })

A dataset needs to match the PyTorch dataset interface, which is very simple. It needs to have a __len__ method, and a __getitem__ method. The __getitem__ method typically returns a tuple of (x, y) where x is the input and y is the target.

#|export
class Dataset():
    "Simple dataset that combines two collections"
    def __init__(self, x, y): self.x,self.y = x,y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i],self.y[i]

A dataset can optionally include transformation functions.

#|export
class TfmDataset(Dataset):
    "Dataset subclass that transforms items"
    def __init__(self, x, y, tfm_x=None, tfm_y=None):
        super().__init__(x,y)
        self.tfm_x,self.tfm_y = tfm_x,tfm_y
        
    def __getitem__(self, i):
        x,y = self.x[i],self.y[i]
        return self.tfm_x(x) if self.tfm_x else x, self.tfm_y(y) if self.tfm_y else y
trn_ds = TfmDataset(trn['image'], trn['label'], tfm_x=TF.to_tensor)
val_ds = TfmDataset(val['image'], val['label'], tfm_x=TF.to_tensor)
#|export
def get_dls(train_ds, valid_ds, bs, **kwargs):
    "Convert train and validation datasets to data loaders"
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))
bs = 1024
trn_dl,val_dl = get_dls(trn_ds, val_ds, bs=bs)
xb,yb = next(iter(trn_dl))
xb.shape,yb[:10]

(torch.Size([1024, 1, 28, 28]), tensor([9, 3, 1, 2, 4, 4, 2, 9, 8, 4]))

Huggingface datasets typically return dictionaries, so here's a collate function to handle them:

#|export 
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

We can transform these dict-style datasets using with_transform:

def transforms(b):
    b['image'] = [TF.to_tensor(o)*2-1 for o in b['image']]
    return b
tds = ds.with_transform(transforms)
#|export
class DataLoaders:
    "Convert a `DatasetDict` into a pair of `DataLoader`s"
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f))

We can recreate the same DataLoaders as before using this:

dls = DataLoaders.from_dd(tds, batch_size=bs)
xb,yb = next(iter(dls.train))
xb.shape,yb
(torch.Size([1024, 1, 28, 28]), tensor([3, 0, 4,  ..., 0, 8, 4]))

Utilities for Displaying Images

We often work with images - this section has some convenient utilities for displaying them. At some point I'll add some examples :)

#| export
@fc.delegates(plt.Axes.imshow)
def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):
    "Show a PIL or PyTorch image on `ax`."
    if fc.hasattrs(im, ('cpu','permute','detach')):
        im = im.detach().cpu()
        if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=np.array(im)
    if im.shape[-1]==1: im=im[...,0]
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    if noframe: ax.axis('off')
    return ax
#| export
@fc.delegates(plt.subplots, keep=True)
def subplots(
    nrows:int=1, # Number of rows in returned axes grid
    ncols:int=1, # Number of columns in returned axes grid
    figsize:tuple=None, # Width, height in inches of the returned figure
    imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure
    suptitle:str=None, # Title to be set to returned figure
    **kwargs
): # fig and axs
    "A figure and set of subplots to display images of `imsize` inches"
    if figsize is None: figsize=(ncols*imsize, nrows*imsize)
    fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
    if suptitle is not None: fig.suptitle(suptitle)
    if nrows*ncols==1: ax = np.array([ax])
    return fig,ax
#| export
@fc.delegates(subplots)
def get_grid(
    n:int, # Number of axes
    nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))`
    ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)`
    title:str=None, # If passed, title set to the figure
    weight:str='bold', # Title font weight
    size:int=14, # Title font size
    **kwargs,
): # fig and axs
    "Return a grid of `n` axes, `rows` by `cols`"
    if nrows: ncols = ncols or int(np.floor(n/nrows))
    elif ncols: nrows = nrows or int(np.ceil(n/ncols))
    else:
        nrows = int(math.sqrt(n))
        ncols = int(np.floor(n/nrows))
    fig,axs = subplots(nrows, ncols, **kwargs)
    for i in range(n, nrows*ncols): axs.flat[i].set_axis_off()
    if title is not None: fig.suptitle(title, weight=weight, size=size)
    return fig,axs
#| export
@fc.delegates(subplots)
def show_images(ims:list, # Images to show
                nrows:typing.Union[int, None]=None, # Number of rows in grid
                ncols:typing.Union[int, None]=None, # Number of columns in grid (auto-calculated if None)
                titles:typing.Union[list, None]=None, # Optional list of titles for each image
                **kwargs):
    "Show all images `ims` as subplots with `rows` using `titles`"
    axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat
    for im,t,ax in zip_longest(ims, [] if titles is None else titles, axs): show_image(im, ax=ax, title=t)
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray_r'
feat = fc.nested_attr(dls, 'train.dataset.features')
names = feat['label'].names
titles = [names[i] for i in yb]
show_images(xb[:9], titles=titles[:9], figsize=(3,4));

Device management

Convenience functions related to device management. Note that if you do from minai import * then def_device will be defined and will be used as the default in functions like to_device() unless you specify otherwise.

#|export
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
#|export
def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)
#|export
def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    return x.detach().cpu()
#|export
def collate_device(b): return to_device(default_collate(b))

The Learner

The core of miniai is the Learner class. It binds together model, dataloaders, loss function and so on. The goal is to handle everything required for training a model while still providing complete control over any of the steps involved. This is done by leaning heavily on callbacks.

It may be instructive here to look at how a single batch is processed (code from the Learner definition below):

@with_cbs('batch')
def _one_batch(self):
    self.predict()
    self.callback('after_predict')
    self.get_loss()
    self.callback('after_loss')
    if self.training:
        self.backward()
        self.callback('after_backward')
        self.step()
        self.callback('after_step')
        self.zero_grad()

The @with_cbs('batch') decorator means that any of the models callbacks that define a before_batch or after_batch method will have said method called when appropriate. This allows you to do things like pre-process a batch of data before the model sees it, or log the loss after the batch has been processed. Within the _one_batch method, five special functions are called (predict, get_loss, backward, step, zero_grad). These are the five steps that are required for training a model. There are additional points where callbacks can be called, for example after the loss has been calculated but before the model has been updated.

The default Learner doesn't even define these methods - instead, it looks to see if they are defined in any of its callbacks. If you're not planning on doing anything fancy, then TrainCB is all you need. It defines all of the methods above. Alternatively, you can use TrainLearner which is a subclass of Learner that defines all of the methods above.

What's the point? These choices mean that if you're just fitting a basic model then you can use TrainLearner or TrainCB and you don't need to worry about any of the details. BUT if you do want to go in and add something fancy, you now have that option. Need a custom step to modify gradients before they are applied? No problem - re-define the step method. Need to make sure everything is on the right device before calling the model? No problem - check out DeviceCB to see how easy it is to do this!

Callback Infrastructure

#|export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass
#|export
class Callback(): order = 0
#|export
def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method(learn)
#|export
class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

Learner class and friends

#|export
from itertools import cycle
#|export
class CycleDL():
    def __init__(self, items, sz=None):
        self.items = items
        self.sz = len(items) if sz is None else sz
        self.it = None

    def __len__(self): return len(self.items) if self.sz is None else self.sz
    def __iter__(self):
        if self.it is None: self.it = cycle(iter(self.items))
        for i in range(self.sz): yield next(self.it)
d = CycleDL(range(10), 3)
[list(d) for _ in range(5)]
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 0, 1], [2, 3, 4]]
#|export
class Learner():
    def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD, epoch_sz=None):
        cbs = fc.L(cbs)
        fc.store_attr()

    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        if self.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()

    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter,self.batch in enumerate(self.dl): self._one_batch()

    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.train_dl if training else self.dls.valid
        self._one_epoch()

    @with_cbs('fit')
    def _fit(self, train, valid):
        self.train_dl = self.dls.train
        if self.epoch_sz is not None: self.train_dl = CycleDL(self.train_dl, self.epoch_sz)
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid:
                with torch.inference_mode(): self.one_epoch(False)

    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        cbs = fc.L(cbs)
        self.cbs += cbs
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def __getattr__(self, name):
        if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
    
    @property
    def training(self): return self.model.training
#|export
def _get_inp(b, n_inp, inp_nm):
    if inp_nm is not None: return [b[inp_nm]]
    return b[:n_inp]

def _get_lbl(b, n_inp, lbl_nm):
    if lbl_nm is not None: return [b[lbl_nm]]
    return b[n_inp:]

def _get_preds(b, preds_nm):
    return b if preds_nm is None else getattr(b, preds_nm)
#|export
class TrainLearner(Learner):
    def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=torch.optim.SGD, epoch_sz=None,
                 n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None):
        super().__init__(model, dls, loss_func, lr, cbs, opt_func=opt_func, epoch_sz=epoch_sz)
        self.n_inp,self.inp_nm,self.lbl_nm,self.preds_nm = n_inp,inp_nm,lbl_nm,preds_nm

    def predict(self):
        inps = _get_inp(self.batch, self.n_inp, self.inp_nm)
        self.preds = self.model(*inps)

    def get_loss(self):
        lbls = _get_lbl(self.batch, self.n_inp, self.lbl_nm)
        preds = _get_preds(self.preds, self.preds_nm)
        self.loss = self.loss_func(preds, *lbls)

    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()
#|export
class TrainCB(Callback):
    def __init__(self, n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None):
        self.n_inp = n_inp
        self.n_inp,self.inp_nm,self.lbl_nm,self.preds_nm = n_inp,inp_nm,lbl_nm,preds_nm

    def predict(self, learn):
        inps = _get_inp(learn.batch, self.n_inp, self.inp_nm)
        learn.preds = learn.model(*inps)

    def get_loss(self, learn):
        lbls = _get_lbl(learn.batch, self.n_inp, self.lbl_nm)
        preds = _get_preds(learn.preds, self.preds_nm)
        learn.loss = learn.loss_func(preds, *lbls)

    def backward(self, learn): learn.loss.backward()
    def step(self, learn): learn.opt.step()
    def zero_grad(self, learn): learn.opt.zero_grad()
#|export
class MomentumLearner(TrainLearner):
    def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=torch.optim.SGD, epoch_sz=None,
                 n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None, mom=0.85):
        self.mom = mom
        super().__init__(model, dls, loss_func, lr, cbs, opt_func=opt_func, epoch_sz=epoch_sz, n_inp=n_inp,
                        inp_nm=inp_nm, lbl_nm=lbl_nm, preds_nm=preds_nm)

    def zero_grad(self):
        with torch.no_grad():
            for p in self.model.parameters():
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad *= self.mom

Some useful callbacks

DeviceCB - makes sure everything is on the right device:

#|export
class DeviceCB(Callback):
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)

Sometimes you only want to run a single batch, for example when debugging. This callback allows you to do that, making use of the CancelFitException:

#|export
class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn): raise CancelFitException()

We rely on torcheval metrics for calculating metrics. By defauly, adding a MetricsCB will track the loss. But you can feed in other metrics as needed:

#|export
class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

ProgressCB is a simple callback that prints out the progress of training. It's not very sophisticated, but it's useful for debugging. Any metrics being tracked via MetricsCB will be printed out as well. And you can set plot=True to get a plot of the loss during training:

#|export
class ProgressCB(Callback):
    order = MetricsCB.order+1
    def __init__(self, plot=False): self.plot = plot
    def before_fit(self, learn):
        learn.epochs = self.mbar = master_bar(learn.epochs)
        self.first = True
        if hasattr(learn, 'metrics'): learn.metrics._log = self._log
        self.losses = []

    def _log(self, d):
        if self.first:
            self.mbar.write(list(d), table=True)
            self.first = False
        self.mbar.write(list(d.values()), table=True)

    def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
    def after_batch(self, learn):
        learn.dl.comment = f'{learn.loss:.3f}'
        if self.plot and hasattr(learn, 'metrics') and learn.training:
            self.losses.append(learn.loss.item())
            self.mbar.update_graph([[fc.L.range(self.losses), self.losses]])

There are callbacks for all sorts of things, here are some common ones.

cbs = [
    TrainCB(), # Handles the core steps in the training loop
    DeviceCB(), # Puts data and model on GPU
    MetricsCB(accuracy=MulticlassAccuracy())
]

We'll try a very basic convnet:

def conv(nin,nout):
    return [nn.Conv2d(nin, nout, kernel_size=3, stride=2, padding=1),
            nn.ReLU(), nn.BatchNorm2d(nout)]

def get_model():
    return nn.Sequential(*conv(1, 16), *conv(16,32), *conv(32,32),
                         nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(32,10))
opt = partial(torch.optim.AdamW, betas=(0.9,0.95), eps=1e-5)
learn = Learner(get_model(), dls, nn.CrossEntropyLoss(), lr=0.05, cbs=cbs, opt_func=opt)
learn.fit(1, cbs=[ProgressCB(plot=True)])
accuracy loss epoch train
0.706 0.808 0 train
0.699 0.957 0 eval
#| export
class CapturePreds(Callback):
    def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[]
    def after_batch(self, learn):
        self.all_inps. append(to_cpu(learn.batch[0]))
        self.all_preds.append(to_cpu(learn.preds))
        self.all_targs.append(to_cpu(learn.batch[1]))
    def after_fit(self, learn):
        self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps])
#| export
@fc.patch
def capture_preds(self: Learner, cbs=None, inps=False):
    cp = CapturePreds()
    with torch.inference_mode(): self.fit(1, train=False, cbs=[cp]+fc.L(cbs))
    res = cp.all_preds,cp.all_targs
    if inps: res = res+(cp.all_inps,)
    return res
preds,targs = learn.capture_preds()
preds.shape,targs.shape
(torch.Size([10000, 10]), torch.Size([10000]))
(preds.argmax(1)==targs).float().mean()
tensor(0.6987)

If you want logging more often than the end of each full pass through the data loader, set epoch_sz to be the number of batches you'd like to count as an "epoch".

learn = Learner(get_model(), dls, nn.CrossEntropyLoss(), lr=0.05,
                cbs=cbs, opt_func=opt, epoch_sz=10)
learn.fit(3, cbs=[ProgressCB()])
accuracy loss epoch train
0.398 1.579 0 train
0.517 1.706 0 eval
0.671 0.935 1 train
0.449 1.795 1 eval
0.729 0.733 2 train
0.712 0.810 2 eval
#| export
@fc.patch
@fc.delegates(show_images)
def show_image_batch(self:Learner, max_n=9, cbs=None, **kwargs):
    self.fit(1, cbs=[SingleBatchCB()]+fc.L(cbs))
    xb,yb = self.batch
    feat = fc.nested_attr(self.dls, 'train.dataset.features')
    if feat is None: titles = np.array(yb)
    else:
        names = feat['label'].names
        titles = [names[i] for i in yb]
    show_images(xb[:max_n], titles=titles[:max_n], **kwargs)
learn.show_image_batch(figsize=(3,4))

Learning Rate Finder

A useful method from FastAI, the learning rate finder tries batches at ever-increasing LRs and plots the loss. It's a good way to get a sense of what LR to use for training. Note that we also patch in the lr_find method to the Learner class, so you can call it directly.

#|export
class LRFinderCB(Callback):
    def __init__(self, gamma=1.3, max_mult=3): fc.store_attr()
    
    def before_fit(self, learn):
        self.sched = ExponentialLR(learn.opt, self.gamma)
        self.lrs,self.losses = [],[]
        self.min = math.inf

    def after_batch(self, learn):
        if not learn.training: raise CancelEpochException()
        self.lrs.append(learn.opt.param_groups[0]['lr'])
        loss = to_cpu(learn.loss)
        self.losses.append(loss)
        if loss < self.min: self.min = loss
        if loss > self.min*self.max_mult:
            raise CancelFitException()
        self.sched.step()

    def cleanup_fit(self, learn):
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')
#|export
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
    self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))
learn = Learner(get_model(), dls, nn.CrossEntropyLoss(), lr=0.1, cbs=cbs, opt_func=opt)
learn.lr_find()

Scheduling callbacks

#|export
class RecorderCB(Callback):
    def __init__(self, **d): self.d = d
    def before_fit(self, learn):
        self.recs = {k:[] for k in self.d}
        self.pg = learn.opt.param_groups[0]
    
    def after_batch(self, learn):
        if not learn.training: return
        for k,v in self.d.items():
            self.recs[k].append(v(self))

    def plot(self):
        for k,v in self.recs.items():
            plt.plot(v, label=k)
            plt.legend()
            plt.show()
#|export
class BaseSchedCB(Callback):
    def __init__(self, sched): self.sched = sched
    def before_fit(self, learn): self.schedo = self.sched(learn.opt)
    def _step(self, learn):
        if learn.training: self.schedo.step()      
#|export  
class BatchSchedCB(BaseSchedCB):
    def after_batch(self, learn): self._step(learn)
#|export
class EpochSchedCB(BaseSchedCB):
    def after_epoch(self, learn): self._step(learn)
#|export
class HasLearnCB(Callback):
    def before_fit(self, learn): self.learn = learn 
    def after_fit(self, learn): self.learn = None

Accelerated training loops

#|export 
class MixedPrecision(TrainCB):
    order = DeviceCB.order+10
    def __init__(self, n_inp=1, dtype=torch.bfloat16):
        super().__init__(n_inp=n_inp)
        self.dtype=dtype
    
    def before_fit(self, learn): self.scaler = torch.cuda.amp.GradScaler()

    def before_batch(self, learn):
        self.autocast = torch.autocast("cuda", dtype=self.dtype)
        self.autocast.__enter__()

    def after_loss(self, learn): self.autocast.__exit__(None, None, None)
        
    def backward(self, learn): self.scaler.scale(learn.loss).backward()

    def step(self, learn):
        self.scaler.step(learn.opt)
        self.scaler.update()
cbs = [MixedPrecision(), DeviceCB(), MetricsCB(accuracy=MulticlassAccuracy())]
learn = Learner(get_model(), dls, nn.CrossEntropyLoss(), lr=0.05, cbs=cbs, opt_func=opt)
learn.fit(1, cbs=[ProgressCB(plot=True)])
accuracy loss epoch train
0.704 0.823 0 train
0.737 0.723 0 eval
#|export
class AccelerateCB(TrainCB):
    order = DeviceCB.order+10
    def __init__(self, n_inp=1, mixed_precision="fp16"):
        super().__init__(n_inp=n_inp)
        self.acc = Accelerator(mixed_precision=mixed_precision)
        
    def before_fit(self, learn):
        learn.model,learn.opt,learn.dls.train,learn.dls.valid = self.acc.prepare(
            learn.model, learn.opt, learn.dls.train, learn.dls.valid)
    
    def after_fit(self, learn): learn.model = self.acc.unwrap_model(learn.model)
    def backward(self, learn): self.acc.backward(learn.loss)

Hooks

#| export
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook,'stats'): hook.stats = ([],[],[])
    acts = to_cpu(outp).float()
    hook.stats[0].append(acts.mean())
    hook.stats[1].append(acts.std())
    hook.stats[2].append(acts.abs().histc(40,0,10))
#| export
def get_min(h):
    h1 = torch.stack(h.stats[2]).t().float()
    return h1[0]/h1.sum(0)
#| export
class Hook():
    def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
    def remove(self): self.hook.remove()
    def __del__(self): self.remove()
#| export
class Hooks(list):
    def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()
    def __del__(self): self.remove()
    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)
    def remove(self):
        for h in self: h.remove()
#| export
class HooksCallback(Callback):
    def __init__(self, hookfunc, mod_filter=fc.noop, on_train=True, on_valid=False, mods=None):
        fc.store_attr()
        super().__init__()
    
    def before_fit(self, learn):
        if self.mods: mods=self.mods
        else: mods = fc.filter_ex(learn.model.modules(), self.mod_filter)
        self.hooks = Hooks(mods, partial(self._hookfunc, learn))

    def _hookfunc(self, learn, *args, **kwargs):
        if (self.on_train and learn.training) or (self.on_valid and not learn.training): self.hookfunc(*args, **kwargs)

    def after_fit(self, learn): self.hooks.remove()
    def __iter__(self): return iter(self.hooks)
    def __len__(self): return len(self.hooks)
#| export
# Thanks to @ste for initial version of histgram plotting code
def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
#|export
class ActivationStats(HooksCallback):
    def __init__(self, mod_filter=fc.noop): super().__init__(append_stats, mod_filter)

    def color_dim(self, figsize=(11,5)):
        fig,axes = get_grid(len(self), figsize=figsize)
        for ax,h in zip(axes.flat, self):
            show_image(get_hist(h), ax, origin='lower')

    def dead_chart(self, figsize=(11,5)):
        fig,axes = get_grid(len(self), figsize=figsize)
        for ax,h in zip(axes.flatten(), self):
            ax.plot(get_min(h))
            ax.set_ylim(0,1)

    def plot_stats(self, figsize=(10,4)):
        fig,axs = plt.subplots(1,2, figsize=figsize)
        for h in self:
            for i in 0,1: axs[i].plot(h.stats[i])
        axs[0].set_title('Means')
        axs[1].set_title('Stdevs')
        plt.legend(fc.L.range(self))
astats = ActivationStats(fc.risinstance(nn.Conv2d))
learn = Learner(get_model(), dls, nn.CrossEntropyLoss(), lr=0.05, cbs=cbs, opt_func=opt)
learn.fit(1, cbs=[ProgressCB(), astats])
accuracy loss epoch train
0.718 0.777 0 train
0.460 2.456 0 eval
plt.rc('image', cmap='viridis')
astats.color_dim()
astats.plot_stats()
astats.dead_chart()
#|export
def _flops(x, h, w):
    if x.dim()<3: return x.numel()
    if x.dim()==4: return x.numel()*h*w
#|export
@fc.patch
def summary(self:Learner):
    res = '|Module|Input|Output|Num params|MFLOPS|\n|--|--|--|--|--|\n'
    totp,totf = 0,0
    def _f(hook, mod, inp, outp):
        nonlocal res,totp,totf
        nparms = sum(o.numel() for o in mod.parameters())
        totp += nparms
        *_,h,w = outp.shape
        flops = sum(_flops(o, h, w) for o in mod.parameters())/1e6
        totf += flops
        res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|{flops:.1f}|\n'
    with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, cbs=SingleBatchCB())
    print(f"Tot params: {totp}; MFLOPS: {totf:.1f}")
    if fc.IN_NOTEBOOK:
        from IPython.display import Markdown
        return Markdown(res)
    else: print(res)
learn.summary()
Tot params: 14538; MFLOPS: 0.4
Module Input Output Num params MFLOPS
Conv2d (1024, 1, 28, 28) (1024, 16, 14, 14) 160 0.0
ReLU (1024, 16, 14, 14) (1024, 16, 14, 14) 0 0.0
BatchNorm2d (1024, 16, 14, 14) (1024, 16, 14, 14) 32 0.0
Conv2d (1024, 16, 14, 14) (1024, 32, 7, 7) 4640 0.2
ReLU (1024, 32, 7, 7) (1024, 32, 7, 7) 0 0.0
BatchNorm2d (1024, 32, 7, 7) (1024, 32, 7, 7) 64 0.0
Conv2d (1024, 32, 7, 7) (1024, 32, 4, 4) 9248 0.1
ReLU (1024, 32, 4, 4) (1024, 32, 4, 4) 0 0.0
BatchNorm2d (1024, 32, 4, 4) (1024, 32, 4, 4) 64 0.0
AdaptiveAvgPool2d (1024, 32, 4, 4) (1024, 32, 1, 1) 0 0.0
Flatten (1024, 32, 1, 1) (1024, 32) 0 0.0
Linear (1024, 32) (1024, 10) 330 0.0
#| export
class BatchTransformCB(Callback):
    def __init__(self, tfm, on_train=True, on_val=True): fc.store_attr()

    def before_batch(self, learn):
        if (self.on_train and learn.training) or (self.on_val and not learn.training):
            learn.batch = self.tfm(learn.batch)
def _norm(b): return (b[0]-xmean)/xstd,b[1]
norm = BatchTransformCB(_norm)
#|export
class GeneralRelu(nn.Module):
    def __init__(self, leak=None, sub=None, maxv=None):
        super().__init__()
        self.leak,self.sub,self.maxv = leak,sub,maxv

    def forward(self, x): 
        x = F.leaky_relu(x,self.leak) if self.leak is not None else F.relu(x)
        if self.sub is not None: x -= self.sub
        if self.maxv is not None: x.clamp_max_(self.maxv)
        return x

Data augmentation

import random
#|export
def _rand_erase1(x, pct, xm, xs, mn, mx):
    szx = int(pct*x.shape[-2])
    szy = int(pct*x.shape[-1])
    stx = int(random.random()*(1-pct)*x.shape[-2])
    sty = int(random.random()*(1-pct)*x.shape[-1])
    nn.init.normal_(x[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs)
    x.clamp_(mn, mx)
#|export
def rand_erase(x, pct=0.2, min_num=0, max_num = 4):
    xm,xs,mn,mx = x.mean(),x.std(),x.min(),x.max()
    num = random.randint(min_num, max_num)
    for i in range(num): _rand_erase1(x, pct, xm, xs, mn, mx)
    return x
plt.rcParams['image.cmap'] = 'gray_r'
xb,_ = next(iter(dls.train))
xbt = xb[:9]
rand_erase(xbt, 0.2, 3, 3)
show_images(xbt, imsize=1.3)
#|export
class RandErase(nn.Module):
    def __init__(self, pct=0.2, max_num=4):
        super().__init__()
        self.pct,self.max_num = pct,max_num
    def forward(self, x): return rand_erase(x, self.pct, self.max_num)
#|export
def _rand_copy1(x, pct):
    szx = int(pct*x.shape[-2])
    szy = int(pct*x.shape[-1])
    stx1 = int(random.random()*(1-pct)*x.shape[-2])
    sty1 = int(random.random()*(1-pct)*x.shape[-1])
    stx2 = int(random.random()*(1-pct)*x.shape[-2])
    sty2 = int(random.random()*(1-pct)*x.shape[-1])
    x[:,:,stx1:stx1+szx,sty1:sty1+szy] = x[:,:,stx2:stx2+szx,sty2:sty2+szy]
#|export
def rand_copy(x, pct=0.2, min_num=0, max_num=4):
    num = random.randint(min_num, max_num)
    for i in range(num): _rand_copy1(x, pct)
    return x
xbt = xb[:9]
xb,_ = next(iter(dls.train))
rand_copy(xbt, 0.2, 3, 3)
show_images(xbt, imsize=1.3)
#|export
class RandCopy(nn.Module):
    def __init__(self, pct=0.2, max_num=4):
        super().__init__()
        self.pct,self.max_num = pct,max_num
    def forward(self, x): return rand_copy(x, self.pct, self.max_num)

Clean memory

#|export
def clean_ipython_hist():
    # Code in this function mainly copied from IPython source
    if not 'get_ipython' in globals(): return
    ip = get_ipython()
    user_ns = ip.user_ns
    ip.displayhook.flush()
    pc = ip.displayhook.prompt_count + 1
    for n in range(1, pc): user_ns.pop('_i'+repr(n),None)
    user_ns.update(dict(_i='',_ii='',_iii=''))
    hm = ip.history_manager
    hm.input_hist_parsed[:] = [''] * pc
    hm.input_hist_raw[:] = [''] * pc
    hm._i = hm._ii = hm._iii = hm._i00 =  ''
#|export
def clean_tb():
    # h/t Piotr Czapla
    if hasattr(sys, 'last_traceback'):
        traceback.clear_frames(sys.last_traceback)
        delattr(sys, 'last_traceback')
    if hasattr(sys, 'last_type'): delattr(sys, 'last_type')
    if hasattr(sys, 'last_value'): delattr(sys, 'last_value')
#|export
def clean_mem():
    clean_tb()
    clean_ipython_hist()
    gc.collect()
    torch.cuda.empty_cache()

tutorial_01.ipynb


Minai Tutorial 1: Introduction

This is a minimal example to get you started, showing the basic flow of training a model using (mini)miniai.

Setup

Installing the library and importing a few useful things:

# Imports
import torch.nn as nn
import minai as mi  # So we can see what is from minai in this tutorial
import torchvision.transforms.functional as TF
from datasets import load_dataset
from torcheval.metrics import MulticlassAccuracy

Preparing the DataLoaders

The dataloaders is just a tiny wrapper around two PyTorch dataloaders, dls.train and dls.valid. You can create your dataloaders with dls=DataLoaders(train_dl, valid_dl) or use the from_dd method like we do here to load them from a DatasetDict (for datasets from Hugging Face with the datasets library):

# Load a dataset from HF
dataset = load_dataset('mnist')

# Specify transforms
def transforms(b):
    b['image'] = [TF.to_tensor(o) for o in b['image']]
    return b
dataset = dataset.with_transform(transforms)

# Turn it into dls
dls = mi.DataLoaders.from_dd(dataset, batch_size=64)

# Look at the data
xb, yb = next(iter(dls.train))
xb.shape, yb.shape, yb[:5]

The library has some useful utility functions, such as:

mi.show_images(xb[:5], ncols=5, titles=list(yb[:5].numpy()))

MNIST sample images Sample images from the MNIST dataset

You can do a lot of fancy stuff with your collate function if your data requires more processing or augmentation.

Prepare the Model

The model can be pretty much any PyTorch model, no changes needed here:

model = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten()
)

Create and Fit the Learner

The heart of (mini)miniai is the Learner class. It pulls together the data, model, and loss function, and can be extended in all sorts of cool ways using callbacks. Here's a somewhat minimal example, training our model on this classification task and plotting some stats as we do so:

# There are callbacks for all sorts of things, here are some common ones:
cbs = [
    mi.TrainCB(),  # Handles the core steps in the training loop. Can be left out if using TrainLearner
    mi.DeviceCB(),  # Handles making sure data and model are on the right device
    mi.MetricsCB(accuracy=MulticlassAccuracy()),  # Keep track of any relevant metrics
    mi.ProgressCB(),  # Displays metrics and loss during training, optionally plot=True for a pretty graph
]

# Nothing fancy for the loss function
loss_fn = nn.CrossEntropyLoss()

# The learner takes a model, dataloaders and loss function, plus some optional extras like a list of callbacks
learn = mi.Learner(model, dls, loss_fn, lr=0.1, cbs=cbs)

# And fit does the magic :)
learn.fit(3)

Training Results

accuracy loss epoch train
0.337 1.853 0 train
0.587 1.261 0 eval
0.707 0.906 1 train
0.801 0.648 1 eval
0.819 0.586 2 train
0.838 0.522 2 eval

When more complex tutorials are available, they will show some of the other existing callbacks in action. However, for most tasks, this is pretty much all you need! The model (learn.model) is just a regular PyTorch model, so you can save it and load it later somewhere else without needing any minai code at all.


tutorial_02.ipynb


Minai Tutorial 2: Customization via Callbacks

This tutorial demonstrates how to customize minai to suit your needs. We'll work with the same dataset as in tutorial 1, but we'll train a variational autoencoder (VAE) model to show how to modify the training process when dealing with a more complicated task.

Setup

First, let's install the required packages and import the necessary modules:

# Install requirements
!pip install -q minai datasets

# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import minai as mi
import torchvision.transforms.functional as TF
from functools import partial
from datasets import load_dataset

# Set default matplotlib colormap to 'gray'
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'

The Data

We'll use MNIST as in Tutorial 1 for demonstration purposes. The dataloader code remains identical, even though we don't use the class label in this tutorial.

# Load a dataset from HF
dataset = load_dataset('mnist')

# Specify transforms
def transforms(b):
    b['image'] = [TF.to_tensor(o) for o in b['image']]
    return b
dataset = dataset.with_transform(transforms)

# Turn it into dls
dls = mi.DataLoaders.from_dd(dataset, batch_size=64)

# Look at the data
xb, yb = next(iter(dls.train))
mi.show_images(xb[:5], ncols=5, titles=list(yb[:5].numpy()))

MNIST sample images Sample images from the MNIST dataset

The Model

The model is a variational auto-encoder (VAE). It takes in an image and compresses it down to a much smaller representation, 'z'. Unlike a normal autoencoder, the representation is not a single vector, but a pair of vectors: the mean and standard deviation of a normal distribution. The model then samples from this distribution to generate a new image.

class ConvVAE(nn.Module):
    def __init__(self, hdim=20):
        super(ConvVAE, self).__init__()
        # Encoder layers
        self.enc_conv1 = nn.Conv2d(1, 32, 3)
        self.enc_conv2 = nn.Conv2d(32, 32, 3)
        self.enc_fc1 = nn.Linear(32*24*24, 128)
        self.enc_fc2_mean = nn.Linear(128, hdim)
        self.enc_fc2_logvar = nn.Linear(128, hdim)

        # Decoder layers
        self.dec_fc1 = nn.Linear(hdim, 128)
        self.dec_fc2 = nn.Linear(128, 32*24*24)
        self.dec_unflatten = nn.Unflatten(dim=-1, unflattened_size=(32, 24, 24))
        self.dec_conv1 = nn.ConvTranspose2d(32, 32, 3)
        self.dec_conv2 = nn.ConvTranspose2d(32, 1, 3)

    def encode(self, x):
        # Encoding logic here

    def reparameterize(self, mean, logvar):
        # Reparameterization trick here

    def decode(self, x):
        # Decoding logic here

    def forward(self, x, return_dist = False):
        # Forward pass logic here

Custom Training Callback

For our VAE, we need to modify the training process. We'll create a custom training callback by subclassing the TrainCB and overriding the relevant methods:

class VaeTrainCB(mi.TrainCB):
    def predict(self, learn): 
        learn.preds = learn.model(learn.batch[0], return_dist=True) 
    
    def get_loss(self, learn):
        reconstruction, mean, logvar = learn.preds
        BCE = F.binary_cross_entropy(reconstruction, learn.batch[0], reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        learn.loss = BCE + KLD

Training the Model

Now we can train our VAE using the custom callback:

cbs = [VaeTrainCB(), mi.DeviceCB(), mi.MetricsCB(), mi.ProgressCB(plot=True)]
lr = 1e-3
opt_func = partial(torch.optim.Adam, weight_decay=1e-5)
learn = mi.Learner(model, dls, loss_func=None, lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(1)

Generating New Outputs

After training, we can generate new outputs by 'decoding' random latents:

z = torch.randn(8, 20).to(mi.def_device)
mi.show_images(model.decode(z))

Generated MNIST-like images Generated MNIST-like images from random latents

Logging Images During Training

To visualize the model's progress, we can create a callback to save generated images during training:

from torchvision.utils import make_grid
import numpy as np
from PIL import Image

def save_images(preds, save_name):
    # Image saving logic here

class LogPreds(mi.Callback):
    def __init__(self, log_every=100):
        self.log_every = log_every
    
    def after_batch(self, learn):
        if learn.iter % self.log_every == 0:
            z = torch.randn(8, 20).to(mi.def_device)
            preds = learn.model.decode(z)
            save_images(preds, f'preds_{learn.iter}.jpeg')

# Fit another epoch with our new callback
learn.fit(1, cbs=[LogPreds()])

Here's an example of a saved image during training:

Generated images during training Generated images at iteration 100

Conclusion

This tutorial demonstrated how callbacks can be useful to modify and extend the training process without resorting to a fully custom training loop and lots of boilerplate. You can further explore by:

  1. Saving samples to Weights and Biases for better visualization.
  2. Modifying the ProgressCB to plot samples alongside the loss graph.

These customizations allow you to tailor the training process to your specific needs while keeping the code modular and easy to maintain.


llm_example.ipynb


Minai LLM Fine-tuning Example

This example demonstrates how to fine-tune a large language model (LLM) using the minai library.

Setup

First, let's import the necessary libraries and set up our environment:

import os
import torch
import numpy as np
from minai.core import *
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch import nn, tensor
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from torch import optim
from functools import partial

set_seed(42)

Model and Tokenizer Preparation

We'll use the Llama-2-7b model for this example:

model_id = 'meta-llama/Llama-2-7b-hf'
m = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=0,
    use_flash_attention_2=True,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    use_cache=False
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

Data Preparation

We'll use the "know_sql" dataset for fine-tuning:

dataset = load_dataset("knowrohit07/know_sql", revision='f33425d13f9e8aab1b46fa945326e9356d6d5726', split="train")

def to_text(x):
    x['text'] = f"Context: {x['context']}\nQuestion: {x['question']}\nAnswer: {x['answer']}"
    return x

dataset = dataset.shuffle(42).map(to_text).filter(lambda x: len(x['text']) < 380)
train_dataset = dataset.select(range(0, len(dataset)-200))
eval_dataset = dataset.select(range(len(dataset)-200, len(dataset)))

DataLoader Setup

Create DataLoaders for training and evaluation:

def collate_fn(examples):
    input_ids = tokenizer([e['text'] for e in examples], return_tensors='pt', padding=True)['input_ids']
    return (input_ids[:, :-1], input_ids[:, 1:])

batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
dls = DataLoaders(train_dataloader, eval_dataloader)

Model Configuration

Configure the model for fine-tuning:

def loss_fn(x, y):
    return torch.nn.functional.cross_entropy(x.view(-1, x.shape[-1]), y.view(-1))

# Freeze first 24 layers for memory efficiency
n_freeze = 24
for param in m.parameters(): param.requires_grad = False
for param in m.lm_head.parameters(): param.requires_grad = True
for param in m.model.layers[n_freeze:].parameters(): param.requires_grad = True

m.gradient_checkpointing_enable()

Training Setup

Set up the training environment:

prog = ProgressCB(plot=True)
cbs = [DeviceCB(), MetricsCB()]

optim = partial(torch.optim.Adam, betas=(0.9, 0.99), eps=1e-5)
lr = 1e-3
sz = len(dls.train) // 50

learn = MomentumLearner(m, dls, loss_func=loss_fn, lr=lr, cbs=cbs, preds_nm='logits', epoch_sz=sz, mom=0.9)

Training

Now, let's train the model:

learn.fit(1, cbs=prog)

Testing

After training, we can test the model on a sample prompt:

prompt = f"Context: {eval_dataset[0]['context']}\nQuestion: {eval_dataset[0]['question']}\nAnswer:"
tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()

with torch.inference_mode():
    output = m.generate(tokenized_prompt, max_new_tokens=90)

print(prompt + tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True))

Conclusion

This example demonstrates how to fine-tune an LLM using the minai library. You can further customize the training process by adjusting hyperparameters, adding more callbacks, or implementing different optimization techniques.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment