Skip to content

Instantly share code, notes, and snippets.

@asford
Created February 7, 2024 22:33
Show Gist options
  • Save asford/ee688d59f0747a6507b9670a83fa7c47 to your computer and use it in GitHub Desktop.
Save asford/ee688d59f0747a6507b9670a83fa7c47 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.4'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import numpy\n",
"import array_api_compat as aac\n",
"\n",
"aac.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"t = torch.arange(10)\n",
"n = numpy.arange(10)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.add(n, 1.0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.add(t, 1.0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"aac.get_namespace(n).add(n, 1.0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'float' object has no attribute 'dtype'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43maac\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_namespace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91\u001b[0m, in \u001b[0;36m_two_arg.<locals>._f\u001b[0;34m(x1, x2, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_f\u001b[39m(x1, x2, \u001b[38;5;241m/\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 91\u001b[0m x1, x2 \u001b[38;5;241m=\u001b[39m \u001b[43m_fix_promotion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m f(x1, x2, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"File \u001b[0;32m~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104\u001b[0m, in \u001b[0;36m_fix_promotion\u001b[0;34m(x1, x2, only_scalar)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_fix_promotion\u001b[39m(x1, x2, only_scalar\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[0;32m--> 104\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x1\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m _array_api_dtypes \u001b[38;5;129;01mor\u001b[39;00m \u001b[43mx2\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m _array_api_dtypes:\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x1, x2\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# If an argument is 0-D pytorch downcasts the other argument\u001b[39;00m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'float' object has no attribute 'dtype'"
]
}
],
"source": [
"aac.get_namespace(t).add(t, 1.0)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment