Created
June 14, 2024 18:39
-
-
Save thquinn/1211c697e6d341484ac99d7aeabd923a to your computer and use it in GitHub Desktop.
LLM text generation without using the letter 'e'.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// .NET 8.0, uses LLamaSharp and LLamaSharp's CUDA backend NuGet packages. | |
using LLama.Common; | |
using LLama; | |
using System.Text; | |
using Laureate; | |
using LLama.Native; | |
using LLama.Batched; | |
string modelPath = @"C:/Users/Cae/Downloads/noromaid-v0.4-mixtral-instruct-8x7b-zloss.Q3_K_M.gguf"; | |
var parameters = new ModelParams(modelPath) { | |
ContextSize = 128, // The longest length of chat as memory. | |
GpuLayerCount = 12 // How many layers to offload to GPU. Please adjust it according to your GPU memory. | |
}; | |
using var model = LLamaWeights.LoadFromFile(parameters); | |
var executor = new BatchedExecutor(model, parameters); | |
var run = executor.Create(); | |
run.Prompt(executor.Context.Tokenize("The following is a poem about waffles:\n\n")); | |
var sampler = new ScroetrySampler(new int[0]); | |
var decoder = new StreamingTokenDecoder(executor.Context); | |
while (true) { | |
await executor.Infer(); | |
var token = sampler.Sample(executor.Context.NativeHandle, run.Sample(), ReadOnlySpan<LLamaToken>.Empty); | |
decoder.Add(token); | |
Console.Write(decoder.Read()); | |
run.Prompt(token); | |
} | |
Console.ReadLine(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LLama; | |
using LLama.Native; | |
using LLama.Sampling; | |
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace Laureate { | |
class ScroetrySampler(int[] tileCounts) : BaseSamplingPipeline { | |
string[]? tokens = null; | |
public override ISamplingPipeline Clone() { | |
throw new NotImplementedException(); | |
} | |
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) { | |
// populate tokens | |
if (tokens == null) { | |
tokens = new string[logits.Length]; | |
var decoder = new StreamingTokenDecoder(Encoding.UTF8, ctx); | |
for (int i = 0; i < tokens.Length; i++) { | |
decoder.Add(i); | |
tokens[i] = decoder.Read(); | |
} | |
Console.WriteLine($"Populated sampler tokens."); | |
} | |
// set probability of tokens to 0 | |
for (int i = 0; i < logits.Length; i++) { | |
if (tokens[i].Contains('e', StringComparison.InvariantCultureIgnoreCase)) { | |
logits[i] = 0; | |
} | |
} | |
// renormalize (might not be necessary?) | |
/* | |
float total = 0; | |
foreach (float logit in logits) { | |
total += logit; | |
} | |
for (int i = 0; i < logits.Length; i++) { | |
logits[i] /= total; | |
} | |
*/ | |
} | |
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) { | |
candidates.Temperature(ctx, 0.8f); | |
candidates.TopK(ctx, 25); | |
return candidates.SampleToken(ctx); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment