Created
April 16, 2024 18:44
-
-
Save ariG23498/3043580b7f73313f6657b22c77988079 to your computer and use it in GitHub Desktop.
rnn-diffusion.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyP8xftETb2XVvWYheumJWH6", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"ef0db216de374d8d9fe4ef7b7c27988b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_e096b4adf8d74c5bb62dc62de43d5485", | |
"IPY_MODEL_5947389f26104af2aeadde1aadb56c46", | |
"IPY_MODEL_c9d86aca6d1b4e9092d087faea2a5f23" | |
], | |
"layout": "IPY_MODEL_9302b16aa22647d0ac292e492f944e32" | |
} | |
}, | |
"e096b4adf8d74c5bb62dc62de43d5485": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_d67abee0a02f41b8ac4d28ed89a43188", | |
"placeholder": "", | |
"style": "IPY_MODEL_3ed6412412644d19bd435f5e688072e7", | |
"value": "Downloading data: 100%" | |
} | |
}, | |
"5947389f26104af2aeadde1aadb56c46": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_329a1621b4c9484bae3089c3d0a4f30d", | |
"max": 30931277, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_161fb7ef927a4cb7873d14e0b5fa500b", | |
"value": 30931277 | |
} | |
}, | |
"c9d86aca6d1b4e9092d087faea2a5f23": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_2aa04feb3e48483c84bcbde834897aad", | |
"placeholder": "", | |
"style": "IPY_MODEL_c1d1f44a1aae457fafb01295c6c9c6dc", | |
"value": " 30.9M/30.9M [00:02<00:00, 16.9MB/s]" | |
} | |
}, | |
"9302b16aa22647d0ac292e492f944e32": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"d67abee0a02f41b8ac4d28ed89a43188": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3ed6412412644d19bd435f5e688072e7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"329a1621b4c9484bae3089c3d0a4f30d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"161fb7ef927a4cb7873d14e0b5fa500b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"2aa04feb3e48483c84bcbde834897aad": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c1d1f44a1aae457fafb01295c6c9c6dc": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"ea998c61f56c49fdbb27229f15c84c3d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_b1d6c88baf604285aa06840605b2b08e", | |
"IPY_MODEL_ada4f71b6a4c449ca86b21449428bf9b", | |
"IPY_MODEL_1828031ff524483aa28bd4b8c6f35981" | |
], | |
"layout": "IPY_MODEL_ff984182fc064033b3095ca08c8220b2" | |
} | |
}, | |
"b1d6c88baf604285aa06840605b2b08e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_756a2563ac4147d0bf4376a52cd26870", | |
"placeholder": "", | |
"style": "IPY_MODEL_707a7542d7d848cdb1f26cccf6e778d5", | |
"value": "Downloading data: 100%" | |
} | |
}, | |
"ada4f71b6a4c449ca86b21449428bf9b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_fa86fdcadc8c4015987dad3a05152e0e", | |
"max": 5175617, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_adc80c861a914013a40393c44ac50210", | |
"value": 5175617 | |
} | |
}, | |
"1828031ff524483aa28bd4b8c6f35981": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_fbfc5527b1c14d7aa42295d35e2b3620", | |
"placeholder": "", | |
"style": "IPY_MODEL_f107f457eba6446490b5adb8ac983869", | |
"value": " 5.18M/5.18M [00:00<00:00, 4.38MB/s]" | |
} | |
}, | |
"ff984182fc064033b3095ca08c8220b2": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"756a2563ac4147d0bf4376a52cd26870": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"707a7542d7d848cdb1f26cccf6e778d5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"fa86fdcadc8c4015987dad3a05152e0e": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"adc80c861a914013a40393c44ac50210": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"fbfc5527b1c14d7aa42295d35e2b3620": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"f107f457eba6446490b5adb8ac983869": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"53af44e4c7dd4d6fb5e3b385bc5f7473": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_ad5d7fa88a2d48ffbbd5148b1d726d36", | |
"IPY_MODEL_2ce5f202a0ad4c59b3dd54e7778e5ac7", | |
"IPY_MODEL_5235bd0e9078423489d90061ebb23a68" | |
], | |
"layout": "IPY_MODEL_d4426c5e37d04d55bb8107c8672311dc" | |
} | |
}, | |
"ad5d7fa88a2d48ffbbd5148b1d726d36": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_3946b48ad5cd496b8b7323ffd586133c", | |
"placeholder": "", | |
"style": "IPY_MODEL_958423b6803641dc916e1d5669cfad68", | |
"value": "Generating train split: 100%" | |
} | |
}, | |
"2ce5f202a0ad4c59b3dd54e7778e5ac7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_4e06d3fa7b594b34abf19e44de503e87", | |
"max": 60000, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_48cb9ee8be1e4d938789a90defbfb704", | |
"value": 60000 | |
} | |
}, | |
"5235bd0e9078423489d90061ebb23a68": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_809081ccc38a48f9869cf0ef956ed24c", | |
"placeholder": "", | |
"style": "IPY_MODEL_e840f8e33dd0434fa8bd6d3bcd5efda7", | |
"value": " 60000/60000 [00:00<00:00, 90943.69 examples/s]" | |
} | |
}, | |
"d4426c5e37d04d55bb8107c8672311dc": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3946b48ad5cd496b8b7323ffd586133c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"958423b6803641dc916e1d5669cfad68": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"4e06d3fa7b594b34abf19e44de503e87": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"48cb9ee8be1e4d938789a90defbfb704": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"809081ccc38a48f9869cf0ef956ed24c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"e840f8e33dd0434fa8bd6d3bcd5efda7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"f68f3db300a947838e60acd6eef73f56": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_9711ce9ed95f4c57a96cc42c419a6535", | |
"IPY_MODEL_f08c421ac9634861bb7104f9e1f3e5ce", | |
"IPY_MODEL_08d967d6785a4a8c861f23d8259f2ac2" | |
], | |
"layout": "IPY_MODEL_7fbca9fc1fb0413b8848723c8c286b3c" | |
} | |
}, | |
"9711ce9ed95f4c57a96cc42c419a6535": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_c9cc94dc9efd4b73ac5e196de8914180", | |
"placeholder": "", | |
"style": "IPY_MODEL_243b70f554464303843eff056b169773", | |
"value": "Generating test split: 100%" | |
} | |
}, | |
"f08c421ac9634861bb7104f9e1f3e5ce": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_e101f5e375c74a13887a46ce51c1bc0d", | |
"max": 10000, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_b96c61f62aaa4666aa46211b53e7fd62", | |
"value": 10000 | |
} | |
}, | |
"08d967d6785a4a8c861f23d8259f2ac2": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_ecdfef650f07497e975f0edec0153ad4", | |
"placeholder": "", | |
"style": "IPY_MODEL_6cf4e93859f445f0b514c626c6b62cc5", | |
"value": " 10000/10000 [00:00<00:00, 53018.19 examples/s]" | |
} | |
}, | |
"7fbca9fc1fb0413b8848723c8c286b3c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c9cc94dc9efd4b73ac5e196de8914180": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"243b70f554464303843eff056b169773": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"e101f5e375c74a13887a46ce51c1bc0d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"b96c61f62aaa4666aa46211b53e7fd62": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"ecdfef650f07497e975f0edec0153ad4": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"6cf4e93859f445f0b514c626c6b62cc5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/3043580b7f73313f6657b22c77988079/rnn-diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion" | |
], | |
"metadata": { | |
"id": "02vydzVb6VWO" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup and Imports" | |
], | |
"metadata": { | |
"id": "-CFKktKV2qqo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --upgrade -qq datasets" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "BmiGcVlQ64dD", | |
"outputId": "2201fe6e-d14c-4d0e-f19a-a53531bc30ed" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "hzUHzDmy2kOy" | |
}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt\n", | |
"\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.nn import functional as F\n", | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"from torchvision import transforms" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Configurations" | |
], | |
"metadata": { | |
"id": "rlNouEPu7ART" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"batch_size = 128\n", | |
"image_size = 28\n", | |
"channels = 1\n", | |
"timesteps = 100\n", | |
"hidden_dim = 128\n", | |
"epochs = 5\n", | |
"\n", | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
"print(f\"{device=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3HNObPXp7Bff", | |
"outputId": "264c2094-1e45-44c9-f7c5-a5bef28208a9" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"device='cuda'\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Dataset and Loaders" | |
], | |
"metadata": { | |
"id": "JGEzhe8s7B1p" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# load dataset from the hub\n", | |
"dataset = load_dataset(\"fashion_mnist\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 145, | |
"referenced_widgets": [ | |
"ef0db216de374d8d9fe4ef7b7c27988b", | |
"e096b4adf8d74c5bb62dc62de43d5485", | |
"5947389f26104af2aeadde1aadb56c46", | |
"c9d86aca6d1b4e9092d087faea2a5f23", | |
"9302b16aa22647d0ac292e492f944e32", | |
"d67abee0a02f41b8ac4d28ed89a43188", | |
"3ed6412412644d19bd435f5e688072e7", | |
"329a1621b4c9484bae3089c3d0a4f30d", | |
"161fb7ef927a4cb7873d14e0b5fa500b", | |
"2aa04feb3e48483c84bcbde834897aad", | |
"c1d1f44a1aae457fafb01295c6c9c6dc", | |
"ea998c61f56c49fdbb27229f15c84c3d", | |
"b1d6c88baf604285aa06840605b2b08e", | |
"ada4f71b6a4c449ca86b21449428bf9b", | |
"1828031ff524483aa28bd4b8c6f35981", | |
"ff984182fc064033b3095ca08c8220b2", | |
"756a2563ac4147d0bf4376a52cd26870", | |
"707a7542d7d848cdb1f26cccf6e778d5", | |
"fa86fdcadc8c4015987dad3a05152e0e", | |
"adc80c861a914013a40393c44ac50210", | |
"fbfc5527b1c14d7aa42295d35e2b3620", | |
"f107f457eba6446490b5adb8ac983869", | |
"53af44e4c7dd4d6fb5e3b385bc5f7473", | |
"ad5d7fa88a2d48ffbbd5148b1d726d36", | |
"2ce5f202a0ad4c59b3dd54e7778e5ac7", | |
"5235bd0e9078423489d90061ebb23a68", | |
"d4426c5e37d04d55bb8107c8672311dc", | |
"3946b48ad5cd496b8b7323ffd586133c", | |
"958423b6803641dc916e1d5669cfad68", | |
"4e06d3fa7b594b34abf19e44de503e87", | |
"48cb9ee8be1e4d938789a90defbfb704", | |
"809081ccc38a48f9869cf0ef956ed24c", | |
"e840f8e33dd0434fa8bd6d3bcd5efda7", | |
"f68f3db300a947838e60acd6eef73f56", | |
"9711ce9ed95f4c57a96cc42c419a6535", | |
"f08c421ac9634861bb7104f9e1f3e5ce", | |
"08d967d6785a4a8c861f23d8259f2ac2", | |
"7fbca9fc1fb0413b8848723c8c286b3c", | |
"c9cc94dc9efd4b73ac5e196de8914180", | |
"243b70f554464303843eff056b169773", | |
"e101f5e375c74a13887a46ce51c1bc0d", | |
"b96c61f62aaa4666aa46211b53e7fd62", | |
"ecdfef650f07497e975f0edec0153ad4", | |
"6cf4e93859f445f0b514c626c6b62cc5" | |
] | |
}, | |
"id": "PdIqVYuj6axU", | |
"outputId": "81b1d933-710c-4721-e89a-5e8001a8f771" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading data: 0%| | 0.00/30.9M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "ef0db216de374d8d9fe4ef7b7c27988b" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading data: 0%| | 0.00/5.18M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "ea998c61f56c49fdbb27229f15c84c3d" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "53af44e4c7dd4d6fb5e3b385bc5f7473" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "f68f3db300a947838e60acd6eef73f56" | |
} | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# define image transformations\n", | |
"transform = transforms.Compose([\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Lambda(lambda t: (t * 2) - 1)\n", | |
"])\n", | |
"\n", | |
"# define function\n", | |
"def transforms(examples):\n", | |
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n", | |
" del examples[\"image\"]\n", | |
" return examples\n", | |
"\n", | |
"transformed_dataset = (\n", | |
" dataset\n", | |
" .with_transform(transforms)\n", | |
" .remove_columns(\"label\")\n", | |
")\n", | |
"\n", | |
"# create dataloader\n", | |
"dataloader = DataLoader(\n", | |
" transformed_dataset[\"train\"],\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True\n", | |
")" | |
], | |
"metadata": { | |
"id": "SbnfSuvR6eel" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get a batch of data and check the shape\n", | |
"batch = next(iter(dataloader))\n", | |
"print(batch.keys())\n", | |
"print(batch[\"pixel_values\"].shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GneDUV5J74KJ", | |
"outputId": "7e5a4cb6-73ee-4b2f-d335-4922d096e4dd" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['pixel_values'])\n", | |
"torch.Size([128, 1, 28, 28])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Forward Diffusion Process" | |
], | |
"metadata": { | |
"id": "Ggg6okjO8NFS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"id": "9_7ipOKVIbFn" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Define a linear schedule\n", | |
"def linear_beta_schedule(timesteps):\n", | |
" beta_start = 0.0001\n", | |
" beta_end = 0.02\n", | |
" return torch.linspace(beta_start, beta_end, timesteps)\n", | |
"\n", | |
"# define beta schedule\n", | |
"betas = linear_beta_schedule(timesteps=timesteps)\n", | |
"\n", | |
"# define alphas = 1 - beta\n", | |
"alphas = 1. - betas\n", | |
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n", | |
"\n", | |
"# batchify the cum prods\n", | |
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).repeat(batch_size, 1)\n", | |
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).repeat(batch_size, 1)\n", | |
"\n", | |
"# define the extraction function\n", | |
"# essentially all the alphas are cumprods\n", | |
"# we need to extract the alpha according to the\n", | |
"# time steps\n", | |
"def extract(a, t, x_shape):\n", | |
" batch_size = t.shape[0]\n", | |
" out = a.gather(dim=-1, index=t.cpu())\n", | |
" return out.reshape(\n", | |
" batch_size, timesteps, *((1,) * (len(x_shape) - 2))\n", | |
" ).to(t.device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# forward diffusion\n", | |
"def q_sample(x_start, t, noise=None):\n", | |
" if noise is None:\n", | |
" noise = torch.randn_like(x_start)\n", | |
"\n", | |
" sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)\n", | |
" sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n", | |
"\n", | |
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise" | |
], | |
"metadata": { | |
"id": "pmHZGKa8Lhbz" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# create a batch of noise images\n", | |
"t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1)\n", | |
"input_images = q_sample(\n", | |
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, timesteps)\n", | |
")" | |
], | |
"metadata": { | |
"id": "uWgxpZPj8sY6" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Viz the forward process" | |
], | |
"metadata": { | |
"id": "aQ0H8DVxEJqF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"idx = random.randint(0, batch_size-1)\n", | |
"for i in range(5):\n", | |
" plt.subplot(1, 5, i+1)\n", | |
" plt.imshow(input_images[idx, 100//5 * i].permute(1, 2, 0), cmap=\"gray\")\n", | |
" plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"id": "I-hoBcj3DNSf", | |
"outputId": "f63af000-1c1b-40f2-caa3-9d5db9eebd5c" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 5 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Define the RNN goodies" | |
], | |
"metadata": { | |
"id": "-CeUWnzpEoqb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageEncoder(nn.Module):\n", | |
" def __init__(self, in_channels):\n", | |
" super().__init__()\n", | |
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(32)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(64)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(128)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Global Average Pooling to reduce spatial dimensions to 1x1\n", | |
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.gap(x)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "z81MMxklEqhO" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageDecoder(nn.Module):\n", | |
" def __init__(self, out_channels, initial_height, initial_width):\n", | |
" super(ImageDecoder, self).__init__()\n", | |
" self.conv_transpose1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(32)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(out_channels)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Additional layer to ensure correct output dimensions\n", | |
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n", | |
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv_transpose1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv_transpose2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv_transpose3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "gzQ35KxeE2Cz" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CustomRecurrence(nn.Module):\n", | |
" def __init__(self, in_channels, initial_height, initial_width, hidden_dim, num_layers=1, training=False):\n", | |
" super().__init__()\n", | |
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n", | |
" self.positional_encoder = nn.Embedding(timesteps, hidden_dim)\n", | |
" self.rnn = nn.RNN(\n", | |
" input_size=128, # hardcoded for the time being\n", | |
" hidden_size=hidden_dim,\n", | |
" num_layers=num_layers,\n", | |
" batch_first=True\n", | |
" )\n", | |
" self.image_decoder = ImageDecoder(\n", | |
" out_channels=in_channels,\n", | |
" initial_height=initial_height,\n", | |
" initial_width=initial_width\n", | |
" )\n", | |
" self.training=training\n", | |
"\n", | |
" def forward(self, x, hidden_states=None, timesteps=timesteps):\n", | |
" batch_size = x.shape[0]\n", | |
"\n", | |
" # x : (b, t-1, c, h, w)\n", | |
" x = x.reshape(batch_size * (timesteps-1), channels, image_size, image_size)\n", | |
"\n", | |
" latent_vectors = self.image_encoder(x) # (b*t, c, h, w)\n", | |
" latent_vectors = latent_vectors.reshape(batch_size, timesteps-1, -1) # (b, t, c*h*w)\n", | |
"\n", | |
" pos_embeds = torch.arange(timesteps-1).unsqueeze(0).repeat(batch_size, 1).to(device) # (b, t)\n", | |
" latent_vectors += self.positional_encoder(pos_embeds) # (b, t, c*h*w)\n", | |
"\n", | |
" if self.training:\n", | |
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, c*h*w)\n", | |
" else:\n", | |
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, c*h*w)\n", | |
"\n", | |
" rnn_outputs = rnn_outputs.reshape(batch_size * (timesteps-1), 128, 1, 1) # (b*t, c, 1, 1)\n", | |
" reconstructed_x = self.image_decoder(rnn_outputs)\n", | |
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps-1, channels, image_size, image_size)\n", | |
"\n", | |
" if self.training:\n", | |
" return reconstructed_x\n", | |
" else:\n", | |
" return reconstructed_x, hidden_states" | |
], | |
"metadata": { | |
"id": "chqbXzZaE5Ok" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Training Loop" | |
], | |
"metadata": { | |
"id": "KD_os31TMhEN" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = CustomRecurrence(\n", | |
" in_channels=channels,\n", | |
" initial_height=image_size,\n", | |
" initial_width=image_size,\n", | |
" hidden_dim=hidden_dim,\n", | |
" training=True, # important parameter\n", | |
")\n", | |
"model.to(device)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(\n", | |
" model.parameters(),\n", | |
" lr=1e-3\n", | |
")\n", | |
"\n", | |
"loss_fn = nn.MSELoss()" | |
], | |
"metadata": { | |
"id": "TYo28dSj49K_" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for step, batch in enumerate(dataloader):\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" batch_size = batch[\"pixel_values\"].shape[0]\n", | |
" batch = batch[\"pixel_values\"].to(device)\n", | |
"\n", | |
" t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1).to(device)\n", | |
"\n", | |
" input_images = q_sample(\n", | |
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, timesteps)\n", | |
" )\n", | |
" reconstructed_images = model(input_images[:, :timesteps-1, ...])\n", | |
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n", | |
"\n", | |
"\n", | |
" if step % 100 == 0:\n", | |
" print(\"Loss:\", loss.item())\n", | |
"\n", | |
" loss.backward()\n", | |
" optimizer.step()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ImaLvHwjMiDf", | |
"outputId": "444f58e7-29ac-4fda-d74f-6714344ba205" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Loss: 1.4240089654922485\n", | |
"Loss: 0.7346835732460022\n", | |
"Loss: 0.7231897115707397\n", | |
"Loss: 0.7261417508125305\n", | |
"Loss: 0.7263044714927673\n", | |
"Loss: 0.7309923768043518\n", | |
"Loss: 0.7178468704223633\n", | |
"Loss: 0.7321939468383789\n", | |
"Loss: 0.706930935382843\n", | |
"Loss: 0.7191136479377747\n", | |
"Loss: 0.7089812159538269\n", | |
"Loss: 0.7193553447723389\n", | |
"Loss: 0.7181675434112549\n", | |
"Loss: 0.7335854172706604\n", | |
"Loss: 0.7130610942840576\n", | |
"Loss: 0.714571475982666\n", | |
"Loss: 0.7261708378791809\n", | |
"Loss: 0.7198599576950073\n", | |
"Loss: 0.7220982909202576\n", | |
"Loss: 0.7202709317207336\n", | |
"Loss: 0.7199481129646301\n", | |
"Loss: 0.71610426902771\n", | |
"Loss: 0.6913580894470215\n", | |
"Loss: 0.7348539233207703\n", | |
"Loss: 0.7180370688438416\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Inference Loop" | |
], | |
"metadata": { | |
"id": "PP6CFN4vMidh" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.training=False\n", | |
"model = model.eval()" | |
], | |
"metadata": { | |
"id": "iNwii9gJ-Utu" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"with torch.no_grad():\n", | |
" current_input = torch.randn(\n", | |
" 1, 1, channels, image_size, image_size\n", | |
" ).to(device)\n", | |
" hidden_state = None # Start with no hidden state\n", | |
" generated_sequence = []\n", | |
"\n", | |
" for _ in range(timesteps):\n", | |
" # Forward pass\n", | |
" output, hidden_state = model(current_input, hidden_state, timesteps=2)\n", | |
"\n", | |
" # Use output as next input\n", | |
" current_input = output" | |
], | |
"metadata": { | |
"id": "8NxLub4ZMjpH" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rNmgRa7m-ZEW", | |
"outputId": "8cafec2a-40ec-4e8f-ca55-94d6e78512ea" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1, 1, 28, 28])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plt.imshow(output[0, 0, 0].detach().cpu().numpy(), cmap=\"gray\")\n", | |
"plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 406 | |
}, | |
"id": "OkJeUd45JooW", | |
"outputId": "be0a63f0-2584-43f2-bee7-a5fb0aa0bd2d" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHXUlEQVR4nO3cMY6NfR/HYUcmkZEggw6NiERiDyqFBViDBViBxjroprQAQUGrkikUFoDMIGGGDOftPnkLuTPHc875z/O4rvr+Jd8pZj7uwj2bz+fzEwBw4sSJk6MHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQDaO+uBsNlvlDhju6tWroydM2tvbW8vNumxtbY2eMGl/f3/hm4ODgxUsWZ6j/F9lbwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD0Ajou9vb3REybt7++PnrBUx/3nOTw8HD1hCG8KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgs/l8Pj/Sg7PZqrcAsEJH+XPvTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4vSdPnoyeMGl7e3stN+v0+fPn0RMm7e7uLnyzt7e3giXL8fDhw9ETJu3s7Kzl5rjxpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4ve3t7dETJr1582bhm5Mnj/e/QR49ejR6wqRPnz6t5WZddnZ2Rk+Y9OHDh9EThjjev6UArJUoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAZvP5fH6kB2ezVW/hX+RPPm533D+Id/PmzdETJn358mUtN+vy8ePH0RP+Okf5c3+8f0sBWCtRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAZGP0AH7v2rVroydM2t3dXcvNOt25c2f0hEnfvn1b+Obr168rWLIcV65cGT1h0vPnzxe+efHixfKHrJk3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGwc9cFTp06tcsc/dnh4uPDNz58/V7BkOXZ3d0dPmLS/vz96wtK9evVq9IRJv379Gj1hqc6dOzd6wqSDg4PRE4bwpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPKf+SDen/BBPP7fy5cvR0+YtLm5uZabdblx48boCZN+/PgxesIQ3hQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD1gWW7fvr3wza1bt1awZDkeP348esKk9+/fr+Vmne7duzd6wqTz588vfLO1tbWCJctx//790RMmPXjwYOGb169fr2DJenlTACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyMZRH/z+/fsqd/xjb9++XfjmOP9MFy9eHD1h0oULFxa+uX79+gqWLM/ly5dHT5h09uzZhW/OnDmzgiXL8ezZs9ETJr179270hCG8KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgPzVH8T7k5t1uXv37ugJk/7kg3h/crNOly5dGj1h0unTpxe+2dzcXMGS5Xj69OnoCZN8EA+Av54oABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAMpvP5/PRIwA4HrwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ/wHzssF0wLHC6wAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": {} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment