Skip to content

Instantly share code, notes, and snippets.

@revdfdev
Created January 17, 2026 14:24
Show Gist options
  • Select an option

  • Save revdfdev/c868aec3ec53e2d7f8530f0a8c6c9bf0 to your computer and use it in GitHub Desktop.

Select an option

Save revdfdev/c868aec3ec53e2d7f8530f0a8c6c9bf0 to your computer and use it in GitHub Desktop.
Rag with permissioned retrieval
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/config"
bedrockruntime "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
bedrockagentruntime "github.com/aws/aws-sdk-go-v2/service/bedrockagentruntime"
barTypes "github.com/aws/aws-sdk-go-v2/service/bedrockagentruntime/types"
)
type Chunk struct {
Text string
Source string
Score float64
}
type Answer struct {
Text string
Sources []string
Refusal bool
Reason string
}
func main() {
ctx := context.Background()
awsCfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
log.Fatal(err)
}
agent := bedrockagentruntime.NewFromConfig(awsCfg)
model := bedrockruntime.NewFromConfig(awsCfg)
kbID := "YOUR_KB_ID"
modelID := "anthropic.claude-3-5-sonnet-20240620-v1:0"
userID := "user-123"
q := "What changed in prod latency since yesterday and why?"
ans, err := Ask(ctx, agent, model, kbID, modelID, userID, q)
if err != nil {
log.Fatal(err)
}
fmt.Println(ans.Text)
if ans.Refusal {
fmt.Println("refused:", ans.Reason)
}
for _, s := range ans.Sources {
fmt.Println("-", s)
}
}
func Ask(
ctx context.Context,
agent *bedrockagentruntime.Client,
model *bedrockruntime.Client,
kbID, modelID, userID, question string,
) (*Answer, error) {
// Don't even hit retrieval if the user shouldn't see this KB.
if !isAllowed(userID, "kb:"+kbID) {
return &Answer{
Refusal: true,
Text: "I can’t access the data needed to answer this.",
Reason: "access denied",
}, nil
}
chunks, err := retrieve(ctx, agent, kbID, question, 8)
if err != nil {
return nil, err
}
// If we can't ground it, don't try to be clever.
if len(chunks) < 2 {
return &Answer{
Refusal: true,
Text: "I don’t have enough verified context to answer this.",
Reason: "not enough relevant KB results",
}, nil
}
prompt, sources := makePrompt(question, chunks)
out, err := generate(ctx, model, modelID, prompt)
if err != nil {
return nil, err
}
// If the model didn't cite anything, treat it as ungrounded.
if !hasCitation(out, len(chunks)) {
return &Answer{
Refusal: true,
Text: "I couldn’t produce a grounded answer with citations.",
Reason: "missing citations",
}, nil
}
return &Answer{
Text: out,
Sources: sources,
}, nil
}
func retrieve(
ctx context.Context,
agent *bedrockagentruntime.Client,
kbID, question string,
k int32,
) ([]Chunk, error) {
in := &bedrockagentruntime.RetrieveInput{
KnowledgeBaseId: &kbID,
RetrievalQuery: &barTypes.KnowledgeBaseQuery{Text: &question},
RetrievalConfiguration: &barTypes.KnowledgeBaseRetrievalConfiguration{
VectorSearchConfiguration: &barTypes.KnowledgeBaseVectorSearchConfiguration{
NumberOfResults: &k,
},
},
}
resp, err := agent.Retrieve(ctx, in)
if err != nil {
return nil, fmt.Errorf("retrieve: %w", err)
}
out := make([]Chunk, 0, len(resp.RetrievalResults))
for _, r := range resp.RetrievalResults {
var text string
if r.Content != nil && r.Content.Text != nil {
text = strings.TrimSpace(*r.Content.Text)
}
if text == "" {
continue
}
var src string
if r.Location != nil {
b, _ := json.Marshal(r.Location)
src = string(b)
}
score := 0.0
if r.Score != nil {
score = *r.Score
}
out = append(out, Chunk{Text: text, Source: src, Score: score})
}
return out, nil
}
func makePrompt(question string, chunks []Chunk) (string, []string) {
var b strings.Builder
b.WriteString("Answer the question using only the context below.\n")
b.WriteString("If the context doesn't contain the answer, say so.\n")
b.WriteString("Cite sources like [S1], [S2] based on the context blocks.\n")
b.WriteString("Keep it to 3-6 bullets.\n\n")
b.WriteString("Question:\n")
b.WriteString(question)
b.WriteString("\n\nContext:\n")
sources := make([]string, 0, len(chunks))
for i, c := range chunks {
tag := fmt.Sprintf("S%d", i+1)
sources = append(sources, fmt.Sprintf("%s: %s", tag, shorten(oneLine(c.Source), 180)))
b.WriteString(fmt.Sprintf("[%s] (score=%.3f)\n", tag, c.Score))
b.WriteString(c.Text)
b.WriteString("\n\n")
}
return b.String(), sources
}
func generate(ctx context.Context, model *bedrockruntime.Client, modelID, prompt string) (string, error) {
payload := map[string]any{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 600,
"temperature": 0,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
}
body, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("marshal payload: %w", err)
}
resp, err := model.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: &modelID,
Body: body,
ContentType: strPtr("application/json"),
Accept: strPtr("application/json"),
})
if err != nil {
return "", fmt.Errorf("invoke model: %w", err)
}
var decoded map[string]any
if err := json.Unmarshal(resp.Body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal response: %w", err)
}
if content, ok := decoded["content"].([]any); ok && len(content) > 0 {
if first, ok := content[0].(map[string]any); ok {
if txt, ok := first["text"].(string); ok {
return strings.TrimSpace(txt), nil
}
}
}
return strings.TrimSpace(fmt.Sprintf("%v", decoded)), nil
}
func hasCitation(s string, max int) bool {
for i := 1; i <= max; i++ {
if strings.Contains(s, fmt.Sprintf("[S%d]", i)) {
return true
}
}
return false
}
func isAllowed(userID, resource string) bool {
_ = userID
_ = resource
return true
}
func oneLine(s string) string {
s = strings.ReplaceAll(s, "\n", " ")
return strings.Join(strings.Fields(s), " ")
}
func shorten(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
func strPtr(s string) *string { return &s }
// keeps time imported in case you add time-window metadata filters later
var _ = time.Now()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment