Skip to content

Instantly share code, notes, and snippets.

@SumanSudhir
Created January 15, 2020 11:52
Show Gist options
  • Save SumanSudhir/5bf57c330e88c2818a30a6cbaee12704 to your computer and use it in GitHub Desktop.
Save SumanSudhir/5bf57c330e88c2818a30a6cbaee12704 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "swift_crash_example.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "swift",
"display_name": "Swift"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "kZRlD4utdPuX",
"colab_type": "code",
"colab": {}
},
"source": [
"import TensorFlow"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "S1SxQfK_xkyr",
"colab_type": "code",
"colab": {}
},
"source": [
"public extension Tensor where Scalar: TensorFlowFloatingPoint {\n",
"\n",
" /// Helper function that assess if `axis` is in the range `[-rank, rank)`, where `rank` is the rank of\n",
" /// the provided tensors.\n",
" @inlinable\n",
" internal func isAxisInRange(_ axis: Int) -> Bool {\n",
" return axis >= -rank && axis < rank\n",
" }\n",
"\n",
" @inlinable\n",
" internal func areAxesInRange(_ axes: Tensor<Int32>) -> Bool {\n",
" return !axes.scalars.contains(where: { !isAxisInRange(Int($0)) })\n",
" }\n",
"\n",
" /// Returns the mean and variance of this tensor along the specified axes. The reduced\n",
" /// dimensions are removed.\n",
" ///\n",
" /// - Parameter axes: The dimensions to reduce.\n",
" /// - Precondition: `axes` must have rank `1`.\n",
" /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.\n",
" @inlinable\n",
" @differentiable(wrt: self)\n",
" func momentsNew(squeezingAxes axes: Tensor<Int32>) -> Moments<Scalar> {\n",
" precondition(axes.rank == 1, \"Axes must have rank 1\")\n",
" precondition(\n",
" areAxesInRange(axes),\n",
" \"\"\"\n",
" The axis must be in the range [-rank, rank)\n",
" of the provided tensors.\n",
" \"\"\")\n",
" let mean = self.mean(alongAxes: axes)\n",
" let variance = squaredDifference(self, mean).mean(squeezingAxes: axes)\n",
" return Moments(\n",
" // The following is required because `Tensor.squeezingShape(at:)` does not accept\n",
" // `Tensor<Int32>`-valued arguments.\n",
" mean: mean.sum(squeezingAxes: axes),\n",
" variance: variance)\n",
" }\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sFRhaug8xuso",
"colab_type": "code",
"colab": {}
},
"source": [
"var input: Tensor<Float32> = Tensor([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0], [5.0, 6.0, 7.0]])\n",
"var axis: Tensor<Int32> = Tensor([1, 1, 0])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qdB5pGVvxw-I",
"colab_type": "code",
"outputId": "763095aa-8a11-404c-87f8-e50c23db266e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"var x = input.momentsNew(squeezingAxes: axis)\n",
"x"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"▿ Moments<Float>\n",
" - mean : 4.3333335\n",
" - variance : 3.5555556\n"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_ajsiTY0xzhZ",
"colab_type": "code",
"outputId": "4404336b-1b57-437b-893f-67a1c04fd5aa",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
}
},
"source": [
"var axis: Tensor<Int32> = Tensor([1, 1, 2])\n",
"var x = input.momentsNew(squeezingAxes: axis)\n",
"x"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Precondition failed: The axis must be in the range [-rank, rank)\r\n",
"of the provided tensors.: file <Cell 2>, line 25\r\n",
"Current stack trace:\r\n",
"0 libswiftCore.so 0x00007f791d53dcb0 swift_reportError + 50\r\n",
"1 libswiftCore.so 0x00007f791d5af5f0 _swift_stdlib_reportFatalErrorInFile + 115\r\n",
"2 libswiftCore.so 0x00007f791d2afe1e <unavailable> + 1478174\r\n",
"3 libswiftCore.so 0x00007f791d2afa27 <unavailable> + 1477159\r\n",
"4 libswiftCore.so 0x00007f791d2b0008 <unavailable> + 1478664\r\n",
"5 libswiftCore.so 0x00007f791d2ae2c0 _assertionFailure(_:_:file:line:flags:) + 517\r\n",
"8 repl_swift 0x0000000000400480 <unavailable> + 1152\r\n",
"9 libswiftCore.so 0x00007f791d54e2f0 <unavailable> + 4223728\r\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "G16VJj2Bx2RX",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment