|
package main |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"log" |
|
"net/http" |
|
"os" |
|
"path" |
|
"time" |
|
|
|
"github.com/intel-go/fastjson" |
|
|
|
"github.com/mattn/go-w2v" |
|
"github.com/osamingo/jsonrpc" |
|
) |
|
|
|
type Word2VecSimilarityResult struct { |
|
Score float64 `json:"score"` |
|
Token string `json:"token"` |
|
} |
|
|
|
type ( |
|
Word2VecQueryHandler struct{} |
|
Word2VecQueryParams struct { |
|
Positive []string `json:"positive"` |
|
Negative []string `json:"negative"` |
|
} |
|
Word2VecQueryResult struct { |
|
Status string `json:"status"` |
|
Result []Word2VecSimilarityResult `json:"result"` |
|
} |
|
) |
|
|
|
func (h Word2VecQueryHandler) ServeJSONRPC(c context.Context, params *fastjson.RawMessage) (interface{}, *jsonrpc.Error) { |
|
var p Word2VecQueryParams |
|
if err := jsonrpc.Unmarshal(params, &p); err != nil { |
|
return nil, err |
|
} |
|
|
|
result := GetWord2VecSimilarityResult(MainModel, p.Positive, p.Negative, 10) |
|
return Word2VecQueryResult{ |
|
Status: "ok", |
|
Result: result, |
|
}, nil |
|
} |
|
|
|
func Now() float64 { |
|
return float64(time.Now().UnixNano()) / 1e9 |
|
} |
|
|
|
func GetWord2VecSimilarityResult(model *w2v.Model, positiveTokenColl []string, negativeTokenColl []string, limit int) []Word2VecSimilarityResult { |
|
var queryVec *w2v.Vector |
|
for _, posToken := range positiveTokenColl { |
|
vec := model.Find(posToken) |
|
if queryVec == nil { |
|
queryVec = vec |
|
} else { |
|
queryVec = queryVec.Add(vec) |
|
} |
|
} |
|
for _, negToken := range negativeTokenColl { |
|
vec := model.Find(negToken) |
|
if queryVec == nil { |
|
queryVec = vec |
|
} else { |
|
queryVec = queryVec.Sub(vec) |
|
} |
|
} |
|
|
|
t0 := Now() |
|
similarEntryColl := model.CosineSimilars(queryVec) |
|
tN := Now() |
|
fmt.Printf("OK: (elapsed %fs) search complete\n", tN-t0) |
|
|
|
nItem := len(similarEntryColl) |
|
var nTake int |
|
if nItem < limit { |
|
nTake = nItem |
|
} else { |
|
nTake = limit |
|
} |
|
out := make([]Word2VecSimilarityResult, 0) |
|
for _, entry := range similarEntryColl[:nTake] { |
|
out = append(out, Word2VecSimilarityResult{ |
|
Score: entry.Value, |
|
Token: entry.Vector.Word(), |
|
}) |
|
} |
|
return out |
|
} |
|
|
|
var MainModel *w2v.Model |
|
|
|
func main() { |
|
modelPath := string(path.Join("word2vec.6B.50d.txt")) |
|
|
|
fmt.Printf("loading: %v\n", modelPath) |
|
|
|
modelFile, err := os.Open(modelPath) |
|
if err != nil { |
|
log.Fatal("Failed to open file") |
|
} |
|
|
|
var t0, tN float64 |
|
t0 = Now() |
|
model, err := w2v.LoadText(modelFile) |
|
if err != nil { |
|
log.Fatal("Failed to load model") |
|
} |
|
MainModel = &model |
|
tN = Now() |
|
fmt.Printf("OK: (elapsed %fs) loaded %v\n", tN-t0, modelPath) |
|
|
|
positives := []string{"queen", "man"} |
|
negatives := []string{"king"} |
|
|
|
fmt.Printf("positives: %v\n", positives) |
|
fmt.Printf("negatives: %v\n", negatives) |
|
|
|
myTest := GetWord2VecSimilarityResult(MainModel, positives, negatives, 5) |
|
for i, res := range myTest { |
|
fmt.Printf("sim %v: %v -- %v\n", i, res.Token, res.Score) |
|
} |
|
|
|
mr := jsonrpc.NewMethodRepository() |
|
if err := mr.RegisterMethod( |
|
"App.most_similar", |
|
Word2VecQueryHandler{}, |
|
Word2VecQueryParams{}, |
|
Word2VecQueryResult{}); err != nil { |
|
log.Fatalln(err) |
|
} |
|
|
|
http.Handle("/rpc", mr) |
|
http.HandleFunc("/rpc/debug", mr.ServeDebug) |
|
if err := http.ListenAndServe(":5002", http.DefaultServeMux); err != nil { |
|
log.Fatalln(err) |
|
} |
|
} |