import geoopt


def spherical_avg(p, w=None, tol=1e-6):
    sphere = geoopt.Sphere()
    if w is None:
        w = p.new_ones([p.shape[0]])
    assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w)
    w = w / w.sum()
    p = sphere.projx(p)
    q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0))
    while True:
        q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0))
        norm = q.sub(q_new).norm()
        q = q_new
        if norm <= tol:
            break
    return q