Created
February 7, 2024 22:33
-
-
Save asford/ee688d59f0747a6507b9670a83fa7c47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.4'" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import numpy\n", | |
"import array_api_compat as aac\n", | |
"\n", | |
"aac.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = torch.arange(10)\n", | |
"n = numpy.arange(10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"numpy.add(n, 1.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.add(t, 1.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"aac.get_namespace(n).add(n, 1.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "AttributeError", | |
"evalue": "'float' object has no attribute 'dtype'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43maac\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_namespace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91\u001b[0m, in \u001b[0;36m_two_arg.<locals>._f\u001b[0;34m(x1, x2, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_f\u001b[39m(x1, x2, \u001b[38;5;241m/\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 91\u001b[0m x1, x2 \u001b[38;5;241m=\u001b[39m \u001b[43m_fix_promotion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m f(x1, x2, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", | |
"File \u001b[0;32m~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104\u001b[0m, in \u001b[0;36m_fix_promotion\u001b[0;34m(x1, x2, only_scalar)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_fix_promotion\u001b[39m(x1, x2, only_scalar\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[0;32m--> 104\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x1\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m _array_api_dtypes \u001b[38;5;129;01mor\u001b[39;00m \u001b[43mx2\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m _array_api_dtypes:\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x1, x2\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# If an argument is 0-D pytorch downcasts the other argument\u001b[39;00m\n", | |
"\u001b[0;31mAttributeError\u001b[0m: 'float' object has no attribute 'dtype'" | |
] | |
} | |
], | |
"source": [ | |
"aac.get_namespace(t).add(t, 1.0)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "base", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment