Skip to content

Instantly share code, notes, and snippets.

@dhruvbird
dhruvbird / GetArchiveName.py
Created November 16, 2021 22:28
Get the top level folder name of the PyTorch Lite Model
elem = zf.infolist()[0]
archive_name = elem.filename.split("/")[0]
print(archive_name)
AddTensorsModel
@dhruvbird
dhruvbird / AddMetadataToPyTorchLiteModel.py
Created November 16, 2021 22:33
Add additional metadata to PyTorch Lite Model
zf2 = zipfile.ZipFile("AddTensorsModel.ptl", "a")
contents = "Both tensors can be any dtype as long as they can be added"
zf2.writestr("{}/extra/dtype_info.txt".format(archive_name), contents)
zf2.close()
@dhruvbird
dhruvbird / ListArchiveContents.py
Created November 16, 2021 22:35
Print the contents of the model file and the contents of the newly added metadata file
zf = zipfile.ZipFile("AddTensorsModel.ptl")
zf.infolist()
[<ZipInfo filename='AddTensorsModel/extra/model_info.txt' file_size=70>,
<ZipInfo filename='AddTensorsModel/data.pkl' file_size=86>,
<ZipInfo filename='AddTensorsModel/code/__torch__.py' compress_type=deflate file_size=247 compress_size=166>,
<ZipInfo filename='AddTensorsModel/code/__torch__.py.debug_pkl' file_size=145>,
<ZipInfo filename='AddTensorsModel/constants.pkl' file_size=4>,
<ZipInfo filename='AddTensorsModel/bytecode.pkl' file_size=452>,
<ZipInfo filename='AddTensorsModel/version' file_size=2>,
@dhruvbird
dhruvbird / FetchModelMetadata.cpp
Created November 16, 2021 23:37
Fetch metadata only using PyTorch Lite Model C++ API
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
int main() {
std::string model_path = "AddTensorsModel.ptl";
torch::jit::ExtraFilesMap extra_files;
// Add keys into extra_files to indicate which metadata files
// we wish to fetch.
extra_files["model_info.txt"] = "";
include_all_non_op_selectives: false
build_features: []
operators:
aten::__getitem__.t:
is_used_for_training: false
is_root_operator: true
include_all_overloads: false
aten::_set_item.str:
is_used_for_training: false
is_root_operator: true
@dhruvbird
dhruvbird / conv2d_linear_equivalence.py
Created June 20, 2023 14:44
Showing equivalence of nn.Conv2d and nn.Linear
import torch
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
n_out_channels = 1
mat = torch.arange(0, 36).reshape((6, 6)).float()
print(f"mat:\n{mat}")
patches = []