This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def expand_path(path: PredictedPath, | |
max_branch: int = 3) -> list[PredictedPath]: | |
successors = get_successors(path.array[path.step-2], path.array[path.step-1]) | |
best_expansions = [] | |
total = successors.total() | |
for p in successors.most_common(max_branch): | |
new_probability = p[1] / total | |
new_path = evolve_path(path, p[0], new_probability) | |
best_expansions.append(new_path) | |
return best_expansions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def expand_seed(h0: int, h1: int, | |
max_branch: int = 3, | |
max_length: int = 10) -> list[PredictedPath]: | |
final = [] | |
seed = PredictedPath.from_seed(h0, h1, max_length) | |
paths = expand_path(seed, max_branch) | |
while len(final) < max_branch: | |
expanded = [] | |
for p in paths: | |
expanded.extend(expand_path(p, max_branch=max_branch)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PredictedPath(): | |
def __init__(self, | |
probability: float = 1.0, | |
step: int = 1, | |
size: int = 0): | |
self.probability = probability | |
self.step = step | |
self.size = size | |
self.array: np.ndarray = np.zeros(size, dtype=int) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(max_branch: int = 3, | |
max_length: int = 10) -> FeatureGroup | None: | |
fg = folium.FeatureGroup(name="polylines") | |
if "token_list" in st.session_state: | |
token_list = st.session_state["token_list"] | |
seed = token_list[-3:-1] | |
if len(seed) > 1: | |
paths = expand_seed(seed[0], seed[1], | |
max_branch=max_branch, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(max_branch: int = 3, | |
max_length: int = 10) -> FeatureGroup | None: | |
fg = folium.FeatureGroup(name="polylines") | |
if "token_list" in st.session_state: | |
hex_list = st.session_state["token_list"] | |
seed = hex_list[-3:-1] | |
if len(seed) > 1: | |
paths = expand_seed(seed[0], seed[1], | |
max_branch=max_branch, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_successors(h0: int, h1: int) -> Counter: | |
cnt = get_cache_successors(h0, h1) | |
if cnt is None: | |
db = TrajDb() | |
sql = "select t2 from triple where t0=? and t1=?" | |
successors = [r[0] for r in db.query(sql, [int(h0), int(h1)])] | |
cnt = Counter(successors) | |
set_cache_successors(h0, h1, cnt) | |
return cnt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compute_probability(token_list: list[int]) -> float: | |
nodes = token_list[1:-1] | |
prob = 0.0 | |
if len(nodes) > 2: | |
prob = 1.0 | |
for i in range(len(nodes)-2): | |
t0, t1, t2 = nodes[i:i+3] | |
cnt = get_successors(int(t0), int(t1)) | |
if len(cnt): | |
prob *= cnt[t2] / cnt.total() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def insert_h3_nodes(h3_nodes: list[tuple[int,tuple[float,float]]]): | |
db = TrajDb() | |
sql = "insert or ignore into h3_node (h3, lat, lon) values (?, ?, ?)" | |
db.execute_sql(sql, [(n[0], n[1][1], n[1][0]) for n in h3_nodes], many=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def insert_triples(traj_id: int, | |
triples: list[(int,int,int)]): | |
db = TrajDb() | |
sql = "insert into triple (traj_id, t0, t1, t2) values (?, ?, ?, ?)" | |
params = [(traj_id, t0, t1, t2) for t0, t1, t2 in triples] | |
db.execute_sql(sql, params, many=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def insert_h3(traj_id: int, | |
h3_list: list[int]) -> None: | |
db = TrajDb() | |
sql = "insert into traj_h3 (traj_id, h3) values (?, ?)" | |
params = [[traj_id, int(h)] for h in h3_list] | |
db.execute_sql(sql, params, many=True) |