Created
January 15, 2020 11:52
-
-
Save SumanSudhir/5bf57c330e88c2818a30a6cbaee12704 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "swift_crash_example.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "swift", | |
"display_name": "Swift" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kZRlD4utdPuX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import TensorFlow" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "S1SxQfK_xkyr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"public extension Tensor where Scalar: TensorFlowFloatingPoint {\n", | |
"\n", | |
" /// Helper function that assess if `axis` is in the range `[-rank, rank)`, where `rank` is the rank of\n", | |
" /// the provided tensors.\n", | |
" @inlinable\n", | |
" internal func isAxisInRange(_ axis: Int) -> Bool {\n", | |
" return axis >= -rank && axis < rank\n", | |
" }\n", | |
"\n", | |
" @inlinable\n", | |
" internal func areAxesInRange(_ axes: Tensor<Int32>) -> Bool {\n", | |
" return !axes.scalars.contains(where: { !isAxisInRange(Int($0)) })\n", | |
" }\n", | |
"\n", | |
" /// Returns the mean and variance of this tensor along the specified axes. The reduced\n", | |
" /// dimensions are removed.\n", | |
" ///\n", | |
" /// - Parameter axes: The dimensions to reduce.\n", | |
" /// - Precondition: `axes` must have rank `1`.\n", | |
" /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.\n", | |
" @inlinable\n", | |
" @differentiable(wrt: self)\n", | |
" func momentsNew(squeezingAxes axes: Tensor<Int32>) -> Moments<Scalar> {\n", | |
" precondition(axes.rank == 1, \"Axes must have rank 1\")\n", | |
" precondition(\n", | |
" areAxesInRange(axes),\n", | |
" \"\"\"\n", | |
" The axis must be in the range [-rank, rank)\n", | |
" of the provided tensors.\n", | |
" \"\"\")\n", | |
" let mean = self.mean(alongAxes: axes)\n", | |
" let variance = squaredDifference(self, mean).mean(squeezingAxes: axes)\n", | |
" return Moments(\n", | |
" // The following is required because `Tensor.squeezingShape(at:)` does not accept\n", | |
" // `Tensor<Int32>`-valued arguments.\n", | |
" mean: mean.sum(squeezingAxes: axes),\n", | |
" variance: variance)\n", | |
" }\n", | |
"}" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sFRhaug8xuso", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"var input: Tensor<Float32> = Tensor([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0], [5.0, 6.0, 7.0]])\n", | |
"var axis: Tensor<Int32> = Tensor([1, 1, 0])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qdB5pGVvxw-I", | |
"colab_type": "code", | |
"outputId": "763095aa-8a11-404c-87f8-e50c23db266e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 68 | |
} | |
}, | |
"source": [ | |
"var x = input.momentsNew(squeezingAxes: axis)\n", | |
"x" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"▿ Moments<Float>\n", | |
" - mean : 4.3333335\n", | |
" - variance : 3.5555556\n" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_ajsiTY0xzhZ", | |
"colab_type": "code", | |
"outputId": "4404336b-1b57-437b-893f-67a1c04fd5aa", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
} | |
}, | |
"source": [ | |
"var axis: Tensor<Int32> = Tensor([1, 1, 2])\n", | |
"var x = input.momentsNew(squeezingAxes: axis)\n", | |
"x" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Precondition failed: The axis must be in the range [-rank, rank)\r\n", | |
"of the provided tensors.: file <Cell 2>, line 25\r\n", | |
"Current stack trace:\r\n", | |
"0 libswiftCore.so 0x00007f791d53dcb0 swift_reportError + 50\r\n", | |
"1 libswiftCore.so 0x00007f791d5af5f0 _swift_stdlib_reportFatalErrorInFile + 115\r\n", | |
"2 libswiftCore.so 0x00007f791d2afe1e <unavailable> + 1478174\r\n", | |
"3 libswiftCore.so 0x00007f791d2afa27 <unavailable> + 1477159\r\n", | |
"4 libswiftCore.so 0x00007f791d2b0008 <unavailable> + 1478664\r\n", | |
"5 libswiftCore.so 0x00007f791d2ae2c0 _assertionFailure(_:_:file:line:flags:) + 517\r\n", | |
"8 repl_swift 0x0000000000400480 <unavailable> + 1152\r\n", | |
"9 libswiftCore.so 0x00007f791d54e2f0 <unavailable> + 4223728\r\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "G16VJj2Bx2RX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment