This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mypy: allow-untyped-defs | |
# mypy: allow-untyped-decorators | |
import torch | |
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten | |
from torch.utils import module_tracker as mt | |
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator | |
from collections import defaultdict | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from math import prod | |
from functools import wraps |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The MIT License (MIT) | |
Copyright (c) 2014 by the President and Fellows of Harvard University | |
Condensed version Copyright (c) 2024 by Gavia Gray | |
This condensed version is based on the original mattjj/autodidact implementation, | |
with assistance from Claude 3.5 Sonnet. | |
Permission is hereby granted, free of charge, to any person obtaining a copy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# function to untar arxiv source files into a folder | |
untar_arxiv () { | |
# Check if the file has a .tar.gz suffix | |
if [[ $1 == *.tar.gz ]]; then | |
# Extract the base name without the .tar.gz suffix | |
base_name=$(basename "$1" .tar.gz) | |
mkdir -p "$base_name" | |
tar -xvf "$1" -C "$base_name" | |
else | |
# Add the .tar.gz suffix to the filename |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import itertools | |
from collections import OrderedDict | |
def einsum_itertools(equation, *operands, verbose=False): | |
# Parse the equation | |
input_labels, output_labels = equation.split('->') | |
input_labels = input_labels.split(',') | |
if verbose: | |
print(f"{input_labels=} {output_labels=}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# script to add gcloud instances to ssh config | |
# | |
# Usage: tpu_configssh.py <instance_name> <instance_name> ... | |
# | |
# A version of this exists in gcloud compute config-ssh but it doesn't work for TPU VMs | |
# | |
# Works by parsing the output of dryrun mode of gcloud compute ssh, example: | |
# $ gcloud alpha compute tpus tpu-vm ssh instance-name --dry-run | |
# /usr/bin/ssh -t -i /home/user/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=<alias> -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=/home/user/.ssh/google_compute_known_hosts user@IP |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# autocomplete nvim scp commands | |
nvim () { | |
local params=(); | |
while [[ ! -z $1 ]]; do | |
if [[ "$1" =~ ^[a-z0-9-]*:/.*$ ]]; then | |
params=("scp://${1/:\//\/\//}" "${params[@]}"); | |
else | |
params+=("$1"); | |
fi; | |
shift; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2020 The Google AI Perception Team Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from einops import rearrange, repeat, reduce | |
def relation(input, g, embedding=None, max_pairwise=None): | |
r"""Applies an all-to-all pairwise relation function to a set of objects. | |
See :class:`~torch.nn.Relation` for details. | |
""" | |
# Batch size, number of objects, feature size | |
b, o, c = input.size() | |
# Create pairwise matrix |
NewerOlder