Created
January 17, 2026 14:24
-
-
Save revdfdev/c868aec3ec53e2d7f8530f0a8c6c9bf0 to your computer and use it in GitHub Desktop.
Rag with permissioned retrieval
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "log" | |
| "strings" | |
| "time" | |
| "github.com/aws/aws-sdk-go-v2/config" | |
| bedrockruntime "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | |
| bedrockagentruntime "github.com/aws/aws-sdk-go-v2/service/bedrockagentruntime" | |
| barTypes "github.com/aws/aws-sdk-go-v2/service/bedrockagentruntime/types" | |
| ) | |
| type Chunk struct { | |
| Text string | |
| Source string | |
| Score float64 | |
| } | |
| type Answer struct { | |
| Text string | |
| Sources []string | |
| Refusal bool | |
| Reason string | |
| } | |
| func main() { | |
| ctx := context.Background() | |
| awsCfg, err := config.LoadDefaultConfig(ctx) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| agent := bedrockagentruntime.NewFromConfig(awsCfg) | |
| model := bedrockruntime.NewFromConfig(awsCfg) | |
| kbID := "YOUR_KB_ID" | |
| modelID := "anthropic.claude-3-5-sonnet-20240620-v1:0" | |
| userID := "user-123" | |
| q := "What changed in prod latency since yesterday and why?" | |
| ans, err := Ask(ctx, agent, model, kbID, modelID, userID, q) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| fmt.Println(ans.Text) | |
| if ans.Refusal { | |
| fmt.Println("refused:", ans.Reason) | |
| } | |
| for _, s := range ans.Sources { | |
| fmt.Println("-", s) | |
| } | |
| } | |
| func Ask( | |
| ctx context.Context, | |
| agent *bedrockagentruntime.Client, | |
| model *bedrockruntime.Client, | |
| kbID, modelID, userID, question string, | |
| ) (*Answer, error) { | |
| // Don't even hit retrieval if the user shouldn't see this KB. | |
| if !isAllowed(userID, "kb:"+kbID) { | |
| return &Answer{ | |
| Refusal: true, | |
| Text: "I can’t access the data needed to answer this.", | |
| Reason: "access denied", | |
| }, nil | |
| } | |
| chunks, err := retrieve(ctx, agent, kbID, question, 8) | |
| if err != nil { | |
| return nil, err | |
| } | |
| // If we can't ground it, don't try to be clever. | |
| if len(chunks) < 2 { | |
| return &Answer{ | |
| Refusal: true, | |
| Text: "I don’t have enough verified context to answer this.", | |
| Reason: "not enough relevant KB results", | |
| }, nil | |
| } | |
| prompt, sources := makePrompt(question, chunks) | |
| out, err := generate(ctx, model, modelID, prompt) | |
| if err != nil { | |
| return nil, err | |
| } | |
| // If the model didn't cite anything, treat it as ungrounded. | |
| if !hasCitation(out, len(chunks)) { | |
| return &Answer{ | |
| Refusal: true, | |
| Text: "I couldn’t produce a grounded answer with citations.", | |
| Reason: "missing citations", | |
| }, nil | |
| } | |
| return &Answer{ | |
| Text: out, | |
| Sources: sources, | |
| }, nil | |
| } | |
| func retrieve( | |
| ctx context.Context, | |
| agent *bedrockagentruntime.Client, | |
| kbID, question string, | |
| k int32, | |
| ) ([]Chunk, error) { | |
| in := &bedrockagentruntime.RetrieveInput{ | |
| KnowledgeBaseId: &kbID, | |
| RetrievalQuery: &barTypes.KnowledgeBaseQuery{Text: &question}, | |
| RetrievalConfiguration: &barTypes.KnowledgeBaseRetrievalConfiguration{ | |
| VectorSearchConfiguration: &barTypes.KnowledgeBaseVectorSearchConfiguration{ | |
| NumberOfResults: &k, | |
| }, | |
| }, | |
| } | |
| resp, err := agent.Retrieve(ctx, in) | |
| if err != nil { | |
| return nil, fmt.Errorf("retrieve: %w", err) | |
| } | |
| out := make([]Chunk, 0, len(resp.RetrievalResults)) | |
| for _, r := range resp.RetrievalResults { | |
| var text string | |
| if r.Content != nil && r.Content.Text != nil { | |
| text = strings.TrimSpace(*r.Content.Text) | |
| } | |
| if text == "" { | |
| continue | |
| } | |
| var src string | |
| if r.Location != nil { | |
| b, _ := json.Marshal(r.Location) | |
| src = string(b) | |
| } | |
| score := 0.0 | |
| if r.Score != nil { | |
| score = *r.Score | |
| } | |
| out = append(out, Chunk{Text: text, Source: src, Score: score}) | |
| } | |
| return out, nil | |
| } | |
| func makePrompt(question string, chunks []Chunk) (string, []string) { | |
| var b strings.Builder | |
| b.WriteString("Answer the question using only the context below.\n") | |
| b.WriteString("If the context doesn't contain the answer, say so.\n") | |
| b.WriteString("Cite sources like [S1], [S2] based on the context blocks.\n") | |
| b.WriteString("Keep it to 3-6 bullets.\n\n") | |
| b.WriteString("Question:\n") | |
| b.WriteString(question) | |
| b.WriteString("\n\nContext:\n") | |
| sources := make([]string, 0, len(chunks)) | |
| for i, c := range chunks { | |
| tag := fmt.Sprintf("S%d", i+1) | |
| sources = append(sources, fmt.Sprintf("%s: %s", tag, shorten(oneLine(c.Source), 180))) | |
| b.WriteString(fmt.Sprintf("[%s] (score=%.3f)\n", tag, c.Score)) | |
| b.WriteString(c.Text) | |
| b.WriteString("\n\n") | |
| } | |
| return b.String(), sources | |
| } | |
| func generate(ctx context.Context, model *bedrockruntime.Client, modelID, prompt string) (string, error) { | |
| payload := map[string]any{ | |
| "anthropic_version": "bedrock-2023-05-31", | |
| "max_tokens": 600, | |
| "temperature": 0, | |
| "messages": []map[string]string{ | |
| {"role": "user", "content": prompt}, | |
| }, | |
| } | |
| body, err := json.Marshal(payload) | |
| if err != nil { | |
| return "", fmt.Errorf("marshal payload: %w", err) | |
| } | |
| resp, err := model.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ | |
| ModelId: &modelID, | |
| Body: body, | |
| ContentType: strPtr("application/json"), | |
| Accept: strPtr("application/json"), | |
| }) | |
| if err != nil { | |
| return "", fmt.Errorf("invoke model: %w", err) | |
| } | |
| var decoded map[string]any | |
| if err := json.Unmarshal(resp.Body, &decoded); err != nil { | |
| return "", fmt.Errorf("unmarshal response: %w", err) | |
| } | |
| if content, ok := decoded["content"].([]any); ok && len(content) > 0 { | |
| if first, ok := content[0].(map[string]any); ok { | |
| if txt, ok := first["text"].(string); ok { | |
| return strings.TrimSpace(txt), nil | |
| } | |
| } | |
| } | |
| return strings.TrimSpace(fmt.Sprintf("%v", decoded)), nil | |
| } | |
| func hasCitation(s string, max int) bool { | |
| for i := 1; i <= max; i++ { | |
| if strings.Contains(s, fmt.Sprintf("[S%d]", i)) { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| func isAllowed(userID, resource string) bool { | |
| _ = userID | |
| _ = resource | |
| return true | |
| } | |
| func oneLine(s string) string { | |
| s = strings.ReplaceAll(s, "\n", " ") | |
| return strings.Join(strings.Fields(s), " ") | |
| } | |
| func shorten(s string, n int) string { | |
| if len(s) <= n { | |
| return s | |
| } | |
| return s[:n] + "..." | |
| } | |
| func strPtr(s string) *string { return &s } | |
| // keeps time imported in case you add time-window metadata filters later | |
| var _ = time.Now() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment