Last active
November 15, 2019 05:39
-
-
Save allenanie/7e3a55bb8d8a26fdf0064802a78f4183 to your computer and use it in GitHub Desktop.
Amortized RSA, without fully realizing S1 Pragmatic Speaker (no need)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
utterances = ["blue", "green", "square", "circle"] | |
objects = ['blue square', 'blue circle', 'green square'] | |
def meaning(utt, obj): | |
return int(utt in obj) | |
def normalize(space): | |
denom = sum(space.values()) | |
for obj in space.keys(): | |
space[obj] /= float(denom) | |
return space | |
def cond_normalize(conditional_model): | |
for utt, space in conditional_model.items(): | |
normalize(space) | |
return conditional_model | |
def prob_model(): | |
# {utt -> {obj1: prob, obj2: prob, ...}} | |
conditional_model = defaultdict(dict) | |
for utt in utterances: | |
for obj in objects: | |
conditional_model[utt][obj] = meaning(utt, obj) | |
return cond_normalize(conditional_model) | |
def get_optimal(cond_model, utt): | |
# utt -> {obj: prob, obj2: prob} | |
opt_model = {} | |
# Complexity: |S| x |U| | |
# Can cache denom which is Z_S | |
# |S| | |
for obj in objects: | |
nom = cond_model[utt][obj] | |
denom = 0 | |
# |U| | |
for obj_set in cond_model.values(): | |
if obj in obj_set: | |
denom += obj_set[obj] | |
if denom > 0: | |
opt_model[obj] = nom / float(denom) | |
# This is S1 (without the final normlize) | |
return normalize(opt_model) | |
def get_optimal_model(cond_model): | |
# we recompute the model | |
# just like prob_model() | |
opt_conditional_model = {} | |
for utt in utterances: | |
opt_conditional_model[utt] = get_optimal(cond_model, utt) | |
return opt_conditional_model | |
print("L0:") | |
print(prob_model()) # L0 | |
print("L1:") | |
# Partial Evaluation L1 model | |
print(get_optimal(prob_model(), 'blue')) | |
print(get_optimal(prob_model(), 'square')) | |
print(get_optimal(prob_model(), 'green')) | |
print(get_optimal(prob_model(), 'circle')) | |
print("Full L1 model:") | |
# Full Evaluation L1 Model | |
print(get_optimal_model(prob_model())) | |
# Full Evaluation L2 Model | |
print("Full L2 model:") | |
print(get_optimal_model(get_optimal_model(prob_model()))) | |
# Full Evaluation L3 Model | |
print("Full L3 model:") | |
print(get_optimal_model(get_optimal_model(get_optimal_model(prob_model())))) | |
# It's probably equivalent to adjusting the rationality parameter |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment