Skip to content

Instantly share code, notes, and snippets.

@thquinn
Created June 14, 2024 18:39
Show Gist options
  • Save thquinn/1211c697e6d341484ac99d7aeabd923a to your computer and use it in GitHub Desktop.
Save thquinn/1211c697e6d341484ac99d7aeabd923a to your computer and use it in GitHub Desktop.
LLM text generation without using the letter 'e'.
// .NET 8.0, uses LLamaSharp and LLamaSharp's CUDA backend NuGet packages.
using LLama.Common;
using LLama;
using System.Text;
using Laureate;
using LLama.Native;
using LLama.Batched;
string modelPath = @"C:/Users/Cae/Downloads/noromaid-v0.4-mixtral-instruct-8x7b-zloss.Q3_K_M.gguf";
var parameters = new ModelParams(modelPath) {
ContextSize = 128, // The longest length of chat as memory.
GpuLayerCount = 12 // How many layers to offload to GPU. Please adjust it according to your GPU memory.
};
using var model = LLamaWeights.LoadFromFile(parameters);
var executor = new BatchedExecutor(model, parameters);
var run = executor.Create();
run.Prompt(executor.Context.Tokenize("The following is a poem about waffles:\n\n"));
var sampler = new ScroetrySampler(new int[0]);
var decoder = new StreamingTokenDecoder(executor.Context);
while (true) {
await executor.Infer();
var token = sampler.Sample(executor.Context.NativeHandle, run.Sample(), ReadOnlySpan<LLamaToken>.Empty);
decoder.Add(token);
Console.Write(decoder.Read());
run.Prompt(token);
}
Console.ReadLine();
using LLama;
using LLama.Native;
using LLama.Sampling;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Laureate {
class ScroetrySampler(int[] tileCounts) : BaseSamplingPipeline {
string[]? tokens = null;
public override ISamplingPipeline Clone() {
throw new NotImplementedException();
}
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) {
// populate tokens
if (tokens == null) {
tokens = new string[logits.Length];
var decoder = new StreamingTokenDecoder(Encoding.UTF8, ctx);
for (int i = 0; i < tokens.Length; i++) {
decoder.Add(i);
tokens[i] = decoder.Read();
}
Console.WriteLine($"Populated sampler tokens.");
}
// set probability of tokens to 0
for (int i = 0; i < logits.Length; i++) {
if (tokens[i].Contains('e', StringComparison.InvariantCultureIgnoreCase)) {
logits[i] = 0;
}
}
// renormalize (might not be necessary?)
/*
float total = 0;
foreach (float logit in logits) {
total += logit;
}
for (int i = 0; i < logits.Length; i++) {
logits[i] /= total;
}
*/
}
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) {
candidates.Temperature(ctx, 0.8f);
candidates.TopK(ctx, 25);
return candidates.SampleToken(ctx);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment