import geoopt def spherical_avg(p, w=None, tol=1e-6): sphere = geoopt.Sphere() if w is None: w = p.new_ones([p.shape[0]]) assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w) w = w / w.sum() p = sphere.projx(p) q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0)) while True: q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0)) norm = q.sub(q_new).norm() q = q_new if norm <= tol: break return q