This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install the tsfm library | |
! pip install "tsfm_public[notebooks] @ git+https://github.com/ibm-granite/[email protected]" -U | |
!pip install numpy==1.26.4 | |
!pip install --force-reinstall pandas | |
# Install the tsfm library | |
! pip install "tsfm_public[notebooks] @ git+https://github.com/ibm-granite/[email protected]" -U | |
# Reinstall NumPy | |
!pip install numpy==1.26.4 | |
# Reinstall pandas | |
!pip install --force-reinstall pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
M = ((P * 4B) / (32 / Q)) * 1.2 | |
Where: | |
- M = Mem needed in GB | |
- P = num of params | |
- 4B = 4 bytes per parameter (@32 quant) | |
- 32 = bits in 4 bytes | |
- Q = model quantization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
def visualize_positional_encoding(seq_length=83, d_model=32): | |
pos = torch.arange(0, seq_length).unsqueeze(-1).float() | |
pos_encoding = torch.zeros(seq_length, d_model) | |
pos_encoding[:, 0::2] = torch.sin(pos / 10000 ** (torch.arange(0, d_model, 2) / d_model)) | |
pos_encoding[:, 1::2] = torch.cos(pos / 10000 ** (torch.arange(1, d_model, 2) / d_model)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
import random | |
from GADEv2_2.util.dataset_loaders.arc_dataset.loader import ARCDataLoader | |
from GADEv2_2.networks.multi_dim_trial import MultiDimTrial | |
class ZeTaskMan: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class MultiDimTrial(nn.Module): | |
def __init__(self, input_shape, output_shape, num_demos, mode='train'): | |
super(MultiDimTrial, self).__init__() | |
self.name = 'MultiDimTrial' | |
self.mode = mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
class ARCDataset(Dataset): | |
def __init__(self, challenges_file, solutions_file=None): | |
self.tasks = [] | |
# Load challenges |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import math | |
import random | |
import copy | |
XOR_INPUTS = [ | |
([0, 0], 0), | |
([0, 1], 1), | |
([1, 0], 1), | |
([1, 1], 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import nbformat | |
import ast | |
# Determine the absolute path to the GADEv2_2 root based on this script's location | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
repo_path = os.path.abspath(os.path.join(script_dir, "../..")) # Two levels up to the GADEv2_2 root | |
main_file = os.path.join(repo_path, "main.py") | |
def extract_imports_from_main(main_file): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import yfinance as yf | |
# FRED API Key (you need to get your own from https://fred.stlouisfed.org/) | |
FRED_API_KEY = 'YOUR_KEY_HERE' | |
# Define the URL for FRED (Federal Funds Rate) | |
FRED_URL = f'https://api.stlouisfed.org/fred/series/observations?series_id=FEDFUNDS&api_key={FRED_API_KEY}&file_type=json' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="../node_modules/steal/steal.js"> | |
window.rerun = true; | |
function rerunIfSuccessful() { | |
var result = document.getElementById('qunit-testresult'); | |
var failed = result.getElementsByClassName('failed'); | |
if (!failed.length) { | |
setTimeout(rerunIfSuccessful, 4000); | |
} else if (failed[0].innerHTML === "0" && window.rerun) { |