Last active
April 2, 2017 23:43
-
-
Save irpap/1a17e723fda2de06afb2e7ab8187a0d5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"flag" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"math/rand" | |
"os" | |
"sort" | |
"strings" | |
"time" | |
"log" | |
sentences "gopkg.in/neurosnap/sentences.v1" | |
"gopkg.in/neurosnap/sentences.v1/english" | |
) | |
// BEGIN Marks the beginning of a sentence | |
const BEGIN = "_BEGIN_" | |
// END Marks the end of a sentence | |
const END = "_END_" | |
func main() { | |
prefixLen := flag.Int("prefix", 2, "prefix length in words") | |
filename := flag.String("f", "", "input file to build the markov chain from") | |
interactive := flag.Bool("i", false, "interactive mode") | |
hints := flag.Bool("h", false, "show hints in interactive mode") | |
n := flag.Int("n", 500, "maximum number of words to generate") | |
flag.Parse() // Parse command-line flags. | |
rand.Seed(time.Now().UnixNano()) // Seed the random number generator. | |
c := NewChain(*prefixLen) | |
var corpus Corpus | |
if *filename != "" { | |
if file, err := os.Open(*filename); err != nil { | |
log.Fatal(err) | |
} else { | |
corpus = NewCorpus(file) | |
} | |
} else { | |
corpus = NewCorpus(os.Stdin) | |
} | |
c.Build(corpus) | |
c.precomputeBeginning() | |
if *interactive { | |
for { | |
text := c.GenerateInteractiveSentence(*hints) | |
fmt.Println(text) | |
fmt.Println() | |
} | |
} else { | |
for i := 0; i < *n; { | |
sentence := c.GenerateSentence() | |
words := len(strings.Fields(sentence)) | |
fmt.Println(sentence) | |
i += words | |
} | |
} | |
} | |
// Prefix is a Markov chain prefix of one or more words. | |
type Prefix []string | |
// String returns the Prefix as a string (for use as a map key). | |
func (p Prefix) String() string { | |
return strings.Join(p, " ") | |
} | |
// Shift removes the first word from the Prefix and appends the given word. | |
func (p Prefix) Shift(word string) { | |
copy(p, p[1:]) | |
p[len(p)-1] = word | |
} | |
// Chain contains a map ("chain") of prefixes to a list of suffixes. | |
// A prefix is a string of prefixLen words joined with spaces. | |
// A suffix is a single word. | |
// A prefix can have multiple suffixes and they are all stored in a map along with their frequency. | |
type Chain struct { | |
chain map[string]map[string]int | |
prefixLen int | |
beginningChoices []string | |
beginningCumdist []int | |
} | |
// NewChain returns a new Chain with prefixes of prefixLen words. | |
func NewChain(prefixLen int) *Chain { | |
return &Chain{chain: make(map[string]map[string]int), prefixLen: prefixLen} | |
} | |
// Build reads text from the provided Corpus and | |
// parses it into prefixes and suffixes that are stored in Chain. | |
func (c *Chain) Build(corpus Corpus) { | |
for _, s := range corpus { | |
words := strings.Fields(s.Text) | |
beginning := createBeginning(c) | |
sentence := append(beginning, words...) | |
sentence = append(sentence, END) | |
for j := 0; j < len(sentence)-c.prefixLen; j++ { | |
state := sentence[j : j+c.prefixLen].String() | |
next := sentence[j+c.prefixLen] | |
if c.chain[state] == nil { | |
c.chain[state] = make(map[string]int) | |
} | |
c.chain[state][next]++ | |
} | |
} | |
} | |
// Move returns the next transition from the current state at random based on the weighted probability. | |
func (c *Chain) Move(p Prefix) (string, error) { | |
var choices []string | |
var cumdist []int | |
if p.String() == createBeginning(c).String() { | |
choices = c.beginningChoices | |
cumdist = c.beginningCumdist | |
} else { | |
vs := c.chain[p.String()] | |
if vs == nil { | |
return "", errors.New("Prefix not found") | |
} | |
choices = make([]string, 0, len(vs)) | |
weights := make([]int, 0, len(vs)) | |
for k, v := range vs { | |
choices = append(choices, k) | |
weights = append(weights, v) | |
} | |
cumdist = accumulate(weights) | |
} | |
r := rand.Intn(cumdist[len(cumdist)-1]) | |
i := sort.Search(len(cumdist), func(i int) bool { return cumdist[i] >= r }) | |
next := choices[i] | |
return next, nil | |
} | |
// GenerateSentence returns a sentence generated from Chain. | |
func (c *Chain) GenerateSentence() string { | |
p := createBeginning(c) | |
var words []string | |
for { | |
next, e := c.Move(p) | |
if e != nil { | |
return strings.Join(words, " ") | |
} | |
if next == END { | |
break | |
} | |
words = append(words, next) | |
p.Shift(next) | |
} | |
return strings.Join(words, " ") | |
} | |
// GenerateInteractiveSentence returns a sentence generated from Chain and user input. | |
func (c *Chain) GenerateInteractiveSentence(hints bool) string { | |
p := createBeginning(c) | |
var words []string | |
turn := "computer" | |
loop: | |
for { | |
switch turn { | |
case "computer": | |
next, e := c.Move(p) | |
if e != nil { | |
fmt.Println("\n### Hit a dead end! ###") | |
return strings.Join(words, " ") | |
} | |
if next == END { | |
break loop | |
} | |
words = append(words, next) | |
p.Shift(next) | |
turn = "user" | |
case "user": | |
optionsDict := c.chain[p.String()] | |
options := make([]string, 0, len(optionsDict)) | |
for k := range optionsDict { | |
options = append(options, k) | |
} | |
if hints { | |
fmt.Printf("Possible choices are: \n%v\n\n", strings.Join(options, "|")) | |
} | |
fmt.Printf("%v ... ", strings.Join(words, " ")) | |
var userWord string | |
fmt.Scan(&userWord) | |
words = append(words, userWord) | |
p.Shift(userWord) | |
turn = "computer" | |
} | |
} | |
return strings.Join(words, " ") | |
} | |
func (c *Chain) precomputeBeginning() { | |
prefix := createBeginning(c).String() | |
vs := c.chain[prefix] | |
choices := make([]string, 0, len(vs)) | |
weights := make([]int, 0, len(vs)) | |
for k, v := range vs { | |
choices = append(choices, k) | |
weights = append(weights, v) | |
} | |
cumdist := accumulate(weights) | |
c.beginningChoices = choices | |
c.beginningCumdist = cumdist | |
} | |
func createBeginning(c *Chain) Prefix { | |
beginning := make(Prefix, c.prefixLen) | |
for i := 0; i < c.prefixLen; i++ { | |
beginning[i] = BEGIN | |
} | |
return beginning | |
} | |
// Corpus a collection of sentences | |
type Corpus []*sentences.Sentence | |
// NewCorpus creates a new corpus from a io.Reader | |
func NewCorpus(r io.Reader) Corpus { | |
b, err := ioutil.ReadAll(r) | |
if err != nil { | |
panic(err) | |
} | |
text := string(b) | |
tokenizer, err := english.NewSentenceTokenizer(nil) | |
if err != nil { | |
panic(err) | |
} | |
sentences := tokenizer.Tokenize(text) | |
return sentences | |
} | |
func accumulate(weights []int) []int { | |
result := make([]int, len(weights), len(weights)) | |
result[0] = weights[0] | |
for i := 1; i < len(weights); i++ { | |
result[i] = result[i-1] + weights[i] | |
} | |
return result | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment