Skip to content

Instantly share code, notes, and snippets.

@awni
Created August 28, 2021 15:31
Show Gist options
  • Save awni/f07022805ebfa6356a623bc3c9b2696c to your computer and use it in GitHub Desktop.
Save awni/f07022805ebfa6356a623bc3c9b2696c to your computer and use it in GitHub Desktop.
interspeech_2021_tutorial.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "interspeech_2021_tutorial.ipynb",
"provenance": [],
"collapsed_sections": [
"LDYlSHEh257E",
"pFfMvNDV2fNO",
"Ih5pGEmkIqyn",
"RqkmUF3EH6x2",
"TyZtbNWHSbHa",
"qwek5IMySv79",
"Sk4nxM6_TU50",
"ASMHyHHMU_gO",
"X7ahg2VgVTJn",
"of4K8SklVsAJ",
"z30dUp8g6sNl",
"PiFDxqnf7Je_",
"S_K8Zwoy92XH"
],
"authorship_tag": "ABX9TyPL53chhUJ6BFxmm4djd64Q",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/awni/f07022805ebfa6356a623bc3c9b2696c/interspeech_2021_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gXAUkdEd1PaQ"
},
"source": [
"## An Introduction to Automatic Differentiation with Weighted Finite-State Automata\n",
"\n",
"This is the companion notebook to the Interspeech 2021 tutorial.\n",
"\n",
"More resources on weighted automata:\n",
"\n",
"- [An Introduction to Weighted Automata in Machine Learning\n",
"](https://www.awnihannun.com/writing/automata_ml.html)\n",
"- The [GTN codebase](https://github.com/gtn-org/gtn)\n",
"- The [GTN documentation](https://gtn.readthedocs.io/en/latest/)\n",
"- [Differentiable Weighted Finite-State Transducers\n",
"](https://arxiv.org/abs/2010.01003): A research paper with some cutting-edge case studies\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LDYlSHEh257E"
},
"source": [
"## Setup\n",
"\n",
"Before you begin, install GTN and import the relevant packages."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Dq0iHBLVrp0O",
"outputId": "82cb1ac0-524c-4dfc-8695-fea5a27f8443"
},
"source": [
"!pip install gtn"
],
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: gtn in /usr/local/lib/python3.7/dist-packages (0.0.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QFC-dFF5GrSW"
},
"source": [
"import gtn\n",
"from IPython.display import display, SVG\n",
"import numpy as np\n",
"import tempfile\n",
"\n",
"def draw(g, isymbols={}, osymbols={}):\n",
" with tempfile.NamedTemporaryFile(suffix=\".svg\") as f:\n",
" gtn.draw(g, f.name, isymbols=isymbols, osymbols=osymbols)\n",
" display(SVG(filename=f.name))"
],
"execution_count": 48,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFfMvNDV2fNO"
},
"source": [
"## Acceptors and Transducers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ih5pGEmkIqyn"
},
"source": [
"### Acceptors\n",
"\n",
"First let's construct a basic acceptor."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "Yb5qQuT9sVc2",
"outputId": "4e975523-5eb0-48b1-ba9c-79808471eafe"
},
"source": [
"# Define the mapping from integer ids to arc symbol labels\n",
"symbols = {0: 'a', 1: 'b', 2: 'c'}\n",
"\n",
"# Make a graph\n",
"fsa = gtn.Graph()\n",
"\n",
"# Add a start node\n",
"fsa.add_node(start=True)\n",
"\n",
"# Add an accepting node\n",
"fsa.add_node(accept=True) \n",
"\n",
"# Add an internal node\n",
"fsa.add_node() \n",
"\n",
"# Add an arc from node 0 to 2 with label 0\n",
"fsa.add_arc(src_node=0, dst_node=2, label=0)\n",
"\n",
"# Add an arc from node 0 to 2 with input label 1 and output label 1\n",
"fsa.add_arc(src_node=0, dst_node=2, ilabel=1, olabel=1)\n",
"\n",
"# Add an arc from node 2 to 1 with input label 0, output label 0 and weight 2\n",
"fsa.add_arc(src_node=2, dst_node=1, ilabel=0, olabel=0, weight=2)\n",
"\n",
"draw(fsa, isymbols=symbols)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 221.00 52.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 217,-48 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node2\">\n<title>2</title>\n<ellipse cx=\"103\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-18.3\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;2</title>\n<path d=\"M36.0263,-22C47.2957,-22 62.0476,-22 74.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-25.5001 84.9997,-22 74.9996,-18.5001 74.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;2</title>\n<path d=\"M32.7129,-11.3344C38.1808,-7.9798 44.6133,-4.7072 51,-3 59.158,-.8193 61.842,-.8193 70,-3 73.1933,-3.8536 76.3981,-5.0986 79.4881,-6.5494\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"77.83,-9.6318 88.2871,-11.3344 81.1742,-3.4822 77.83,-9.6318\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-6.8\">b/0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node3\">\n<title>1</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">1</text>\n</g>\n<!-- 2&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;1</title>\n<path d=\"M121.2337,-22C132.0103,-22 145.9708,-22 158.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.7317,-25.5001 168.7317,-22 158.7317,-18.5001 158.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-25.8\">a/2</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jM47_PKzFVIJ"
},
"source": [
"Acceptors can have multiple start nodes and multiple acept nodes."
]
},
{
"cell_type": "code",
"metadata": {
"id": "FNovcmj50LoA",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 201
},
"outputId": "1a0aa5ab-9e64-4c51-d1eb-063f23ff01f7"
},
"source": [
"fsa = gtn.Graph()\n",
"\n",
"# Graphs can have multiple start-nodes\n",
"fsa.add_node(start=True)\n",
"fsa.add_node(start=True)\n",
"\n",
"fsa.add_node()\n",
"\n",
"# Graphs can also have multiple accept nodes\n",
"fsa.add_node(accept=True)\n",
"fsa.add_node(accept=True)\n",
"\n",
"# Start nodes can have incoming arcs\n",
"fsa.add_arc(src_node=0, dst_node=1, label=1)\n",
"\n",
"fsa.add_arc(src_node=0, dst_node=2, label=0)\n",
"fsa.add_arc(src_node=1, dst_node=3, label=0)\n",
"fsa.add_arc(src_node=2, dst_node=3, label=1)\n",
"fsa.add_arc(src_node=2, dst_node=3, label=0)\n",
"fsa.add_arc(src_node=2, dst_node=4, label=2)\n",
"\n",
"# Accept nodes can have outgoing arcs\n",
"fsa.add_arc(src_node=3, dst_node=4, label=1)\n",
"\n",
"# Set the arc weights\n",
"fsa.set_weights([1, 1, 1, 3, 1, 2, 2])\n",
"\n",
"draw(fsa, isymbols=symbols)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"135pt\" viewBox=\"0.00 0.00 315.00 135.00\" width=\"315pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 131)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-131 311,-131 311,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-77\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-73.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-109\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-105.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M35.201,-83.4757C47.0299,-87.9289 62.9699,-93.9298 76.422,-98.9942\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"75.5227,-102.3954 86.1146,-102.6432 77.9891,-95.8442 75.5227,-102.3954\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-98.8\">b/1</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"103\" cy=\"-50\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-46.3\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;2</title>\n<path d=\"M34.5929,-69.363C39.7818,-67.1373 45.5707,-64.8179 51,-63 58.7307,-60.4115 67.2603,-58.0721 75.1365,-56.1146\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"76.0724,-59.4897 84.9921,-53.7723 74.4538,-52.6794 76.0724,-59.4897\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-66.8\">a/1</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"192\" cy=\"-69\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-69\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-65.3\">3</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;3</title>\n<path d=\"M119.7402,-101.4763C131.7685,-96.0703 148.302,-88.6396 162.5187,-82.25\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.2207,-85.3224 171.907,-78.0306 161.3511,-78.9376 164.2207,-85.3224\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-96.8\">a/1</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge4\">\n<title>2-&gt;3</title>\n<path d=\"M121.0105,-53.8449C132.2326,-56.2407 146.9884,-59.3908 160.1386,-62.1981\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.732,-65.6901 170.2424,-64.3551 161.1935,-58.8444 159.732,-65.6901\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-63.8\">b/3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge5\">\n<title>2-&gt;3</title>\n<path d=\"M119.2234,-41.5314C129.6316,-37.2448 143.332,-33.7581 155,-38 159.9913,-39.8146 164.7789,-42.6306 169.1578,-45.841\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"167.0833,-48.6671 177.0207,-52.3411 171.5433,-43.272 167.0833,-48.6671\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-41.8\">a/1</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"285\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"285\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"285\" y=\"-18.3\">4</text>\n</g>\n<!-- 2&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge6\">\n<title>2-&gt;4</title>\n<path d=\"M118.8279,-41.0741C124.1799,-38.2655 130.2564,-35.2999 136,-33 150.6224,-27.1448 154.4196,-25.3129 170,-23 197.6201,-18.8998 229.3832,-18.981 252.5197,-19.8987\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"252.6585,-23.4091 262.8124,-20.3844 252.9885,-16.4168 252.6585,-23.4091\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-26.8\">c/2</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge7\">\n<title>3-&gt;4</title>\n<path d=\"M211.7229,-59.0325C224.6694,-52.4896 241.8168,-43.8237 256.2689,-36.52\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"257.9592,-39.5874 265.3055,-31.9531 254.8018,-33.3399 257.9592,-39.5874\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"238.5\" y=\"-53.8\">b/2</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xh6xa4jhHNFZ"
},
"source": [
"ϵ transitions are also allowed."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "h8axK0KfG2aA",
"outputId": "89834a83-e0d7-4be1-9d9a-5f3d52022022"
},
"source": [
"# We also allow ϵ transitions in acceptors\n",
"fsa = gtn.Graph()\n",
"fsa.add_node(start=True)\n",
"fsa.add_node()\n",
"fsa.add_node(accept=True)\n",
"\n",
"fsa.add_arc(src_node=0, dst_node=1, label=0)\n",
"fsa.add_arc(src_node=0, dst_node=1, label=gtn.epsilon)\n",
"fsa.add_arc(src_node=1, dst_node=2, label=1)\n",
"draw(fsa, isymbols=symbols)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 221.00 52.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 217,-48 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M32.7129,-11.3344C38.1808,-7.9798 44.6133,-4.7072 51,-3 58.7286,-.9341 61.2714,-.9341 69,-3 72.1933,-3.8536 75.3981,-5.0986 78.4881,-6.5494\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"76.83,-9.6318 87.2871,-11.3344 80.1742,-3.4822 76.83,-9.6318\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-6.8\">ε/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;2</title>\n<path d=\"M120.0105,-22C131.1524,-22 145.7778,-22 158.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.9179,-25.5001 168.9178,-22 158.9178,-18.5001 158.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144.5\" y=\"-25.8\">b/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RqkmUF3EH6x2"
},
"source": [
"### Transducers"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "_qPrsqA1H_KK",
"outputId": "981fc988-a630-43eb-8a4f-c24fb7ee9be8"
},
"source": [
"# Define the mappings from integer ids to arc input and output labels\n",
"isymbols = {0: 'a', 1: 'b', 2: 'c'}\n",
"osymbols = {0: 'x', 1: 'y', 2: 'z'}\n",
"\n",
"fst = gtn.Graph()\n",
"\n",
"fst.add_node(start=True)\n",
"fst.add_node()\n",
"fst.add_node(accept=True)\n",
"\n",
"# Adding an arc with just an input label, the output label defaults to have\n",
"# the same value as the input label\n",
"fst.add_arc(src_node=0, dst_node=1, label=0)\n",
"\n",
"# Add an arc from node 0 to 2 with the same input and output label index of 1\n",
"fst.add_arc(src_node=0, dst_node=1, ilabel=1, olabel=1, weight=2)\n",
"\n",
"# However, we can add an arc with a different input and output label.\n",
"fst.add_arc(src_node=1, dst_node=2, ilabel=1, olabel=2, weight=3)\n",
"\n",
"draw(fst, isymbols=isymbols, osymbols=osymbols)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 244.00 52.00\" width=\"244pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 240,-48 240,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"115\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"115\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.245,-22C50.4709,-22 70.479,-22 86.7679,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"86.8369,-25.5001 96.8369,-22 86.8368,-18.5001 86.8369,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-25.8\">a:x/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M32.7129,-11.3344C38.1808,-7.9798 44.6133,-4.7072 51,-3 64.3104,.558 68.6896,.558 82,-3 85.1933,-3.8536 88.3981,-5.0986 91.4881,-6.5494\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"89.83,-9.6318 100.2871,-11.3344 93.1742,-3.4822 89.83,-9.6318\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-6.8\">b:y/2</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"214\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"214\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;2</title>\n<path d=\"M133.1581,-22C146.7122,-22 165.5923,-22 181.6984,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"181.7979,-25.5001 191.7979,-22 181.7978,-18.5001 181.7979,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"162.5\" y=\"-25.8\">b:z/3</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b0t45RcPIYR7"
},
"source": [
"Transducers can have ϵ transitions on input labels, output labels, or both input and output labels."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "qTMkD9NZIOej",
"outputId": "cdaac439-84ec-438b-d20c-b28d45c552c6"
},
"source": [
"fst = gtn.Graph()\n",
"fst.add_node(start=True)\n",
"fst.add_node()\n",
"fst.add_node()\n",
"fst.add_node(accept=True)\n",
"\n",
"# The input label on an arc can be an ϵ\n",
"fst.add_arc(src_node=0, dst_node=1, ilabel=gtn.epsilon, olabel=0)\n",
"\n",
"# The output label on an arc can be an ϵ\n",
"fst.add_arc(src_node=1, dst_node=2, ilabel=1, olabel=gtn.epsilon)\n",
"\n",
"# And both an input label and the output label can be ϵ\n",
"fst.add_arc(src_node=2, dst_node=3, ilabel=gtn.epsilon, olabel=gtn.epsilon)\n",
"\n",
"draw(fst, isymbols=isymbols, osymbols=osymbols)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 336.00 52.00\" width=\"336pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 332,-48 332,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"113\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.317,-22C49.9851,-22 68.9167,-22 84.5492,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"84.7393,-25.5001 94.7393,-22 84.7392,-18.5001 84.7393,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-25.8\">ε:x/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"208\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"208\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M131.317,-22C144.9851,-22 163.9167,-22 179.5492,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"179.7393,-25.5001 189.7393,-22 179.7392,-18.5001 179.7393,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-25.8\">b:ε/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"306\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"306\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"306\" y=\"-18.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;3</title>\n<path d=\"M226.4331,-22C239.6776,-22 257.8831,-22 273.5374,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"273.8752,-25.5001 283.8751,-22 273.8751,-18.5001 273.8752,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"255\" y=\"-25.8\">ε:ε/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E41lTAHrJkS0"
},
"source": [
"## Basic Operations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TyZtbNWHSbHa"
},
"source": [
"### Closure"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "P9zSdG7dR_C6",
"outputId": "3ed5646f-db90-428c-caad-9ddc32ad5485"
},
"source": [
"# Define the mapping from integer ids to arc symbol labels:\n",
"symbols = {0: 'a', 1: 'b', 2: 'c'}\n",
"\n",
"fsa = gtn.Graph()\n",
"fsa.add_node(start=True)\n",
"fsa.add_node()\n",
"fsa.add_node()\n",
"fsa.add_node(accept=True)\n",
"fsa.add_arc(src_node=0, dst_node=1, label=0)\n",
"fsa.add_arc(src_node=1, dst_node=2, label=1)\n",
"fsa.add_arc(src_node=2, dst_node=3, label=0)\n",
"\n",
"draw(fsa, isymbols=symbols)"
],
"execution_count": 58,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 305.00 52.00\" width=\"305pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 301,-48 301,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"187\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"187\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M120.0263,-22C131.2957,-22 146.0476,-22 158.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.9997,-25.5001 168.9997,-22 158.9996,-18.5001 158.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144.5\" y=\"-25.8\">b/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"275\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"275\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"275\" y=\"-18.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;3</title>\n<path d=\"M205.2337,-22C216.0103,-22 229.9708,-22 242.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"242.7317,-25.5001 252.7317,-22 242.7317,-18.5001 242.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"229\" y=\"-25.8\">a/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 132
},
"id": "TzP0AGoHSjAv",
"outputId": "e1e80bc0-70b1-4751-c915-32dc800c3837"
},
"source": [
"fsa_closed = gtn.closure(fsa)\n",
"draw(fsa_closed, isymbols=symbols)"
],
"execution_count": 59,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"83pt\" viewBox=\"0.00 0.00 389.00 83.00\" width=\"389pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 79)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-79 385,-79 385,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"22\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<ellipse cx=\"22\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"22\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"110\" cy=\"-53\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"110\" y=\"-49.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M42.8548,-29.3466C54.8862,-33.5849 70.1453,-38.9603 83.0757,-43.5153\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"82.2453,-46.9336 92.8402,-46.9551 84.5712,-40.3312 82.2453,-46.9336\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68\" y=\"-44.8\">ε/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"194\" cy=\"-57\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"194\" y=\"-53.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M128.2267,-53.8679C139.1242,-54.3869 153.2054,-55.0574 165.5413,-55.6448\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"165.6536,-59.154 175.8088,-56.1338 165.9867,-52.162 165.6536,-59.154\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"152\" y=\"-58.8\">a/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"279\" cy=\"-54\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"279\" y=\"-50.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;3</title>\n<path d=\"M212.0263,-56.3638C223.2957,-55.966 238.0476,-55.4454 250.8373,-54.994\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"251.1294,-58.4859 260.9997,-54.6353 250.8824,-51.4903 251.1294,-58.4859\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.5\" y=\"-58.8\">b/0</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"363\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"363\" y=\"-18.3\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge4\">\n<title>3-&gt;4</title>\n<path d=\"M295.9987,-47.5243C307.612,-43.1002 323.2351,-37.1485 336.4737,-32.1053\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"337.9277,-35.2968 346.0266,-28.466 335.4357,-28.7554 337.9277,-35.2968\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"321\" y=\"-43.8\">a/0</text>\n</g>\n<!-- 4&#45;&gt;0 -->\n<g class=\"edge\" id=\"edge5\">\n<title>4-&gt;0</title>\n<path d=\"M345.4862,-16.6722C328.6204,-12.0178 302.3075,-6 279,-6 110,-6 110,-6 110,-6 90.9682,-6 69.9891,-9.6674 53.2897,-13.4905\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"52.3967,-10.1055 43.5006,-15.8597 54.0434,-16.9091 52.3967,-10.1055\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"194\" y=\"-9.8\">ε/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qwek5IMySv79"
},
"source": [
"### Union"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "MO9JPbxySsyn",
"outputId": "79d4ba5d-ce42-4d85-eeb1-aa7c80c98830"
},
"source": [
"# A graph which recognizes \"aba*\"\n",
"g1 = gtn.Graph()\n",
"g1.add_node(start=True)\n",
"g1.add_node()\n",
"g1.add_node(accept=True)\n",
"g1.add_arc(src_node=0, dst_node=1, label=0)\n",
"g1.add_arc(src_node=1, dst_node=2, label=1)\n",
"g1.add_arc(src_node=2, dst_node=2, label=0)\n",
"\n",
"# A graph which recognizes \"ba\"\n",
"g2 = gtn.Graph()\n",
"g2.add_node(start=True)\n",
"g2.add_node()\n",
"g2.add_node(accept=True)\n",
"g2.add_arc(src_node=0, dst_node=1, label=1)\n",
"g2.add_arc(src_node=1, dst_node=2, label=0)\n",
"\n",
"# A graph which recognizes \"ac\"\n",
"g3 = gtn.Graph()\n",
"g3.add_node(start=True)\n",
"g3.add_node()\n",
"g3.add_node(accept=True)\n",
"g3.add_arc(src_node=0, dst_node=1, label=0)\n",
"g3.add_arc(src_node=1, dst_node=2, label=2)\n",
"\n",
"draw(g1, isymbols=symbols)\n",
"draw(g2, isymbols=symbols)\n",
"draw(g3, isymbols=symbols)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"85pt\" viewBox=\"0.00 0.00 221.00 85.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 81)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-81 217,-81 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M120.0105,-22C131.1524,-22 145.7778,-22 158.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.9179,-25.5001 168.9178,-22 158.9178,-18.5001 158.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144.5\" y=\"-25.8\">b/0</text>\n</g>\n<!-- 2&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;2</title>\n<path d=\"M182.6298,-42.5808C181.4716,-52.8447 184.2617,-62 191,-62 195.3167,-62 198.0131,-58.2427 199.0891,-52.8436\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"202.595,-52.6729 199.3702,-42.5808 195.5976,-52.4812 202.595,-52.6729\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-65.8\">a/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 221.00 52.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 217,-48 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.0263,-22C47.2957,-22 62.0476,-22 74.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-25.5001 84.9997,-22 74.9996,-18.5001 74.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">b/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M121.2337,-22C132.0103,-22 145.9708,-22 158.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.7317,-25.5001 168.7317,-22 158.7317,-18.5001 158.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-25.8\">a/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 220.00 52.00\" width=\"220pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 216,-48 216,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"190\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"190\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M120.2337,-22C131.0103,-22 144.9708,-22 157.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"157.7317,-25.5001 167.7317,-22 157.7317,-18.5001 157.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144\" y=\"-25.8\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "u_cRsq6dS7lH",
"outputId": "a010db1e-181f-4a3e-9462-30ce657e0d16"
},
"source": [
"g = gtn.union([g1, g2, g3])\n",
"draw(g, isymbols=symbols)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"209pt\" viewBox=\"0.00 0.00 222.00 209.00\" width=\"222pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 205)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-205 218,-205 218,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.0263,-22C47.2957,-22 62.0476,-22 74.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-25.5001 84.9997,-22 74.9996,-18.5001 74.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node7\">\n<title>2</title>\n<ellipse cx=\"192\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M121.0105,-22C132.1524,-22 146.7778,-22 159.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.9179,-25.5001 169.9178,-22 159.9178,-18.5001 159.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-25.8\">b/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node3\">\n<title>3</title>\n<ellipse cx=\"18\" cy=\"-117\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-113.3\">3</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node4\">\n<title>4</title>\n<ellipse cx=\"103\" cy=\"-117\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-113.3\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge2\">\n<title>3-&gt;4</title>\n<path d=\"M36.0263,-117C47.2957,-117 62.0476,-117 74.8373,-117\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-120.5001 84.9997,-117 74.9996,-113.5001 74.9997,-120.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-120.8\">b/0</text>\n</g>\n<!-- 5 -->\n<g class=\"node\" id=\"node8\">\n<title>5</title>\n<ellipse cx=\"192\" cy=\"-117\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-117\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-113.3\">5</text>\n</g>\n<!-- 4&#45;&gt;5 -->\n<g class=\"edge\" id=\"edge5\">\n<title>4-&gt;5</title>\n<path d=\"M121.0105,-117C132.1524,-117 146.7778,-117 159.8566,-117\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.9179,-120.5001 169.9178,-117 159.9178,-113.5001 159.9179,-120.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-120.8\">a/0</text>\n</g>\n<!-- 6 -->\n<g class=\"node\" id=\"node5\">\n<title>6</title>\n<ellipse cx=\"18\" cy=\"-179\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-175.3\">6</text>\n</g>\n<!-- 7 -->\n<g class=\"node\" id=\"node6\">\n<title>7</title>\n<ellipse cx=\"103\" cy=\"-179\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-175.3\">7</text>\n</g>\n<!-- 6&#45;&gt;7 -->\n<g class=\"edge\" id=\"edge3\">\n<title>6-&gt;7</title>\n<path d=\"M36.0263,-179C47.2957,-179 62.0476,-179 74.8373,-179\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-182.5001 84.9997,-179 74.9996,-175.5001 74.9997,-182.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-182.8\">a/0</text>\n</g>\n<!-- 8 -->\n<g class=\"node\" id=\"node9\">\n<title>8</title>\n<ellipse cx=\"192\" cy=\"-179\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-179\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-175.3\">8</text>\n</g>\n<!-- 7&#45;&gt;8 -->\n<g class=\"edge\" id=\"edge6\">\n<title>7-&gt;8</title>\n<path d=\"M121.0105,-179C132.1524,-179 146.7778,-179 159.8566,-179\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.9179,-182.5001 169.9178,-179 159.9178,-175.5001 159.9179,-182.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-182.8\">c/0</text>\n</g>\n<!-- 2&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge7\">\n<title>2-&gt;2</title>\n<path d=\"M183.6298,-42.5808C182.4716,-52.8447 185.2617,-62 192,-62 196.3167,-62 199.0131,-58.2427 200.0891,-52.8436\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"203.595,-52.6729 200.3702,-42.5808 196.5976,-52.4812 203.595,-52.6729\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-65.8\">a/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sk4nxM6_TU50"
},
"source": [
"### Concatenate"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 225
},
"id": "JMdEJHF0TTLr",
"outputId": "4f45527e-65b3-4d49-e1e9-dcc296688185"
},
"source": [
"# The graph which recognizes \"ba\"\n",
"g1 = gtn.Graph()\n",
"g1.add_node(start=True)\n",
"g1.add_node()\n",
"g1.add_node(accept=True)\n",
"g1.add_arc(src_node=0, dst_node=1, label=1)\n",
"g1.add_arc(src_node=1, dst_node=2, label=0)\n",
"\n",
"# The graph which recognizes \"ac\" and \"bc\"\n",
"g2 = gtn.Graph()\n",
"g2.add_node(start=True)\n",
"g2.add_node()\n",
"g2.add_node()\n",
"g2.add_node(accept=True)\n",
"g2.add_arc(src_node=0, dst_node=1, label=0)\n",
"g2.add_arc(src_node=1, dst_node=3, label=2)\n",
"g2.add_arc(src_node=0, dst_node=2, label=1)\n",
"g2.add_arc(src_node=2, dst_node=3, label=2)\n",
"\n",
"draw(g1, isymbols=symbols)\n",
"draw(g2, isymbols=symbols)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 221.00 52.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 217,-48 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.0263,-22C47.2957,-22 62.0476,-22 74.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-25.5001 84.9997,-22 74.9996,-18.5001 74.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">b/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M121.2337,-22C132.0103,-22 145.9708,-22 158.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.7317,-25.5001 168.7317,-22 158.7317,-18.5001 158.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-25.8\">a/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"98pt\" viewBox=\"0.00 0.00 221.00 98.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 94)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-94 217,-94 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-41.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-72\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-68.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M35.201,-50.4639C46.9526,-54.1967 62.7617,-59.2184 76.1579,-63.4737\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"75.2342,-66.8526 85.8246,-66.5443 77.3534,-60.181 75.2342,-66.8526\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-63.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"103\" cy=\"-18\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-14.3\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;2</title>\n<path d=\"M35.201,-39.5361C46.9526,-35.8033 62.7617,-30.7816 76.1579,-26.5263\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"77.3534,-29.819 85.8246,-23.4557 75.2342,-23.1474 77.3534,-29.819\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-36.8\">b/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"191\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-45\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-41.3\">3</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;3</title>\n<path d=\"M120.386,-66.6657C131.7219,-63.1876 146.8329,-58.5513 160.167,-54.4601\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"161.3918,-57.7454 169.9253,-51.4661 159.3385,-51.0533 161.3918,-57.7454\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-64.8\">c/0</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge4\">\n<title>2-&gt;3</title>\n<path d=\"M120.386,-23.3343C131.7219,-26.8124 146.8329,-31.4487 160.167,-35.5399\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.3385,-38.9467 169.9253,-38.5339 161.3918,-32.2546 159.3385,-38.9467\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-36.8\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 152
},
"id": "EHx_tK4STmSr",
"outputId": "6af6c1d5-1cdf-495b-b280-efef48a7e7b7"
},
"source": [
"g = gtn.concat([g1, g2])\n",
"draw(g, isymbols=symbols)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"98pt\" viewBox=\"0.00 0.00 474.00 98.00\" width=\"474pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 94)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-94 470,-94 470,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-41.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-41.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.0263,-45C47.2957,-45 62.0476,-45 74.8373,-45\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-48.5001 84.9997,-45 74.9996,-41.5001 74.9997,-48.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-48.8\">b/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"187\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"187\" y=\"-41.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M121.2267,-45C132.1242,-45 146.2054,-45 158.5413,-45\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.8089,-48.5001 168.8088,-45 158.8088,-41.5001 158.8089,-48.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145\" y=\"-48.8\">a/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"271\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"271\" y=\"-41.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;3</title>\n<path d=\"M205.2267,-45C216.1242,-45 230.2054,-45 242.5413,-45\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"242.8089,-48.5001 252.8088,-45 242.8088,-41.5001 242.8089,-48.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"229\" y=\"-48.8\">ε/0</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"356\" cy=\"-72\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356\" y=\"-68.3\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge4\">\n<title>3-&gt;4</title>\n<path d=\"M288.201,-50.4639C299.9526,-54.1967 315.7617,-59.2184 329.1579,-63.4737\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"328.2342,-66.8526 338.8246,-66.5443 330.3534,-60.181 328.2342,-66.8526\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"313.5\" y=\"-63.8\">a/0</text>\n</g>\n<!-- 5 -->\n<g class=\"node\" id=\"node6\">\n<title>5</title>\n<ellipse cx=\"356\" cy=\"-18\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356\" y=\"-14.3\">5</text>\n</g>\n<!-- 3&#45;&gt;5 -->\n<g class=\"edge\" id=\"edge5\">\n<title>3-&gt;5</title>\n<path d=\"M288.201,-39.5361C299.9526,-35.8033 315.7617,-30.7816 329.1579,-26.5263\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"330.3534,-29.819 338.8246,-23.4557 328.2342,-23.1474 330.3534,-29.819\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"313.5\" y=\"-36.8\">b/0</text>\n</g>\n<!-- 6 -->\n<g class=\"node\" id=\"node7\">\n<title>6</title>\n<ellipse cx=\"444\" cy=\"-45\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"444\" cy=\"-45\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"444\" y=\"-41.3\">6</text>\n</g>\n<!-- 4&#45;&gt;6 -->\n<g class=\"edge\" id=\"edge6\">\n<title>4-&gt;6</title>\n<path d=\"M373.386,-66.6657C384.7219,-63.1876 399.8329,-58.5513 413.167,-54.4601\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"414.3918,-57.7454 422.9253,-51.4661 412.3385,-51.0533 414.3918,-57.7454\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"398\" y=\"-64.8\">c/0</text>\n</g>\n<!-- 5&#45;&gt;6 -->\n<g class=\"edge\" id=\"edge7\">\n<title>5-&gt;6</title>\n<path d=\"M373.386,-23.3343C384.7219,-26.8124 399.8329,-31.4487 413.167,-35.5399\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"412.3385,-38.9467 422.9253,-38.5339 414.3918,-32.2546 412.3385,-38.9467\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"398\" y=\"-36.8\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K3kDykhqUpVj"
},
"source": [
"## Advanced Operations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ASMHyHHMU_gO"
},
"source": [
"### Intersect"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 220
},
"id": "ddOAKExCUr71",
"outputId": "4244f84b-f183-415c-cd45-fcb37514b75c"
},
"source": [
"# Define the mapping from integer ids to arc symbol labels:\n",
"symbols = {0: 'a', 1: 'b', 2: 'c'}\n",
"\n",
"g1 = gtn.Graph()\n",
"g1.add_node(start=True)\n",
"g1.add_node(accept=True)\n",
"g1.add_arc(src_node=0, dst_node=0, label=0)\n",
"g1.add_arc(src_node=0, dst_node=1, label=1)\n",
"\n",
"g2 = gtn.Graph()\n",
"g2.add_node(start=True)\n",
"g2.add_node()\n",
"g2.add_node(accept=True)\n",
"g2.add_arc(src_node=0, dst_node=1, label=0)\n",
"g2.add_arc(src_node=0, dst_node=1, label=1)\n",
"g2.add_arc(src_node=0, dst_node=1, label=2)\n",
"g2.add_arc(src_node=1, dst_node=2, label=0)\n",
"g2.add_arc(src_node=1, dst_node=2, label=1)\n",
"g2.add_arc(src_node=1, dst_node=2, label=2)\n",
"\n",
"draw(g1, isymbols=symbols)\n",
"draw(g2, isymbols=symbols)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"81pt\" viewBox=\"0.00 0.00 137.00 81.00\" width=\"137pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 77)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-77 133,-77 133,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 0&#45;&gt;0 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;0</title>\n<path d=\"M10.6172,-38.6641C8.9766,-48.625 11.4375,-58 18,-58 22.2041,-58 24.7249,-54.1525 25.5625,-48.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"29.0601,-48.6002 25.3828,-38.6641 22.0612,-48.7247 29.0601,-48.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-61.8\">a/0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"107\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"107\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"107\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.0105,-22C47.1524,-22 61.7778,-22 74.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9179,-25.5001 84.9178,-22 74.9178,-18.5001 74.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">b/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"65pt\" viewBox=\"0.00 0.00 222.00 64.90\" width=\"222pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 60.8979)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-60.8979 218,-60.8979 218,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-21.1979\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-21.1979\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.5774,-34.2322C38.9127,-36.9378 45.0427,-39.5299 51,-40.8979 59.2302,-42.7879 61.7698,-42.7879 70,-40.8979 72.6994,-40.278 75.4343,-39.4068 78.1146,-38.3876\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.718,-41.5048 87.4226,-34.2322 76.8644,-35.1128 79.718,-41.5048\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-45.6979\">a/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.1352,-22.9515C40.9602,-22.5151 46.1726,-22.1169 51,-21.8979 59.4358,-21.5152 61.5642,-21.5152 70,-21.8979 71.5086,-21.9663 73.0547,-22.0523 74.6149,-22.1514\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.6227,-25.6626 84.8648,-22.9515 75.1675,-18.6838 74.6227,-25.6626\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.6979\">b/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;1</title>\n<path d=\"M31.4382,-12.747C37.0662,-8.4407 43.9378,-4.1162 51,-1.8979 59.0564,.6326 61.9436,.6326 70,-1.8979 73.8621,-3.111 77.6673,-4.9541 81.2432,-7.0796\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.3269,-10.009 89.5618,-12.747 83.2682,-4.224 79.3269,-10.009\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-5.6979\">c/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-21.1979\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M118.5774,-34.2322C123.9127,-36.9378 130.0427,-39.5299 136,-40.8979 144.7831,-42.9148 154.238,-41.5881 162.8325,-39.0015\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.232,-42.2169 172.4346,-35.511 161.8405,-35.638 164.232,-42.2169\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-44.6979\">a/0</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M121.1352,-22.9515C125.9602,-22.5151 131.1726,-22.1169 136,-21.8979 143.6897,-21.5491 151.9711,-21.7019 159.7451,-22.0797\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.7317,-25.5854 169.9284,-22.7084 160.1631,-18.5987 159.7317,-25.5854\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-25.6979\">b/0</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;2</title>\n<path d=\"M116.4382,-12.747C122.0662,-8.4407 128.9378,-4.1162 136,-1.8979 144.0564,.6326 146.88,.4202 155,-1.8979 158.6319,-2.9347 162.2548,-4.4423 165.734,-6.1974\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.1053,-9.2985 174.5155,-11.2683 167.6059,-3.2366 164.1053,-9.2985\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-5.6979\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "kSt4A0rZVKFQ",
"outputId": "35016abf-aa88-49a2-d725-bd872902667a"
},
"source": [
"g = gtn.intersect(g1, g2)\n",
"draw(g, isymbols=symbols)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 221.00 52.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 217,-48 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M120.0105,-22C131.1524,-22 145.7778,-22 158.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.9179,-25.5001 168.9178,-22 158.9178,-18.5001 158.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144.5\" y=\"-25.8\">b/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X7ahg2VgVTJn"
},
"source": [
"### Compose"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 228
},
"id": "xmDrnyvtVUoj",
"outputId": "3c658ab7-2812-4ba2-90ce-895a7666b52e"
},
"source": [
"# Define the mappings from integer ids to arc input and output labels\n",
"isymbols = {0: 'a', 1: 'b', 2: 'c'}\n",
"osymbols = {0: 'x', 1: 'y', 2: 'z'}\n",
"\n",
"g1 = gtn.Graph()\n",
"g1.add_node(start=True)\n",
"g1.add_node(accept=True)\n",
"g1.add_arc(src_node=0, dst_node=0, ilabel=0, olabel=0, weight=1)\n",
"g1.add_arc(src_node=0, dst_node=1, ilabel=1, olabel=1, weight=2)\n",
"g1.add_arc(src_node=1, dst_node=1, ilabel=2, olabel=2, weight=3)\n",
"\n",
"g2 = gtn.Graph()\n",
"g2.add_node(start=True)\n",
"g2.add_node()\n",
"g2.add_node()\n",
"g2.add_node(accept=True)\n",
"g2.add_arc(src_node=0, dst_node=1, ilabel=0, olabel=0)\n",
"g2.add_arc(src_node=0, dst_node=1, ilabel=0, olabel=1)\n",
"g2.add_arc(src_node=0, dst_node=1, ilabel=1, olabel=2)\n",
"g2.add_arc(src_node=1, dst_node=2, ilabel=0, olabel=0)\n",
"g2.add_arc(src_node=1, dst_node=2, ilabel=1, olabel=1)\n",
"g2.add_arc(src_node=1, dst_node=2, ilabel=2, olabel=2)\n",
"g2.add_arc(src_node=2, dst_node=3, ilabel=1, olabel=0)\n",
"g2.add_arc(src_node=2, dst_node=3, ilabel=2, olabel=1)\n",
"g2.add_arc(src_node=2, dst_node=3, ilabel=2, olabel=2)\n",
"g2.set_weights([1, 2, 3, 3, 2, 1, 2, 1, 3])\n",
"\n",
"draw(g1, isymbols=isymbols, osymbols=osymbols)\n",
"draw(g2, isymbols=osymbols, osymbols=isymbols)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"85pt\" viewBox=\"0.00 0.00 149.00 85.00\" width=\"149pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 81)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-81 145,-81 145,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 0&#45;&gt;0 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;0</title>\n<path d=\"M9.635,-38.2903C7.6179,-48.3892 10.4063,-58 18,-58 22.8647,-58 25.7573,-54.0557 26.6778,-48.5656\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"30.1677,-48.1791 26.365,-38.2903 23.171,-48.3922 30.1677,-48.1791\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-61.8\">a:x/1</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"119\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"119\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"119\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.0568,-22C50.0203,-22 69.7227,-22 86.4212,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"86.8703,-25.5001 96.8702,-22 86.8702,-18.5001 86.8703,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-25.8\">b:y/2</text>\n</g>\n<!-- 1&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;1</title>\n<path d=\"M109.5928,-42.1697C108.1455,-52.599 111.2813,-62 119,-62 124.0654,-62 127.1571,-57.9513 128.2751,-52.2203\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"131.7754,-52.2149 128.4072,-42.1697 124.776,-52.1228 131.7754,-52.2149\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"119\" y=\"-65.8\">c:z/3</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"67pt\" viewBox=\"0.00 0.00 341.00 67.10\" width=\"341pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 63.0966)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-63.0966 337,-63.0966 337,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-26.0966\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-22.3966\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"115\" cy=\"-26.0966\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"115\" y=\"-22.3966\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.5774,-35.4309C38.9127,-38.1365 45.0427,-40.7285 51,-42.0966 64.4283,-45.1802 68.5717,-45.1802 82,-42.0966 84.6994,-41.4767 87.4343,-40.6055 90.1146,-39.5863\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"91.718,-42.7034 99.4226,-35.4309 88.8644,-36.3115 91.718,-42.7034\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-47.8966\">x:a/1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.1352,-24.1501C40.9602,-23.7138 46.1726,-23.3155 51,-23.0966 64.7636,-22.4722 68.2364,-22.4722 82,-23.0966 83.5086,-23.165 85.0547,-23.2509 86.6149,-23.3501\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"86.6227,-26.8612 96.8648,-24.1501 87.1675,-19.8825 86.6227,-26.8612\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-26.8966\">x:b/2</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;1</title>\n<path d=\"M31.4382,-13.9457C37.0662,-9.6394 43.9378,-5.3148 51,-3.0966 64.1446,1.0322 68.8554,1.0322 82,-3.0966 85.8621,-4.3097 89.6673,-6.1527 93.2432,-8.2782\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"91.3269,-11.2077 101.5618,-13.9457 95.2682,-5.4227 91.3269,-11.2077\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-6.8966\">y:c/3</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"212\" cy=\"-26.0966\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"212\" y=\"-22.3966\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M130.5774,-35.4309C135.9127,-38.1365 142.0427,-40.7285 148,-42.0966 161.4283,-45.1802 165.5717,-45.1802 179,-42.0966 181.6994,-41.4767 184.4343,-40.6055 187.1146,-39.5863\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"188.718,-42.7034 196.4226,-35.4309 185.8644,-36.3115 188.718,-42.7034\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.5\" y=\"-47.8966\">x:a/3</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M133.1352,-24.1501C137.9602,-23.7138 143.1726,-23.3155 148,-23.0966 161.7636,-22.4722 165.2364,-22.4722 179,-23.0966 180.5086,-23.165 182.0547,-23.2509 183.6149,-23.3501\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"183.6227,-26.8612 193.8648,-24.1501 184.1675,-19.8825 183.6227,-26.8612\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.5\" y=\"-26.8966\">y:b/2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;2</title>\n<path d=\"M128.4382,-13.9457C134.0662,-9.6394 140.9378,-5.3148 148,-3.0966 161.1446,1.0322 165.8554,1.0322 179,-3.0966 182.8621,-4.3097 186.6673,-6.1527 190.2432,-8.2782\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"188.3269,-11.2077 198.5618,-13.9457 192.2682,-5.4227 188.3269,-11.2077\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.5\" y=\"-6.8966\">z:c/1</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"311\" cy=\"-26.0966\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"311\" cy=\"-26.0966\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"311\" y=\"-22.3966\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge7\">\n<title>2-&gt;3</title>\n<path d=\"M227.5774,-35.4309C232.9127,-38.1365 239.0427,-40.7285 245,-42.0966 257.5619,-44.9813 261.3786,-44.7086 274,-42.0966 276.5186,-41.5753 279.0765,-40.874 281.6107,-40.0551\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"283.0226,-43.263 291.1327,-36.4457 280.5414,-36.7175 283.0226,-43.263\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"259.5\" y=\"-47.8966\">y:a/2</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge8\">\n<title>2-&gt;3</title>\n<path d=\"M230.1352,-24.1501C234.9602,-23.7138 240.1726,-23.3155 245,-23.0966 257.8756,-22.5125 261.1217,-22.5753 274,-23.0966 275.4891,-23.1568 277.0111,-23.2307 278.5479,-23.3152\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"278.5013,-26.8198 288.7129,-23.9961 278.9692,-19.8355 278.5013,-26.8198\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"259.5\" y=\"-26.8966\">z:b/1</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge9\">\n<title>2-&gt;3</title>\n<path d=\"M225.4382,-13.9457C231.0662,-9.6394 237.9378,-5.3148 245,-3.0966 257.2966,.7658 261.6063,.4416 274,-3.0966 277.6319,-4.1334 281.2548,-5.6409 284.734,-7.3961\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"283.1053,-10.4972 293.5155,-12.467 286.6059,-4.4353 283.1053,-10.4972\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"259.5\" y=\"-6.8966\">z:c/3</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 181
},
"id": "GX7Zy8qCVeU8",
"outputId": "20c147b3-9d8d-4f26-fed5-dbc4980278f5"
},
"source": [
"g = gtn.compose(g1, g2)\n",
"draw(g, isymbols=isymbols, osymbols=isymbols)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"120pt\" viewBox=\"0.00 0.00 339.00 120.00\" width=\"339pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 116)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-116 335,-116 335,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-56\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-52.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"113\" cy=\"-83\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-79.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.7448,-64.7284C39.0914,-67.3936 45.1856,-70.1149 51,-72 61.841,-75.5148 74.1342,-77.9912 84.8991,-79.6919\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"84.649,-83.1917 95.0441,-81.1439 85.6408,-76.2624 84.649,-83.1917\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-81.8\">a:a/2</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M35.4432,-51.4624C48.212,-49.0085 65.7269,-47.4819 80,-53 85.1049,-54.9736 89.9067,-58.1446 94.1836,-61.7013\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"92.0294,-64.481 101.7028,-68.8026 96.8357,-59.3919 92.0294,-64.481\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-56.8\">a:b/3</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"113\" cy=\"-18\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-14.3\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;2</title>\n<path d=\"M31.1539,-43.2651C36.8449,-38.4181 43.8374,-33.2784 51,-30 61.5765,-25.159 74.0352,-22.2695 84.9994,-20.5454\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"85.4917,-24.0109 94.9434,-19.2239 84.5695,-17.0719 85.4917,-24.0109\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-33.8\">b:c/5</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"210\" cy=\"-94\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"210\" y=\"-90.3\">3</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;3</title>\n<path d=\"M131.245,-85.069C145.4709,-86.6823 165.479,-88.9512 181.7679,-90.7984\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"181.5061,-94.2911 191.8369,-91.9403 182.2949,-87.3357 181.5061,-94.2911\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"161.5\" y=\"-93.8\">a:a/4</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"210\" cy=\"-29\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"210\" y=\"-25.3\">4</text>\n</g>\n<!-- 1&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;4</title>\n<path d=\"M129.0176,-74.083C144.1819,-65.641 167.1871,-52.834 184.8075,-43.0247\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"186.9237,-45.8524 193.9586,-37.9302 183.5189,-39.7363 186.9237,-45.8524\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"161.5\" y=\"-66.8\">b:b/4</text>\n</g>\n<!-- 2&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge6\">\n<title>2-&gt;4</title>\n<path d=\"M131.0451,-16.9512C143.9293,-16.5253 161.6289,-16.5993 177,-19 178.9519,-19.3048 180.9489,-19.6973 182.9469,-20.1496\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"182.1764,-23.5667 192.7422,-22.7828 183.9937,-16.8067 182.1764,-23.5667\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"161.5\" y=\"-22.8\">c:c/4</text>\n</g>\n<!-- 5 -->\n<g class=\"node\" id=\"node6\">\n<title>5</title>\n<ellipse cx=\"309\" cy=\"-56\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"309\" cy=\"-56\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"309\" y=\"-52.3\">5</text>\n</g>\n<!-- 3&#45;&gt;5 -->\n<g class=\"edge\" id=\"edge7\">\n<title>3-&gt;5</title>\n<path d=\"M227.2445,-87.3809C241.4153,-81.9416 261.833,-74.1045 278.7186,-67.6232\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"280.0968,-70.8432 288.1785,-63.9921 277.5883,-64.3081 280.0968,-70.8432\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"257.5\" y=\"-83.8\">b:a/4</text>\n</g>\n<!-- 4&#45;&gt;5 -->\n<g class=\"edge\" id=\"edge8\">\n<title>4-&gt;5</title>\n<path d=\"M227.6992,-33.8271C241.6087,-37.6205 261.3277,-42.9985 277.8621,-47.5078\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"277.0965,-50.9268 287.6651,-50.1814 278.9384,-44.1735 277.0965,-50.9268\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"257.5\" y=\"-48.8\">c:b/4</text>\n</g>\n<!-- 4&#45;&gt;5 -->\n<g class=\"edge\" id=\"edge9\">\n<title>4-&gt;5</title>\n<path d=\"M226.6226,-21.578C239.4433,-16.9617 257.3941,-12.95 272,-19 278.1687,-21.5552 283.8453,-25.6931 288.8037,-30.2943\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"286.5166,-32.9653 295.972,-37.7452 291.5611,-28.1122 286.5166,-32.9653\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"257.5\" y=\"-22.8\">c:c/6</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "of4K8SklVsAJ"
},
"source": [
"### Forward and Viterbi"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 132
},
"id": "0QyM43EuVt00",
"outputId": "e72645bc-5147-4205-973c-3ce873ba3ed5"
},
"source": [
"fsa = gtn.Graph()\n",
"fsa.add_node(start=True)\n",
"fsa.add_node(start=True)\n",
"fsa.add_node()\n",
"fsa.add_node(accept=True)\n",
"fsa.add_arc(src_node=0, dst_node=1, label=0)\n",
"fsa.add_arc(src_node=0, dst_node=2, label=1)\n",
"fsa.add_arc(src_node=1, dst_node=2, label=2)\n",
"fsa.add_arc(src_node=2, dst_node=3, label=0)\n",
"fsa.set_weights([1.1, 3.2, 1.4, 2.1])\n",
"\n",
"draw(fsa, isymbols=symbols)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"83pt\" viewBox=\"0.00 0.00 337.00 83.00\" width=\"337pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 79)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-79 333,-79 333,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"113\" cy=\"-57\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-53.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M34.9841,-28.2573C49.2155,-33.5004 69.7846,-41.0785 86.2297,-47.1373\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"85.2492,-50.506 95.8426,-50.6789 87.6692,-43.9375 85.2492,-50.506\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-47.8\">a/1.1</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"208\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"208\" y=\"-18.3\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;2</title>\n<path d=\"M36.0158,-19.8967C51.6007,-18.1888 74.7518,-15.9185 95,-15 110.9836,-14.275 115.0164,-14.275 131,-15 147.1353,-15.7319 165.114,-17.3223 179.6384,-18.8024\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"179.6715,-22.3254 189.9842,-19.8967 180.4079,-15.3642 179.6715,-22.3254\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-18.8\">b/3.2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;2</title>\n<path d=\"M129.9841,-50.7427C144.2155,-45.4996 164.7846,-37.9215 181.2297,-31.8627\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"182.6692,-35.0625 190.8426,-28.3211 180.2492,-28.494 182.6692,-35.0625\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-47.8\">c/1.4</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"307\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"307\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"307\" y=\"-18.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge4\">\n<title>2-&gt;3</title>\n<path d=\"M226.1581,-22C239.7122,-22 258.5923,-22 274.6984,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"274.7979,-25.5001 284.7979,-22 274.7978,-18.5001 274.7979,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"255.5\" y=\"-25.8\">a/2.1</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CRBWQsR9NDQA",
"outputId": "c30066ad-57dc-41d7-90f1-8dd8eeda05cf"
},
"source": [
"score = gtn.forward_score(fsa)\n",
"print(f\"The forward score is {score.item():.2f}\")"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"The forward score is 5.81\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hKzRqb0jNK6q",
"outputId": "198302f8-8452-414e-c31c-b722ec3afae5"
},
"source": [
"score = gtn.viterbi_score(fsa)\n",
"print(f\"The viterbi score is {score.item():.3f}\")"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"The viterbi score is 5.300\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "Q4gZumUyNUTe",
"outputId": "0234bedd-7ce5-4f09-f4e4-67c502af274e"
},
"source": [
"viterbi_path = gtn.viterbi_path(fsa)\n",
"draw(viterbi_path, isymbols=symbols)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 244.00 52.00\" width=\"244pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 240,-48 240,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"115\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"115\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.245,-22C50.4709,-22 70.479,-22 86.7679,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"86.8369,-25.5001 96.8369,-22 86.8368,-18.5001 86.8369,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.5\" y=\"-25.8\">b/3.2</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"214\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"214\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M133.1581,-22C146.7122,-22 165.5923,-22 181.6984,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"181.7979,-25.5001 191.7979,-22 181.7978,-18.5001 181.7979,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"162.5\" y=\"-25.8\">a/2.1</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZP4Eln_zVaV6"
},
"source": [
"## Differentiation with Automata"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"id": "1G3nAaS_VdM2",
"outputId": "6741998c-8792-4b66-a3d1-65a9abd728ae"
},
"source": [
"symbols = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e'}\n",
"\n",
"# Request the gradient for A1 `calc_grad=True`:\n",
"A1 = gtn.Graph(calc_grad=True)\n",
"A1.add_node(start=True)\n",
"A1.add_node()\n",
"A1.add_node(accept=True)\n",
"for i in range(3):\n",
" A1.add_arc(\n",
" src_node=0, dst_node=1, ilabel=i, olabel=i, weight=i)\n",
" A1.add_arc(\n",
" src_node=1, dst_node=2, ilabel=i, olabel=i, weight=i)\n",
"draw(A1, isymbols=symbols)"
],
"execution_count": 33,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"65pt\" viewBox=\"0.00 0.00 222.00 64.90\" width=\"222pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 60.8979)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-60.8979 218,-60.8979 218,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-21.1979\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-21.1979\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.5774,-34.2322C38.9127,-36.9378 45.0427,-39.5299 51,-40.8979 59.2302,-42.7879 61.7698,-42.7879 70,-40.8979 72.6994,-40.278 75.4343,-39.4068 78.1146,-38.3876\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.718,-41.5048 87.4226,-34.2322 76.8644,-35.1128 79.718,-41.5048\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-45.6979\">a/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.1352,-22.9515C40.9602,-22.5151 46.1726,-22.1169 51,-21.8979 59.4358,-21.5152 61.5642,-21.5152 70,-21.8979 71.5086,-21.9663 73.0547,-22.0523 74.6149,-22.1514\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.6227,-25.6626 84.8648,-22.9515 75.1675,-18.6838 74.6227,-25.6626\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.6979\">b/1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;1</title>\n<path d=\"M31.4382,-12.747C37.0662,-8.4407 43.9378,-4.1162 51,-1.8979 59.0564,.6326 61.9436,.6326 70,-1.8979 73.8621,-3.111 77.6673,-4.9541 81.2432,-7.0796\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.3269,-10.009 89.5618,-12.747 83.2682,-4.224 79.3269,-10.009\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-5.6979\">c/2</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-21.1979\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M118.5774,-34.2322C123.9127,-36.9378 130.0427,-39.5299 136,-40.8979 144.7831,-42.9148 154.238,-41.5881 162.8325,-39.0015\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.232,-42.2169 172.4346,-35.511 161.8405,-35.638 164.232,-42.2169\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-44.6979\">a/0</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M121.1352,-22.9515C125.9602,-22.5151 131.1726,-22.1169 136,-21.8979 143.6897,-21.5491 151.9711,-21.7019 159.7451,-22.0797\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.7317,-25.5854 169.9284,-22.7084 160.1631,-18.5987 159.7317,-25.5854\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-25.6979\">b/1</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;2</title>\n<path d=\"M116.4382,-12.747C122.0662,-8.4407 128.9378,-4.1162 136,-1.8979 144.0564,.6326 146.88,.4202 155,-1.8979 158.6319,-2.9347 162.2548,-4.4423 165.734,-6.1974\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.1053,-9.2985 174.5155,-11.2683 167.6059,-3.2366 164.1053,-9.2985\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-5.6979\">c/2</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "CQJohXehgHOt",
"outputId": "f21ce0da-3a4a-486d-f842-8c0e969ef924"
},
"source": [
"# Request the gradient for A2 `calc_grad=True`: \n",
"A2 = gtn.Graph(calc_grad=True)\n",
"A2.add_node(start=True)\n",
"A2.add_node()\n",
"A2.add_node(accept=True)\n",
"for i in range(3, 5):\n",
" A2.add_arc(\n",
" src_node=0, dst_node=1, ilabel=i, olabel=i, weight=1)\n",
" A2.add_arc(\n",
" src_node=1, dst_node=2, ilabel=i, olabel=i, weight=1)\n",
"draw(A2, isymbols=symbols)"
],
"execution_count": 34,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 222.00 52.00\" width=\"222pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 218,-48 218,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.0263,-22C47.2957,-22 62.0476,-22 74.8373,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.9997,-25.5001 84.9997,-22 74.9996,-18.5001 74.9997,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.8\">d/1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M32.7129,-11.3344C38.1808,-7.9798 44.6133,-4.7072 51,-3 59.158,-.8193 61.842,-.8193 70,-3 73.1933,-3.8536 76.3981,-5.0986 79.4881,-6.5494\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"77.83,-9.6318 88.2871,-11.3344 81.1742,-3.4822 77.83,-9.6318\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-6.8\">e/1</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"192\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;2</title>\n<path d=\"M121.0105,-22C132.1524,-22 146.7778,-22 159.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.9179,-25.5001 169.9178,-22 159.9178,-18.5001 159.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-25.8\">d/1</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M117.7129,-11.3344C123.1808,-7.9798 129.6133,-4.7072 136,-3 145.3232,-.5078 155.35,-2.4075 164.3066,-5.8099\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"162.9659,-9.0445 173.5216,-9.9548 165.8375,-2.6606 162.9659,-9.0445\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-6.8\">e/1</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "3eHT83P8gLSk",
"outputId": "47eb0a64-35a9-411f-96ec-55a39062cac2"
},
"source": [
"# Don't request gradients for X:\n",
"X = gtn.Graph(calc_grad=False)\n",
"X.add_node(start=True)\n",
"X.add_node()\n",
"X.add_node(accept=True)\n",
"X.add_arc(src_node=0, dst_node=1, ilabel=2, olabel=2, weight=1.5)\n",
"X.add_arc(src_node=0, dst_node=1, ilabel=4, olabel=4, weight=1.5)\n",
"X.add_arc(src_node=1, dst_node=2, ilabel=0, olabel=0, weight=2.5)\n",
"draw(X, isymbols=symbols)"
],
"execution_count": 35,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 242.00 52.00\" width=\"242pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 238,-48 238,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"113\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"113\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.317,-22C49.9851,-22 68.9167,-22 84.5492,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"84.7393,-25.5001 94.7393,-22 84.7392,-18.5001 84.7393,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-25.8\">c/1.5</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M32.7129,-11.3344C38.1808,-7.9798 44.6133,-4.7072 51,-3 63.4517,.3284 67.5483,.3284 80,-3 83.1933,-3.8536 86.3981,-5.0986 89.4881,-6.5494\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"87.83,-9.6318 98.2871,-11.3344 91.1742,-3.4822 87.83,-9.6318\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"65.5\" y=\"-6.8\">e/1.5</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"212\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"212\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"212\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;2</title>\n<path d=\"M131.1581,-22C144.7122,-22 163.5923,-22 179.6984,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"179.7979,-25.5001 189.7979,-22 179.7978,-18.5001 179.7979,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-25.8\">a/2.5</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y97XEwLtgNSB"
},
"source": [
"# Compute the function:\n",
"G1 = gtn.union([A1, A2])\n",
"G2 = gtn.closure(X)\n",
"G3 = gtn.intersect(G1, G2)\n",
"G4 = gtn.forward_score(G3)"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "l1tD5Kx3gTlo"
},
"source": [
"# Compute the gradients:\n",
"gtn.backward(G4)"
],
"execution_count": 39,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"id": "qg7Mf1rigYUd",
"outputId": "5c602efa-cabf-4d13-b80c-7d2639b1d813"
},
"source": [
"# A1_grad is a graph which has the same structure as A1 and\n",
"# with arc weights which contain the derivatives of G4 with\n",
"# respect to each of A1's arc weights.\n",
"A1_grad = A1.grad()\n",
"draw(A1_grad, isymbols=symbols)"
],
"execution_count": 40,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"65pt\" viewBox=\"0.00 0.00 222.00 64.90\" width=\"222pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 60.8979)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-60.8979 218,-60.8979 218,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-21.1979\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"103\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103\" y=\"-21.1979\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.5774,-34.2322C38.9127,-36.9378 45.0427,-39.5299 51,-40.8979 59.2302,-42.7879 61.7698,-42.7879 70,-40.8979 72.6994,-40.278 75.4343,-39.4068 78.1146,-38.3876\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.718,-41.5048 87.4226,-34.2322 76.8644,-35.1128 79.718,-41.5048\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-45.6979\">a/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.1352,-22.9515C40.9602,-22.5151 46.1726,-22.1169 51,-21.8979 59.4358,-21.5152 61.5642,-21.5152 70,-21.8979 71.5086,-21.9663 73.0547,-22.0523 74.6149,-22.1514\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"74.6227,-25.6626 84.8648,-22.9515 75.1675,-18.6838 74.6227,-25.6626\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-25.6979\">b/0</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;1</title>\n<path d=\"M31.4382,-12.747C37.0662,-8.4407 43.9378,-4.1162 51,-1.8979 59.0564,.6326 61.9436,.6326 70,-1.8979 73.8621,-3.111 77.6673,-4.9541 81.2432,-7.0796\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"79.3269,-10.009 89.5618,-12.747 83.2682,-4.224 79.3269,-10.009\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60.5\" y=\"-5.6979\">c/1</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"192\" cy=\"-24.8979\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192\" y=\"-21.1979\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M118.5774,-34.2322C123.9127,-36.9378 130.0427,-39.5299 136,-40.8979 144.7831,-42.9148 154.238,-41.5881 162.8325,-39.0015\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.232,-42.2169 172.4346,-35.511 161.8405,-35.638 164.232,-42.2169\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-44.6979\">a/1</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M121.1352,-22.9515C125.9602,-22.5151 131.1726,-22.1169 136,-21.8979 143.6897,-21.5491 151.9711,-21.7019 159.7451,-22.0797\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"159.7317,-25.5854 169.9284,-22.7084 160.1631,-18.5987 159.7317,-25.5854\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-25.6979\">b/0</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;2</title>\n<path d=\"M116.4382,-12.747C122.0662,-8.4407 128.9378,-4.1162 136,-1.8979 144.0564,.6326 146.88,.4202 155,-1.8979 158.6319,-2.9347 162.2548,-4.4423 165.734,-6.1974\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"164.1053,-9.2985 174.5155,-11.2683 167.6059,-3.2366 164.1053,-9.2985\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.5\" y=\"-5.6979\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"id": "_NxCXkMhhIU8",
"outputId": "3c935691-849c-469c-80a7-d61715c65f94"
},
"source": [
"# We can also access any of the output graph gradients:\n",
"G3_grad = G3.grad()\n",
"draw(G3_grad, isymbols=symbols)"
],
"execution_count": 44,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"52pt\" viewBox=\"0.00 0.00 388.00 52.00\" width=\"388pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 48)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-48 384,-48 384,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">ε/1</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"186\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"186\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>1-&gt;2</title>\n<path d=\"M120.2267,-22C131.1242,-22 145.2054,-22 157.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"157.8089,-25.5001 167.8088,-22 157.8088,-18.5001 157.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144\" y=\"-25.8\">c/1</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"270\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"270\" y=\"-18.3\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>2-&gt;3</title>\n<path d=\"M204.2267,-22C215.1242,-22 229.2054,-22 241.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"241.8089,-25.5001 251.8088,-22 241.8088,-18.5001 241.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"228\" y=\"-25.8\">a/1</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"358\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"358\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"358\" y=\"-18.3\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge4\">\n<title>3-&gt;4</title>\n<path d=\"M288.2337,-22C299.0103,-22 312.9708,-22 325.5692,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"325.7317,-25.5001 335.7317,-22 325.7317,-18.5001 325.7317,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"312\" y=\"-25.8\">ε/1</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z30dUp8g6sNl"
},
"source": [
"## Sequence Criteria with Automata"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PiFDxqnf7Je_"
},
"source": [
"### Automatic Segmentation Criterion (ASG)"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 129
},
"id": "yfIDk1q7658B",
"outputId": "5b361c49-9dfa-4a5d-f7eb-a7e01087a787"
},
"source": [
"# Construct the target constrained graph for y = \"ab\"\n",
"Ay = gtn.Graph()\n",
"Ay.add_node(start=True)\n",
"Ay.add_node()\n",
"Ay.add_node(accept=True)\n",
"Ay.add_arc(src_node=0, dst_node=0, label=0)\n",
"Ay.add_arc(src_node=0, dst_node=1, label=0)\n",
"Ay.add_arc(src_node=1, dst_node=1, label=1)\n",
"Ay.add_arc(src_node=1, dst_node=2, label=1)\n",
"draw(Ay, isymbols=symbols)"
],
"execution_count": 49,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"81pt\" viewBox=\"0.00 0.00 221.00 81.00\" width=\"221pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 77)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-77 217,-77 217,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 0&#45;&gt;0 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;0</title>\n<path d=\"M10.6172,-38.6641C8.9766,-48.625 11.4375,-58 18,-58 22.2041,-58 24.7249,-54.1525 25.5625,-48.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"29.0601,-48.6002 25.3828,-38.6641 22.0612,-48.7247 29.0601,-48.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-61.8\">a/0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 1&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;1</title>\n<path d=\"M94.6172,-38.6641C92.9766,-48.625 95.4375,-58 102,-58 106.2041,-58 108.7249,-54.1525 109.5625,-48.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"113.0601,-48.6002 109.3828,-38.6641 106.0612,-48.7247 113.0601,-48.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-61.8\">b/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"191\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-18.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M120.0105,-22C131.1524,-22 145.7778,-22 158.8566,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"158.9179,-25.5001 168.9178,-22 158.9178,-18.5001 158.9179,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144.5\" y=\"-25.8\">b/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 110
},
"id": "lAzoFQPm6zZe",
"outputId": "3d7f28e3-1259-4495-ffc9-d53550509e43"
},
"source": [
"# Construct the emisisons graph for an input with four time-steps\n",
"# and randomly sampled emissions scores, s_t():\n",
"E = gtn.linear_graph(4, 3)\n",
"E.set_weights(np.random.randn(4*3).round(decimals=1))\n",
"draw(E, isymbols=symbols)"
],
"execution_count": 50,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"67pt\" viewBox=\"0.00 0.00 453.00 67.50\" width=\"453pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 63.4961)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-63.4961 449,-63.4961 449,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-26.4961\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-22.7961\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"118\" cy=\"-26.4961\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118\" y=\"-22.7961\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M33.5774,-35.8305C38.9127,-38.5361 45.0427,-41.1281 51,-42.4961 65.7278,-45.8782 70.2722,-45.8782 85,-42.4961 87.6994,-41.8762 90.4343,-41.005 93.1146,-39.9858\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"94.718,-43.103 102.4226,-35.8305 91.8644,-36.711 94.718,-43.103\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68\" y=\"-48.2961\">a/-0.3</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.1352,-24.5497C40.9602,-24.1134 46.1726,-23.7151 51,-23.4961 66.0956,-22.8114 69.9044,-22.8114 85,-23.4961 86.5086,-23.5646 88.0547,-23.6505 89.6149,-23.7496\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"89.6227,-27.2608 99.8648,-24.5497 90.1675,-20.282 89.6227,-27.2608\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68\" y=\"-27.2961\">b/0.3</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;1</title>\n<path d=\"M31.4382,-14.3452C37.0662,-10.039 43.9378,-5.7144 51,-3.4961 65.4167,1.0322 70.5833,1.0322 85,-3.4961 88.8621,-4.7092 92.6673,-6.5523 96.2432,-8.6778\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"94.3269,-11.6073 104.5618,-14.3452 98.2682,-5.8223 94.3269,-11.6073\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68\" y=\"-7.2961\">c/1.2</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"219\" cy=\"-26.4961\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"219\" y=\"-22.7961\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M133.5774,-35.8305C138.9127,-38.5361 145.0427,-41.1281 151,-42.4961 166.1609,-45.9777 170.8391,-45.9777 186,-42.4961 188.6994,-41.8762 191.4343,-41.005 194.1146,-39.9858\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"195.718,-43.103 203.4226,-35.8305 192.8644,-36.711 195.718,-43.103\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"168.5\" y=\"-48.2961\">a/1.1</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M136.1352,-24.5497C140.9602,-24.1134 146.1726,-23.7151 151,-23.4961 166.5396,-22.7912 170.4604,-22.7912 186,-23.4961 187.5086,-23.5646 189.0547,-23.6505 190.6149,-23.7496\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"190.6227,-27.2608 200.8648,-24.5497 191.1675,-20.282 190.6227,-27.2608\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"168.5\" y=\"-27.2961\">b/-1.8</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;2</title>\n<path d=\"M131.4382,-14.3452C137.0662,-10.039 143.9378,-5.7144 151,-3.4961 165.8407,1.1654 171.1593,1.1654 186,-3.4961 189.8621,-4.7092 193.6673,-6.5523 197.2432,-8.6778\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"195.3269,-11.6073 205.5618,-14.3452 199.2682,-5.8223 195.3269,-11.6073\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"168.5\" y=\"-7.2961\">c/-0.5</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"319\" cy=\"-26.4961\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319\" y=\"-22.7961\">3</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge7\">\n<title>2-&gt;3</title>\n<path d=\"M234.5774,-35.8305C239.9127,-38.5361 246.0427,-41.1281 252,-42.4961 266.7278,-45.8782 271.2722,-45.8782 286,-42.4961 288.6994,-41.8762 291.4343,-41.005 294.1146,-39.9858\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"295.718,-43.103 303.4226,-35.8305 292.8644,-36.711 295.718,-43.103\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"269\" y=\"-48.2961\">a/-2.4</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge8\">\n<title>2-&gt;3</title>\n<path d=\"M237.1352,-24.5497C241.9602,-24.1134 247.1726,-23.7151 252,-23.4961 267.0956,-22.8114 270.9044,-22.8114 286,-23.4961 287.5086,-23.5646 289.0547,-23.6505 290.6149,-23.7496\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"290.6227,-27.2608 300.8648,-24.5497 291.1675,-20.282 290.6227,-27.2608\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"269\" y=\"-27.2961\">b/0.1</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge9\">\n<title>2-&gt;3</title>\n<path d=\"M232.4382,-14.3452C238.0662,-10.039 244.9378,-5.7144 252,-3.4961 266.4167,1.0322 271.5833,1.0322 286,-3.4961 289.8621,-4.7092 293.6673,-6.5523 297.2432,-8.6778\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"295.3269,-11.6073 305.5618,-14.3452 299.2682,-5.8223 295.3269,-11.6073\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"269\" y=\"-7.2961\">c/-0.1</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"423\" cy=\"-26.4961\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"423\" cy=\"-26.4961\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"423\" y=\"-22.7961\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge10\">\n<title>3-&gt;4</title>\n<path d=\"M334.5774,-35.8305C339.9127,-38.5361 346.0427,-41.1281 352,-42.4961 366.7278,-45.8782 371.2025,-45.5586 386,-42.4961 388.5186,-41.9749 391.0765,-41.2736 393.6107,-40.4546\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"395.0226,-43.6625 403.1327,-36.8452 392.5414,-37.117 395.0226,-43.6625\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"369\" y=\"-48.2961\">a/1.4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge11\">\n<title>3-&gt;4</title>\n<path d=\"M337.1352,-24.5497C341.9602,-24.1134 347.1726,-23.7151 352,-23.4961 367.0956,-22.8114 370.9013,-22.885 386,-23.4961 387.4891,-23.5564 389.0111,-23.6303 390.5479,-23.7148\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"390.5013,-27.2194 400.7129,-24.3957 390.9692,-20.235 390.5013,-27.2194\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"369\" y=\"-27.2961\">b/0.5</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge12\">\n<title>3-&gt;4</title>\n<path d=\"M332.4382,-14.3452C338.0662,-10.039 344.9378,-5.7144 352,-3.4961 366.4167,1.0322 371.4694,.6521 386,-3.4961 389.6319,-4.533 393.2548,-6.0405 396.734,-7.7956\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"395.1053,-10.8967 405.5155,-12.8666 398.6059,-4.8348 395.1053,-10.8967\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"369\" y=\"-7.2961\">c/-0.8</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 257
},
"id": "dB5hLvdw7ywM",
"outputId": "a514d857-903f-4f8e-ccc5-29a3afd09c8f"
},
"source": [
"# Make the bigram transition graph for the token set\n",
"# {a, b, c} with gradient computation enabled:\n",
"B = gtn.Graph(calc_grad=True)\n",
"B.add_node(start=True, accept=True)\n",
"B.add_node(accept=True)\n",
"B.add_node(accept=True)\n",
"B.add_node(accept=True)\n",
"for i in range(4):\n",
" for j in range(3):\n",
" B.add_arc(src_node=i, dst_node=(j + 1), label=j)\n",
"draw(B, isymbols=symbols)"
],
"execution_count": 51,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"177pt\" viewBox=\"0.00 0.00 330.00 176.74\" width=\"330pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 172.7391)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-172.7391 326,-172.7391 326,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"22\" cy=\"-29.7391\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<ellipse cx=\"22\" cy=\"-29.7391\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"22\" y=\"-26.0391\">0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"114\" cy=\"-100.7391\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"114\" cy=\"-100.7391\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"114\" y=\"-97.0391\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;1</title>\n<path d=\"M39.7385,-43.4286C53.4842,-54.0367 72.7267,-68.8869 88.1797,-80.8126\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"86.2609,-83.7528 96.3159,-87.0916 90.5376,-78.2112 86.2609,-83.7528\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68\" y=\"-74.5391\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"207\" cy=\"-39.7391\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"207\" cy=\"-39.7391\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"207\" y=\"-36.0391\">2</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;2</title>\n<path d=\"M44.3153,-30.9453C76.766,-32.6994 137.1735,-35.9647 174.5479,-37.9849\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"174.7191,-41.4992 184.8935,-38.5442 175.097,-34.5094 174.7191,-41.4992\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"114\" y=\"-39.5391\">b/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"300\" cy=\"-77.7391\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"300\" cy=\"-77.7391\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"300\" y=\"-74.0391\">3</text>\n</g>\n<!-- 0&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge3\">\n<title>0-&gt;3</title>\n<path d=\"M42.7978,-22.3585C80.7077,-10.0499 163.208,11.424 229,-8.7391 251.2513,-15.5584 269.8264,-34.466 282.3895,-50.7673\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"279.6678,-52.9747 288.3988,-58.9764 285.3161,-48.8399 279.6678,-52.9747\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-4.5391\">c/0</text>\n</g>\n<!-- 1&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;1</title>\n<path d=\"M105.6298,-121.3199C104.4716,-131.5838 107.2617,-140.7391 114,-140.7391 118.3167,-140.7391 121.0131,-136.9818 122.0891,-131.5827\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"125.595,-131.412 122.3702,-121.3199 118.5976,-131.2203 125.595,-131.412\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"114\" y=\"-144.5391\">a/0</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;2</title>\n<path d=\"M135.2811,-94.8202C146.2187,-91.1914 159.4323,-85.8611 170,-78.7391 175.824,-74.8141 181.4211,-69.8268 186.4214,-64.7409\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"189.0549,-67.0482 193.2774,-57.3311 183.9169,-62.2941 189.0549,-67.0482\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-92.5391\">b/0</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge6\">\n<title>1-&gt;3</title>\n<path d=\"M125.3761,-119.8591C131.7566,-128.8721 140.5403,-138.9047 151,-144.7391 181.3374,-161.6612 194.3547,-152.2729 229,-149.7391 244.233,-148.625 250.4472,-153.4403 263,-144.7391 275.5665,-136.0284 284.3115,-121.6116 290.1315,-108.3783\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"293.3986,-109.6346 293.8653,-99.05 286.8998,-107.0334 293.3986,-109.6346\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"207\" y=\"-157.5391\">c/0</text>\n</g>\n<!-- 2&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge7\">\n<title>2-&gt;1</title>\n<path d=\"M185.366,-44.762C174.4593,-47.943 161.381,-52.791 151,-59.7391 144.7551,-63.9189 138.866,-69.3822 133.7002,-74.9515\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"131.057,-72.6573 127.1387,-82.501 136.3404,-77.2493 131.057,-72.6573\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.5\" y=\"-63.5391\">a/0</text>\n</g>\n<!-- 2&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge8\">\n<title>2-&gt;2</title>\n<path d=\"M198.6298,-60.3199C197.4716,-70.5838 200.2617,-79.7391 207,-79.7391 211.3167,-79.7391 214.0131,-75.9818 215.0891,-70.5827\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"218.595,-70.412 215.3702,-60.3199 211.5976,-70.2203 218.595,-70.412\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"207\" y=\"-83.5391\">b/0</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge9\">\n<title>2-&gt;3</title>\n<path d=\"M228.7976,-36.4054C239.5002,-35.6816 252.3495,-36.2157 263,-40.7391 269.1457,-43.3493 274.8124,-47.5106 279.7685,-52.1172\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"277.4813,-54.7879 286.9386,-59.5638 282.5238,-49.9326 277.4813,-54.7879\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"253.5\" y=\"-44.5391\">c/0</text>\n</g>\n<!-- 3&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge10\">\n<title>3-&gt;1</title>\n<path d=\"M279.9459,-86.819C265.967,-92.713 246.7195,-99.9741 229,-103.7391 203.2915,-109.2015 196.2674,-106.8497 170,-107.7391 161.5604,-108.0249 159.4076,-108.5274 151,-107.7391 149.2995,-107.5797 147.5608,-107.3799 145.8093,-107.1496\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"146.2408,-103.6744 135.8163,-105.5667 145.1455,-110.5882 146.2408,-103.6744\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"207\" y=\"-110.5391\">a/0</text>\n</g>\n<!-- 3&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge11\">\n<title>3-&gt;2</title>\n<path d=\"M279.4531,-69.8587C268.8337,-65.7373 255.6841,-60.5589 244,-55.7391 241.6224,-54.7583 239.1649,-53.7294 236.6995,-52.6864\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"237.8743,-49.3821 227.3037,-48.6657 235.1203,-55.8177 237.8743,-49.3821\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"253.5\" y=\"-66.5391\">b/0</text>\n</g>\n<!-- 3&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge12\">\n<title>3-&gt;3</title>\n<path d=\"M291.6298,-98.3199C290.4716,-108.5838 293.2617,-117.7391 300,-117.7391 304.3167,-117.7391 307.0131,-113.9818 308.0891,-108.5827\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"311.595,-108.412 308.3702,-98.3199 304.5976,-108.2203 311.595,-108.412\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"300\" y=\"-121.5391\">c/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fUYlvhF-7mqP"
},
"source": [
"# The ASG loss function which takes as input the emissions graph E,\n",
"# the transition graph B and the target constrained graph Ay\n",
"def ASG(E, B, Ay):\n",
" # Compute constrained and normalization graphs:\n",
" AXy = gtn.intersect(gtn.intersect(B, Ay), E)\n",
" ZX = gtn.intersect(B, E)\n",
" \n",
" # Forward both graphs:\n",
" AXy_score = gtn.forward_score(AXy)\n",
" ZX_score = gtn.forward_score(ZX)\n",
" \n",
" # Compute the loss:\n",
" loss = gtn.negate(gtn.subtract(AXy_score, ZX_score))\n",
" \n",
" # Clear the previous gradients:\n",
" E.zero_grad()\n",
" B.zero_grad()\n",
" \n",
" # Compute gradients:\n",
" gtn.backward(loss, retain_graph=False)\n",
"\n",
" return loss.item()"
],
"execution_count": 53,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6h8vtKab7185",
"outputId": "89130d55-2f93-4363-8ea4-3ce381dcdd66"
},
"source": [
"# Call the ASG loss:\n",
"loss = ASG(E, B, Ay)\n",
"print(f\"The ASG loss is {loss:.3f}.\")\n",
"\n",
"# Access the graph containing the gradient for B:\n",
"dB = B.grad()"
],
"execution_count": 54,
"outputs": [
{
"output_type": "stream",
"text": [
"The ASG loss is 4.048.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S_K8Zwoy92XH"
},
"source": [
"### Connectionist Temporal Classification (CTC)"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 176
},
"id": "Cr7XiHiw95jM",
"outputId": "f61caa29-a506-4d68-cb45-9e819ce53282"
},
"source": [
"# Add the blank token to the symbol set\n",
"symbols[3] = \"<b>\"\n",
"\n",
"# Construct the CTC alignment graph for \"ab\":\n",
"Ay = gtn.Graph()\n",
"Ay.add_node(start=True)\n",
"Ay.add_node()\n",
"Ay.add_node()\n",
"Ay.add_node(accept=True)\n",
"Ay.add_node(accept=True)\n",
"Ay.add_arc(src_node=0, dst_node=0, label=3)\n",
"Ay.add_arc(src_node=0, dst_node=1, label=0)\n",
"Ay.add_arc(src_node=1, dst_node=1, label=0)\n",
"Ay.add_arc(src_node=1, dst_node=2, label=3)\n",
"Ay.add_arc(src_node=1, dst_node=3, label=1)\n",
"Ay.add_arc(src_node=2, dst_node=2, label=3)\n",
"Ay.add_arc(src_node=2, dst_node=3, label=1)\n",
"Ay.add_arc(src_node=3, dst_node=3, label=1)\n",
"Ay.add_arc(src_node=3, dst_node=4, label=3)\n",
"Ay.add_arc(src_node=4, dst_node=4, label=3)\n",
"draw(Ay, isymbols=symbols)"
],
"execution_count": 55,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.SVG object>"
],
"image/svg+xml": "<svg height=\"116pt\" viewBox=\"0.00 0.00 433.00 116.00\" width=\"433pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 112)\">\n<title>FST</title>\n<polygon fill=\"#ffffff\" points=\"-4,4 -4,-112 429,-112 429,4 -4,4\" stroke=\"transparent\"/>\n<!-- 0 -->\n<g class=\"node\" id=\"node1\">\n<title>0</title>\n<ellipse cx=\"18\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\" stroke-width=\"2\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-18.3\">0</text>\n</g>\n<!-- 0&#45;&gt;0 -->\n<g class=\"edge\" id=\"edge1\">\n<title>0-&gt;0</title>\n<path d=\"M10.6172,-38.6641C8.9766,-48.625 11.4375,-58 18,-58 22.2041,-58 24.7249,-54.1525 25.5625,-48.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"29.0601,-48.6002 25.3828,-38.6641 22.0612,-48.7247 29.0601,-48.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"18\" y=\"-61.8\">&lt;b&gt;/0</text>\n</g>\n<!-- 1 -->\n<g class=\"node\" id=\"node2\">\n<title>1</title>\n<ellipse cx=\"102\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-18.3\">1</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge2\">\n<title>0-&gt;1</title>\n<path d=\"M36.2267,-22C47.1242,-22 61.2054,-22 73.5413,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"73.8089,-25.5001 83.8088,-22 73.8088,-18.5001 73.8089,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"60\" y=\"-25.8\">a/0</text>\n</g>\n<!-- 1&#45;&gt;1 -->\n<g class=\"edge\" id=\"edge3\">\n<title>1-&gt;1</title>\n<path d=\"M94.6172,-38.6641C92.9766,-48.625 95.4375,-58 102,-58 106.2041,-58 108.7249,-54.1525 109.5625,-48.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"113.0601,-48.6002 109.3828,-38.6641 106.0612,-48.7247 113.0601,-48.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102\" y=\"-61.8\">a/0</text>\n</g>\n<!-- 2 -->\n<g class=\"node\" id=\"node3\">\n<title>2</title>\n<ellipse cx=\"204\" cy=\"-57\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"204\" y=\"-53.3\">2</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge4\">\n<title>1-&gt;2</title>\n<path d=\"M119.3029,-27.9373C135.2247,-33.4006 159.0774,-41.5854 177.4167,-47.8783\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"176.358,-51.2153 186.9526,-51.1504 178.63,-44.5942 176.358,-51.2153\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"153\" y=\"-47.8\">&lt;b&gt;/0</text>\n</g>\n<!-- 3 -->\n<g class=\"node\" id=\"node4\">\n<title>3</title>\n<ellipse cx=\"293\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"293\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"293\" y=\"-18.3\">3</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge5\">\n<title>1-&gt;3</title>\n<path d=\"M119.8854,-19.7565C143.3289,-17.0913 185.7322,-13.2165 222,-15 234.7259,-15.6258 248.6885,-16.8789 260.8392,-18.1542\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"260.6521,-21.6546 270.974,-19.2653 261.4151,-14.6963 260.6521,-21.6546\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"204\" y=\"-18.8\">b/0</text>\n</g>\n<!-- 2&#45;&gt;2 -->\n<g class=\"edge\" id=\"edge6\">\n<title>2-&gt;2</title>\n<path d=\"M196.6172,-73.6641C194.9766,-83.625 197.4375,-93 204,-93 208.2041,-93 210.7249,-89.1525 211.5625,-83.7682\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"215.0601,-83.6002 211.3828,-73.6641 208.0612,-83.7247 215.0601,-83.6002\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"204\" y=\"-96.8\">&lt;b&gt;/0</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge7\">\n<title>2-&gt;3</title>\n<path d=\"M221.1601,-50.2516C232.9476,-45.6161 248.9138,-39.3373 262.8009,-33.8761\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"264.4306,-36.9962 272.4559,-30.0792 261.8687,-30.4818 264.4306,-36.9962\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"246.5\" y=\"-46.8\">b/0</text>\n</g>\n<!-- 3&#45;&gt;3 -->\n<g class=\"edge\" id=\"edge8\">\n<title>3-&gt;3</title>\n<path d=\"M284.6298,-42.5808C283.4716,-52.8447 286.2617,-62 293,-62 297.3167,-62 300.0131,-58.2427 301.0891,-52.8436\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"304.595,-52.6729 301.3702,-42.5808 297.5976,-52.4812 304.595,-52.6729\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"293\" y=\"-65.8\">b/0</text>\n</g>\n<!-- 4 -->\n<g class=\"node\" id=\"node5\">\n<title>4</title>\n<ellipse cx=\"403\" cy=\"-22\" fill=\"none\" rx=\"18\" ry=\"18\" stroke=\"#000000\"/>\n<ellipse cx=\"403\" cy=\"-22\" fill=\"none\" rx=\"22\" ry=\"22\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"403\" y=\"-18.3\">4</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge9\">\n<title>3-&gt;4</title>\n<path d=\"M315.2601,-22C331.139,-22 352.7423,-22 370.5387,-22\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"370.773,-25.5001 380.773,-22 370.7729,-18.5001 370.773,-25.5001\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"348\" y=\"-25.8\">&lt;b&gt;/0</text>\n</g>\n<!-- 4&#45;&gt;4 -->\n<g class=\"edge\" id=\"edge10\">\n<title>4-&gt;4</title>\n<path d=\"M393.2928,-41.7575C391.6432,-52.3499 394.8789,-62 403,-62 408.3295,-62 411.555,-57.8441 412.6766,-51.9932\" fill=\"none\" stroke=\"#000000\"/>\n<polygon fill=\"#000000\" points=\"416.1772,-51.768 412.7072,-41.7575 409.1772,-51.7469 416.1772,-51.768\" stroke=\"#000000\"/>\n<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"403\" y=\"-65.8\">&lt;b&gt;/0</text>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O7yrc1Xh9_aI"
},
"source": [
"# The ASG loss function which takes as input the emissions graph E,\n",
"# and the target constrained graph Ay\n",
"def CTC(E, Ay):\n",
" # Compute constrained and normalization graphs:\n",
" AXy = gtn.intersect(Ay, E)\n",
" ZX = E\n",
" \n",
" # Forward both graphs:\n",
" AXy_score = gtn.forward_score(AXy)\n",
" ZX_score = gtn.forward_score(ZX)\n",
" \n",
" # Compute the loss:\n",
" loss = gtn.negate(gtn.subtract(AXy_score, ZX_score))\n",
" \n",
" # Clear the previous gradients:\n",
" E.zero_grad()\n",
" \n",
" # Compute gradients:\n",
" gtn.backward(loss, retain_graph=False)\n",
"\n",
" return loss.item()"
],
"execution_count": 56,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Scz8AAum-PYS",
"outputId": "fa8304b3-4a90-4bbb-ccab-841fe70e3d8d"
},
"source": [
"# The emissions graph for CTC is the same as for ASG:\n",
"E = gtn.linear_graph(4, 3)\n",
"E.set_weights(np.random.randn(4*3).round(decimals=1))\n",
"\n",
"# Compute the CTC loss:\n",
"loss = CTC(E, Ay)\n",
"\n",
"print(f\"The CTC loss is {loss:.3f}.\")"
],
"execution_count": 57,
"outputs": [
{
"output_type": "stream",
"text": [
"The CTC loss is 2.758.\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment