Created
April 16, 2024 18:44
-
-
Save ariG23498/3043580b7f73313f6657b22c77988079 to your computer and use it in GitHub Desktop.
rnn-diffusion.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyP8xftETb2XVvWYheumJWH6", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"ef0db216de374d8d9fe4ef7b7c27988b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_e096b4adf8d74c5bb62dc62de43d5485", | |
"IPY_MODEL_5947389f26104af2aeadde1aadb56c46", | |
"IPY_MODEL_c9d86aca6d1b4e9092d087faea2a5f23" | |
], | |
"layout": "IPY_MODEL_9302b16aa22647d0ac292e492f944e32" | |
} | |
}, | |
"e096b4adf8d74c5bb62dc62de43d5485": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_d67abee0a02f41b8ac4d28ed89a43188", | |
"placeholder": "", | |
"style": "IPY_MODEL_3ed6412412644d19bd435f5e688072e7", | |
"value": "Downloading data: 100%" | |
} | |
}, | |
"5947389f26104af2aeadde1aadb56c46": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_329a1621b4c9484bae3089c3d0a4f30d", | |
"max": 30931277, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_161fb7ef927a4cb7873d14e0b5fa500b", | |
"value": 30931277 | |
} | |
}, | |
"c9d86aca6d1b4e9092d087faea2a5f23": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_2aa04feb3e48483c84bcbde834897aad", | |
"placeholder": "", | |
"style": "IPY_MODEL_c1d1f44a1aae457fafb01295c6c9c6dc", | |
"value": " 30.9M/30.9M [00:02<00:00, 16.9MB/s]" | |
} | |
}, | |
"9302b16aa22647d0ac292e492f944e32": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"d67abee0a02f41b8ac4d28ed89a43188": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3ed6412412644d19bd435f5e688072e7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"329a1621b4c9484bae3089c3d0a4f30d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"161fb7ef927a4cb7873d14e0b5fa500b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"2aa04feb3e48483c84bcbde834897aad": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c1d1f44a1aae457fafb01295c6c9c6dc": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"ea998c61f56c49fdbb27229f15c84c3d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_b1d6c88baf604285aa06840605b2b08e", | |
"IPY_MODEL_ada4f71b6a4c449ca86b21449428bf9b", | |
"IPY_MODEL_1828031ff524483aa28bd4b8c6f35981" | |
], | |
"layout": "IPY_MODEL_ff984182fc064033b3095ca08c8220b2" | |
} | |
}, | |
"b1d6c88baf604285aa06840605b2b08e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_756a2563ac4147d0bf4376a52cd26870", | |
"placeholder": "", | |
"style": "IPY_MODEL_707a7542d7d848cdb1f26cccf6e778d5", | |
"value": "Downloading data: 100%" | |
} | |
}, | |
"ada4f71b6a4c449ca86b21449428bf9b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_fa86fdcadc8c4015987dad3a05152e0e", | |
"max": 5175617, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_adc80c861a914013a40393c44ac50210", | |
"value": 5175617 | |
} | |
}, | |
"1828031ff524483aa28bd4b8c6f35981": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_fbfc5527b1c14d7aa42295d35e2b3620", | |
"placeholder": "", | |
"style": "IPY_MODEL_f107f457eba6446490b5adb8ac983869", | |
"value": " 5.18M/5.18M [00:00<00:00, 4.38MB/s]" | |
} | |
}, | |
"ff984182fc064033b3095ca08c8220b2": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"756a2563ac4147d0bf4376a52cd26870": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"707a7542d7d848cdb1f26cccf6e778d5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"fa86fdcadc8c4015987dad3a05152e0e": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"adc80c861a914013a40393c44ac50210": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"fbfc5527b1c14d7aa42295d35e2b3620": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"f107f457eba6446490b5adb8ac983869": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"53af44e4c7dd4d6fb5e3b385bc5f7473": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_ad5d7fa88a2d48ffbbd5148b1d726d36", | |
"IPY_MODEL_2ce5f202a0ad4c59b3dd54e7778e5ac7", | |
"IPY_MODEL_5235bd0e9078423489d90061ebb23a68" | |
], | |
"layout": "IPY_MODEL_d4426c5e37d04d55bb8107c8672311dc" | |
} | |
}, | |
"ad5d7fa88a2d48ffbbd5148b1d726d36": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_3946b48ad5cd496b8b7323ffd586133c", | |
"placeholder": "", | |
"style": "IPY_MODEL_958423b6803641dc916e1d5669cfad68", | |
"value": "Generating train split: 100%" | |
} | |
}, | |
"2ce5f202a0ad4c59b3dd54e7778e5ac7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_4e06d3fa7b594b34abf19e44de503e87", | |
"max": 60000, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_48cb9ee8be1e4d938789a90defbfb704", | |
"value": 60000 | |
} | |
}, | |
"5235bd0e9078423489d90061ebb23a68": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_809081ccc38a48f9869cf0ef956ed24c", | |
"placeholder": "", | |
"style": "IPY_MODEL_e840f8e33dd0434fa8bd6d3bcd5efda7", | |
"value": " 60000/60000 [00:00<00:00, 90943.69 examples/s]" | |
} | |
}, | |
"d4426c5e37d04d55bb8107c8672311dc": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3946b48ad5cd496b8b7323ffd586133c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"958423b6803641dc916e1d5669cfad68": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"4e06d3fa7b594b34abf19e44de503e87": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"48cb9ee8be1e4d938789a90defbfb704": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"809081ccc38a48f9869cf0ef956ed24c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"e840f8e33dd0434fa8bd6d3bcd5efda7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"f68f3db300a947838e60acd6eef73f56": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_9711ce9ed95f4c57a96cc42c419a6535", | |
"IPY_MODEL_f08c421ac9634861bb7104f9e1f3e5ce", | |
"IPY_MODEL_08d967d6785a4a8c861f23d8259f2ac2" | |
], | |
"layout": "IPY_MODEL_7fbca9fc1fb0413b8848723c8c286b3c" | |
} | |
}, | |
"9711ce9ed95f4c57a96cc42c419a6535": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_c9cc94dc9efd4b73ac5e196de8914180", | |
"placeholder": "", | |
"style": "IPY_MODEL_243b70f554464303843eff056b169773", | |
"value": "Generating test split: 100%" | |
} | |
}, | |
"f08c421ac9634861bb7104f9e1f3e5ce": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_e101f5e375c74a13887a46ce51c1bc0d", | |
"max": 10000, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_b96c61f62aaa4666aa46211b53e7fd62", | |
"value": 10000 | |
} | |
}, | |
"08d967d6785a4a8c861f23d8259f2ac2": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_ecdfef650f07497e975f0edec0153ad4", | |
"placeholder": "", | |
"style": "IPY_MODEL_6cf4e93859f445f0b514c626c6b62cc5", | |
"value": " 10000/10000 [00:00<00:00, 53018.19 examples/s]" | |
} | |
}, | |
"7fbca9fc1fb0413b8848723c8c286b3c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c9cc94dc9efd4b73ac5e196de8914180": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"243b70f554464303843eff056b169773": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"e101f5e375c74a13887a46ce51c1bc0d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"b96c61f62aaa4666aa46211b53e7fd62": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"ecdfef650f07497e975f0edec0153ad4": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"6cf4e93859f445f0b514c626c6b62cc5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/3043580b7f73313f6657b22c77988079/rnn-diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion" | |
], | |
"metadata": { | |
"id": "02vydzVb6VWO" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup and Imports" | |
], | |
"metadata": { | |
"id": "-CFKktKV2qqo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --upgrade -qq datasets" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "BmiGcVlQ64dD", | |
"outputId": "2201fe6e-d14c-4d0e-f19a-a53531bc30ed" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "hzUHzDmy2kOy" | |
}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt\n", | |
"\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.nn import functional as F\n", | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"from torchvision import transforms" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Configurations" | |
], | |
"metadata": { | |
"id": "rlNouEPu7ART" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"batch_size = 128\n", | |
"image_size = 28\n", | |
"channels = 1\n", | |
"timesteps = 100\n", | |
"hidden_dim = 128\n", | |
"epochs = 5\n", | |
"\n", | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
"print(f\"{device=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3HNObPXp7Bff", | |
"outputId": "264c2094-1e45-44c9-f7c5-a5bef28208a9" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"device='cuda'\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Dataset and Loaders" | |
], | |
"metadata": { | |
"id": "JGEzhe8s7B1p" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# load dataset from the hub\n", | |
"dataset = load_dataset(\"fashion_mnist\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 145, | |
"referenced_widgets": [ | |
"ef0db216de374d8d9fe4ef7b7c27988b", | |
"e096b4adf8d74c5bb62dc62de43d5485", | |
"5947389f26104af2aeadde1aadb56c46", | |
"c9d86aca6d1b4e9092d087faea2a5f23", | |
"9302b16aa22647d0ac292e492f944e32", | |
"d67abee0a02f41b8ac4d28ed89a43188", | |
"3ed6412412644d19bd435f5e688072e7", | |
"329a1621b4c9484bae3089c3d0a4f30d", | |
"161fb7ef927a4cb7873d14e0b5fa500b", | |
"2aa04feb3e48483c84bcbde834897aad", | |
"c1d1f44a1aae457fafb01295c6c9c6dc", | |
"ea998c61f56c49fdbb27229f15c84c3d", | |
"b1d6c88baf604285aa06840605b2b08e", | |
"ada4f71b6a4c449ca86b21449428bf9b", | |
"1828031ff524483aa28bd4b8c6f35981", | |
"ff984182fc064033b3095ca08c8220b2", | |
"756a2563ac4147d0bf4376a52cd26870", | |
"707a7542d7d848cdb1f26cccf6e778d5", | |
"fa86fdcadc8c4015987dad3a05152e0e", | |
"adc80c861a914013a40393c44ac50210", | |
"fbfc5527b1c14d7aa42295d35e2b3620", | |
"f107f457eba6446490b5adb8ac983869", | |
"53af44e4c7dd4d6fb5e3b385bc5f7473", | |
"ad5d7fa88a2d48ffbbd5148b1d726d36", | |
"2ce5f202a0ad4c59b3dd54e7778e5ac7", | |
"5235bd0e9078423489d90061ebb23a68", | |
"d4426c5e37d04d55bb8107c8672311dc", | |
"3946b48ad5cd496b8b7323ffd586133c", | |
"958423b6803641dc916e1d5669cfad68", | |
"4e06d3fa7b594b34abf19e44de503e87", | |
"48cb9ee8be1e4d938789a90defbfb704", | |
"809081ccc38a48f9869cf0ef956ed24c", | |
"e840f8e33dd0434fa8bd6d3bcd5efda7", | |
"f68f3db300a947838e60acd6eef73f56", | |
"9711ce9ed95f4c57a96cc42c419a6535", | |
"f08c421ac9634861bb7104f9e1f3e5ce", | |
"08d967d6785a4a8c861f23d8259f2ac2", | |
"7fbca9fc1fb0413b8848723c8c286b3c", | |
"c9cc94dc9efd4b73ac5e196de8914180", | |
"243b70f554464303843eff056b169773", | |
"e101f5e375c74a13887a46ce51c1bc0d", | |
"b96c61f62aaa4666aa46211b53e7fd62", | |
"ecdfef650f07497e975f0edec0153ad4", | |
"6cf4e93859f445f0b514c626c6b62cc5" | |
] | |
}, | |
"id": "PdIqVYuj6axU", | |
"outputId": "81b1d933-710c-4721-e89a-5e8001a8f771" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading data: 0%| | 0.00/30.9M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "ef0db216de374d8d9fe4ef7b7c27988b" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading data: 0%| | 0.00/5.18M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "ea998c61f56c49fdbb27229f15c84c3d" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "53af44e4c7dd4d6fb5e3b385bc5f7473" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "f68f3db300a947838e60acd6eef73f56" | |
} | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# define image transformations\n", | |
"transform = transforms.Compose([\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Lambda(lambda t: (t * 2) - 1)\n", | |
"])\n", | |
"\n", | |
"# define function\n", | |
"def transforms(examples):\n", | |
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n", | |
" del examples[\"image\"]\n", | |
" return examples\n", | |
"\n", | |
"transformed_dataset = (\n", | |
" dataset\n", | |
" .with_transform(transforms)\n", | |
" .remove_columns(\"label\")\n", | |
")\n", | |
"\n", | |
"# create dataloader\n", | |
"dataloader = DataLoader(\n", | |
" transformed_dataset[\"train\"],\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True\n", | |
")" | |
], | |
"metadata": { | |
"id": "SbnfSuvR6eel" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get a batch of data and check the shape\n", | |
"batch = next(iter(dataloader))\n", | |
"print(batch.keys())\n", | |
"print(batch[\"pixel_values\"].shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GneDUV5J74KJ", | |
"outputId": "7e5a4cb6-73ee-4b2f-d335-4922d096e4dd" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['pixel_values'])\n", | |
"torch.Size([128, 1, 28, 28])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Forward Diffusion Process" | |
], | |
"metadata": { | |
"id": "Ggg6okjO8NFS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"id": "9_7ipOKVIbFn" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Define a linear schedule\n", | |
"def linear_beta_schedule(timesteps):\n", | |
" beta_start = 0.0001\n", | |
" beta_end = 0.02\n", | |
" return torch.linspace(beta_start, beta_end, timesteps)\n", | |
"\n", | |
"# define beta schedule\n", | |
"betas = linear_beta_schedule(timesteps=timesteps)\n", | |
"\n", | |
"# define alphas = 1 - beta\n", | |
"alphas = 1. - betas\n", | |
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n", | |
"\n", | |
"# batchify the cum prods\n", | |
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).repeat(batch_size, 1)\n", | |
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).repeat(batch_size, 1)\n", | |
"\n", | |
"# define the extraction function\n", | |
"# essentially all the alphas are cumprods\n", | |
"# we need to extract the alpha according to the\n", | |
"# time steps\n", | |
"def extract(a, t, x_shape):\n", | |
" batch_size = t.shape[0]\n", | |
" out = a.gather(dim=-1, index=t.cpu())\n", | |
" return out.reshape(\n", | |
" batch_size, timesteps, *((1,) * (len(x_shape) - 2))\n", | |
" ).to(t.device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# forward diffusion\n", | |
"def q_sample(x_start, t, noise=None):\n", | |
" if noise is None:\n", | |
" noise = torch.randn_like(x_start)\n", | |
"\n", | |
" sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)\n", | |
" sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n", | |
"\n", | |
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise" | |
], | |
"metadata": { | |
"id": "pmHZGKa8Lhbz" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# create a batch of noise images\n", | |
"t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1)\n", | |
"input_images = q_sample(\n", | |
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, timesteps)\n", | |
")" | |
], | |
"metadata": { | |
"id": "uWgxpZPj8sY6" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Viz the forward process" | |
], | |
"metadata": { | |
"id": "aQ0H8DVxEJqF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"idx = random.randint(0, batch_size-1)\n", | |
"for i in range(5):\n", | |
" plt.subplot(1, 5, i+1)\n", | |
" plt.imshow(input_images[idx, 100//5 * i].permute(1, 2, 0), cmap=\"gray\")\n", | |
" plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"id": "I-hoBcj3DNSf", | |
"outputId": "f63af000-1c1b-40f2-caa3-9d5db9eebd5c" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 5 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnT0lEQVR4nO3dV9DlRbU28IUBc85ZFCMqggEEhjDkISPISChF1DJeiFqWXlreKJZVIlWocAEjJdGRAcaBQXKQKEkRRcyYc86cK/v8/v29PfW+s/dXp+qc57la887e+9/da3XvXc+z1upN7rvvvvsqCIIgCIL/07jf//QAgiAIgiD4n0d+EARBEARBkB8EQRAEQRDkB0EQBEEQBJUfBEEQBEEQVH4QBEEQBEFQ+UEQBEEQBEHlB0EQBEEQBFX1gMW+cPny5c1+4Qtf2OwHP/jBzf7Vr37V7L333rvZV155ZbP/8pe/NPtJT3rS8Hm//vWvm/3vf/+72Y95zGOafe655y74WTvuuOOCz3vGM57R7LvvvrvZT37yk5v9t7/9rdk//vGPm33FFVdMxnfooYc2+8UvfnGz161bt+BYv/Od7zT7iU98YrNPO+20mhU77bRTs5///Oc3e+Sb3XffvdnXXntts10rx9jjN7/5TbP1zeMe97hmr127ttmPfexjm73DDjs0+1//+lezn/KUpzT7G9/4RrP1zV//+tdm//SnP11wDlVVBxxwQLON1YsvvrjZj370o5v9ve99r9lPeMITmn3mmWfWLHjNa17T7M0337zZD3rQg5ptnO+yyy7NvvHGG5utXx7/+McPn/e73/2u2frFNbzoooua/fCHP7zZ2267bbMf+MAHNtv1+PrXv77g390zP//5z5t9ww03TMbnmfC85z2v2ZdffnmzH/nIRzb7hz/8YbONLff9xuIVr3hFs5/1rGc1e9NNN232b3/722a7PrfffnuzjUnjvMfvf//7ZtsLzjPpqquuavYDHvDfR/NWW23VbNdB/91yyy0LjuMf//hHs3/5y182+7bbbpuMb+edd272s5/97GZfd911zdY39957b7M959avX1+zwLjwTHc9/vSnPzX7RS96UbM9Y43JRz3qUZNnuDfcW/rFM/2OO+5o9h/+8Idme9a+5CUvabZr7tn0iEc8otn//Oc/m+2+dQ5VVVtuuWWzXY+77rqr2Z7znifGx80331yLQRiCIAiCIAjygyAIgiAIgiVIBlKs97///Zst9SF1Id0n1SgtKoVVNaXP7rzzzmZLM0t9H3PMMc2W9vVzpGAct69xHH/84x8XfK6UXFXVy172smY7v7///e/N3nrrrZut/CCNPQ9Iid3vfv/9G2/kG+k+6Xmllt43T3va05otXSUtLlX29re/vdnf/e53m/2CF7yg2T/4wQ+aLb31zGc+c8FxSBVut912zZbSq5rSd85P37z85S9vtvKDNOCsGPlFqUS//OQnP2n2t771rWa7xsZ/VdVTn/rUZiuDvepVr1pwHMcee2yz3RtSpMaqNKXUpn7585//vOBz3RdV07h3fsbNS1/60mYrPzznOc+pecJ4E+4ZpR3HYjxL5ysxVE1lN98jDWysKst++9vfbrY0upSw8sEvfvGLBW0pcc8sY6Wq6rnPfe6CY9U3jkP/K7nMCvfDJpts0mz3jPKBa+4+dqzS/FXT7zL96hooO3qWffOb32z205/+9GYbNw972MOa7feSflHSMLbde1VTuc/5+Ty/y5zrhiT5EcIQBEEQBEGQHwRBEARBECxBMpB+MGP6hBNOaPa+++7bbOlFacSvfe1rze7pWekSaRezes1o3WOPPZotBSiFKUVtBvH111/f7Ne+9rXNlkrzuVKD/Wf5PDPwjz/++GYrlfjseUBqUtrrM5/5TLP33HPPZusbZQ1lGn1cNaWqzTD//ve/32yzkHfddddm6xuzsqXKpAHNsN9///2brVThWKWZq6a+kTKVXjzxxBObrVRy00031bxgRYCZ26eeemqzpYnNRJfelab0c6qmMoq+NyaVyqxIechDHtJs5RTlId/rsw488MBmKw1IdW+xxRaTsVq9YBwYj6ecckqzlUrci/OAmfhSrl/4whearSzl3JU+lF2UuqqmMe1cPNuUKH2NcoW+USYwK975SHd/9atfbbZ+klLvxyqd7fjOOeecZnseuhdnhWeI++fqq69utvKWcaQf/S7pK6ak+j2XlRLci54v+l4Zw3XyO2CvvfZqthURzsdKBCXAqumeUb7xXLM6xXPbs3mxCEMQBEEQBEF+EARBEARBsATJwKxs6RhlAhsvSDWaUSyd1TfH8BlmZEot21jl7LPPbraUl+OwCYVZq9JRP/vZz5otrXzwwQc3W6mjqmr16tXNNjNW2s8MU+nZ/rNmhZSWmerKBFJxZrzfc889zT7kkEOa3dOAZkabYSyFfckllzRb6tVYcBxmygrpT+NAGnzlypXN7rPZzz///GYrH7z61a9uthUO+saqhFkhtWmMKRMYq/rF2DnqqKOa3WeHS2Eq60iLSk/a1MfnWYnQN3JZ6O/GgHvs6KOPbnYf5/pFSlzZ6kc/+lGz3VfKJvOA1KpzUSaQ2vfs0JfveMc7mm31Tv8MzxUpbKuXbOalb9wzxotwrFLW/v3d7353s/vKHBukue5WQXi2+BqrEmaFlLm0vXtpVIngvD/84Q83W7mnairtSPUrbXo+XHrppQuO1XNKal949vks5b0VK1Y0u5fGjjvuuGb7Hee56753HFYlLBZhCIIgCIIgyA+CIAiCIAiWIBlYTSAtKC0hdWhvejNjv/KVrzRbOrhqSodJEUrvm6kpXSSVqsQgnWXGrfSxc7ARxJe+9KVmL1u2bDJW5yrtZ2a1TSKkJa1qmAc++9nPNvvII49sttSmGcnOUdv1sWFHVdVDH/rQZptJLm1mRYdQlpDqMgv21ltvbbYUsr6Rar3wwgubLc1bNa0asJ++vpG+k3KzqmFWrFq1qtnKMca9fnF+0ppKIr1flDuMdel994PP8/VKDDYVk1q28sHYsvLH+yK22WabyVi9Q8PMaONGetcmPGZrzwNKflK2xoJr5RlmHBnDvW+kb5VC3EtS5J4pfq6vkcpWilU6chzS4N7H4h6rmlaNuLeMHX2jBOJ+mxXXXHNNsz2jlQ9cJ+PcigHlxd4vQn8bh8o07rHRHQ76V9t95XeAkplVVX3TOr9rP/e5zzXb71crLfSR+3WxCEMQBEEQBEF+EARBEARBsATJwEYzZi1LUdhoSGpLesO/91e5SqmM+m5Lcx5++OHNNmPdBhM2YpFaUZKw6ZIyhLSyFGnVNAvf7Fupb8ckFSSFPg84R68rVbKwEYmUsJSgVJcUadWUvjXz2MYsZsZ7PbQNNaTipPr1jZKEMaJvjLueEjRelBysatE3xm1//egscH5eP2pcjfqYG//+XZqyappZrZwm5eyeUxKx4Y1y3Stf+cpm6xf3nvFhRrx+6feM49BnSiLGipnYG9NkZUPw+mNlEcdvo5tRlYjzMLarxvtPf0pBe8baHMiseiuK9J972hjRN9LofQa61Lm2soRUvdKf58ascJ2MN9fDxmjGv82SlOWk/KumMS2lrwyiX92LNgQSSnyeLa6l8o0VA35mv2c884wn19wGScqO/X00i0EYgiAIgiAI8oMgCIIgCIIlSAZmk0tXjJqvSMlJA40o6qoppSKFKQUm1XLeeec1WyrNTFUpZxv12D96RLdJ+dpop6pq5513brZZw2aqOj4pd3vFzwNKJKNrlqWP9I0Z09KAfZMV/aGfXC+bcNhkxRjxHgdpZ+lSZQ/lGLO+t91222avWbNmMlaznh2TFLQVJNKOyluzQr+MrmaVUpSql6aUjnQ9qsZ+MfvaNVi/fn2zlUqUNBy3a+lrRtVFxpbZ6lVV22+/fbOlgN0Pjk8K12qYecC1kl7ebLPNmm1ceB+LZ43+66WrkW+UFqSEr7zyymZ7pbANnvwcqziUPZyP9LX7UNm3/yx9ayz4Hu+fkMKfFZ6l+l9ZV8nNcSuZuQaevVXT9bdxk3EslBo9s5TA3GOeu8oertnoDoa+MZrN+oxB5SLjQ/lh1GBsQwhDEARBEARBfhAEQRAEQVC1yX3ysBuAjSv222+/ZpsV/OUvf7nZZjtKj0i1S/1UTWlEh2V2rBSWVJU0lxKAlJx0j/S42ZtSLqOGIFVTOlFK18ZLwkzhyy67rNnz6J3vfRI2cJEuvPzyy5s9yojdYYcdmu2aVE1945q6XlKkI98oaYwav4x6jRsH+qaXN/SNr7Mxi9A30raz9s/3em4lEfeM0pVUoxTkqEFL1dQvvsc9Jy2q9CfFPWpi5evNyNYvxoDUbk+hSwEr5fXSwn+gX1wnq1A2FjYaU8pwP3gV9uj6Yild465qSuvapGhUQeDZJuXt55pR7uuVuqyU0DfuyT6O9I3zO/PMM2shOG4bmllptDFQjlHOdN5WPljNJPWuPNjHoftPicK/u5d8jb7Qj6OKHc81feFZpuzafx0rMzi+T37yk7UQjDOrrfrqhRHCEARBEARBkB8EQRAEQRDkB0EQBEEQBLWEskM1d0vK1J5G93arXXkfu3pR1VRX8f1qQ5aQWEajZmq3OfU0X29JoNqoGo4lhH05mmV8jmnUjcrPNQdjHlBzv+iii5qtTuh6Wo6y6aabNvuCCy5otmtVNdWM9bMlY+prlrSpnakx6hv1c0sCzQFwDS0xUvetmvrGXAi1UbU29cm999675gXjyjvVR35xn6hhemFQf/nJaP9Zpudr1FbdM+bIuA8333zzZlsS6OfrF+Om10PVctVf1WhHftltt91qnnD8lik7d8flGeS4LMVzrarGXQzNr9I3noeOw7wb/25porkvxpGldH6+HQyrphqzMekaeJ75uf3lYrPAnBU7WBqr5i/YNdUzyvf2Zd6up3kY2u5Rfe/6e+77d/Nj7FxrDBhDluo6tv7fnoWe254V7jlzMBaLMARBEARBEOQHQRAEQRAES5AMpPikVqShpIAtTZPqkIrpS+78XJ9ntydpX6WBu+66q9nSgdJclt5JQUqzWD5y2mmnNdsypaop7WdZnbSVVLZ0kTTXPOCYXUOfI22vb5RCHHtf2qUU4vOUAPSNf7dMyHXTN1Kv3msuNSnNdtZZZzXbsryqKWVn+ZYxZTmPFGkvlcwC10mZRtrfUiLnNyoj68u6vFfdePMyIGlfpYHRJVL6xa6RlqW6Z/SRXSP7+9h9hrKOn2XnP/1iB8F5wLXyOdL+7hN9OSqp7rtICqUBy6K9eMfOrfpvVGp9yy23NNsul66n0oBl4UrA/TP0gfS8PjOm7D47K/SLz/A7RH+NpEz3Un8hkXPy/dLz7itlDH3sazwf3VdKp6MxKCv4mf3nKkG5Hn4P+v6N6boahiAIgiAIgvwgCIIgCIJgCZ0KP/CBDzTbjFYvarEbm1mrZnpLYW255ZaTZ5x88snN9iIiJQApSWkuM26FdJG0o/TL6J5pabieQpfykZ5yOaW27E4otX7++ecvOO6l4D3veU+zR1nPZs9LCSrhbMg3q1atarbdJu1MZ/WJvjHbXki3SmHqG+lyKb2tt9662b1vpGH1jbZzveqqq5otTblu3boFx71YvO1tb2u2lzoZhwcddFCzjUMlDbOIexr+9NNPb7aUvjSnVLG0qrSjUFoxw9q40S/udeNGKrNqvGekiY1Hu366Z5zPxuLwww9v9u23395sY/XII49ccLzS//qmlwI9q4xXfeB+9bxQlvBMUWLSf47JfTWqyup9I9U86kRq5zvlCv1qxcbGYPny5QuOUbnp/e9/f7ONC+eq7GVFUtV0v7sXfc+dd97ZbM8TfSeUzRzrSEIenX1KRVXTqhsrC6zosnLPs9DXW3WxIYQhCIIgCIIgPwiCIAiCIFhClYFZtmZ1S+NK2ZiRKj0r/dXfyS39LGUjTSadLMXmxTUf+tCHmi1Va1WDzXxGTW6k3HvaWwpROsa5SjVJ4Xqv9TwgxavMYUaslRDSTdJkUoVeJlM1vaTDqoGRb8yO1c/ve9/7mu3lMVaJSFNLc5rxrkTk66um6yEdrG/1jZeLSO3OCveMl+BIsVs14FpKFYreL479nnvuabaUpJUWztv1fOc739lss56lIKWobdp14403Nlt5r89kVxLxQhzlBylg59ZLJbPCWHD/e255z7zx7BrqJ9ezf4+VBVLKyjB+rpT8UUcd1WzPMGPHs1MZV186Hi/eqpqeD0o1niHKaZ69PSU/CzxX3ddKSTbqMi5GzdeMz6rpGe13lr5UsvPv7jElcv3rmnnuK094cZp7dZ999pmM1e8W5z2qTNBHfsctFmEIgiAIgiDID4IgCIIgCJYgGZhNKyUl9SGdJT1oBuzrXve6ZksbVk1paZ+hfKBcYVa1VKo916VhvdNcuk36UIr685//fLP75ha+zjWQJrUnv9SRY5oHpP19vnTvyDdSY4cddlizb7755skz9I398KW0lEK23XbbZivb2NPfddhmm22abaatr5GmtupBiaBqGjtmRktBGiNS0/P0jXtG+s7saWNb6lyKdOXKlc2WsqyazklKUglGuUIf6WMzr90P9lmXFnWdbL5ipZD0b9VU5nEeynE2D5P27e+rmBX6xgYuNm7SN/aqN3P8iCOOaLYSQ1XVVltt1Wz9r2+sTLDJ2x133NFspTXXwddLUytV2TTo+OOPb7ZnRtVUxtDPSkBKCf+/fCOdbxMmpRzPHGNY6v0Nb3hDs3uZ7dBDD2228/McN1ZdQ/euZ4uxrtSsJOQ5Zcx/8IMfbLayR9V0PZQaPY8925Xc+v23GIQhCIIgCIIgPwiCIAiCIFiCZCDFZ2WB1+2+9a1vbbZXTlpZIPXTZyFLj0iBmUErdWezHykvYbauGeRSP9JzNt2RjpWKqZpm5UopmWkvlbqhJkezwrseXDcbuBx99NHNHjWkcU595r5UltSVvlFWkYJWLhFWAEhrK+dI23qfhPFldnbV1AfGmLGn7GFcuJazQvrZNbOx1+tf//pmj64U9nN8TdV0/a0mcD9IpUr7KpUIKVlpaaUc6VLlHun33i+OTwpdWlV/mUHeN2yZFTYgc1w33HBDs/fff/9mS18L6VqbcVVNfWA2u+suva+EKiUvrPDxTgzPZCl/19BmRzYT6sfn+WBFkmfC6Nmzwj2q1KzMohzpmvt6ZQKrp6qmZ5nyqU2HlEGMdbP4R+P2vca28qxr6XhsJtSPb9TwSL/6uY5psQhDEARBEARBfhAEQRAEQbAEycDmEzZF2HHHHZttlu0WW2zRbDPyzbzfb7/9Js+QnpS2lB4+++yzm22mpRSkFJaZlieddFKz3/KWtzTb7FJpbD9H6qZ/3V577dVsJQ2pJtdPCnceMPte+WO77bZrttS7Y1HWUP7x7oOqKUXovKSwv/jFLzZbmcCMWiUN+3+feuqpzTZDeCTtKCv0vnGuXnOtb5TAnNs8fbOYShkpYH3he5XGbIZSNa0skAZ3b3jHhHS11LCShnLFGWec0WyrHZQx9It/d72rprSv54ZrrmTjGkitzwOum/E8uovB1ytrKMFsv/32k2dI5UoX2yDIqhspa/fViPq94IILmr3vvvs2W8lM33ie9VVT7gfjU98oRbgG8/SN8zaelWWVNn29tlJOf6W50rRSsNS933GjBnNWAgllJ+UN94PjU1bo11Kfja5NV/JSQtyY6o8wBEEQBEEQ5AdBEARBEARLkAxstiDlIq0mnWIWpNmfZrv3DSOkkM2yld6y2YdQGpD+Gt0hYKa3Gbc2h1GG6Ck2s0HNYjUbWfpxQ1nxs0I6zbGYmW02rr6R9vK6V5sJVU37eTtHKUXnK+2sb6SEHZOZ/l6hqm8cq+vZ0/z29hbS6MpTG8qMnwWjXuJSr8a5lKJztZGKTWr697smNkGx+c3IL1L9fqY+lQo129qxKkP0fhndm6GMoZwl/TnvPeNaORYpWilk56JEZe956fmqqYQ6kj+UiYS+uffeexcck1S4EqaShFKF69k381qMb/wO8AzoKxZmgXtUWAVhLDgPz2jP+r5ySInYChnPReND+B2llON+UE7xu8jzTpnANe5pfhto6RelDvece7dvPrUYhCEIgiAIgiA/CIIgCIIgWIJk4BW2ZkZL8UldSEFK5UhZmg1bNaV/pFGkdH2PVLmfKw0n5eK4bfIgtSKt5hx6Kmd0ZbIZn2Zu9/LIf/Cud71rwb8vBWY62/TI6gPXUMpamnrUQKpqmhUrfWcm9cg30mxSp0pPjluZR/rfv/v6vme3/vfZjk+7l0f+g/e+970L/n2xkN6XtpR6Nd6khqWo9YtrXzWNafeZ+0df2qPdtTGL3j2jJOhn6heboUib+/lV03mPqFft/j6NeeK2225rtvvBdfCMcC97TjkPpdGq6dopuUn36k/X0c/17/rG80k6WjnTzHtjsPfNKC5sTOR+7eWRecExukedh9UYjm90D4lSQNVUFrJiwXkrg7lHPWs8+0Z3Dijx+D2jv4wn93r/DP9P2Up71sZqYQiCIAiCIMgPgiAIgiAIliAZSGEuX7682dddd12zpdjM3pT6sZmGGcxVU8pHylkqVcp51O/dRkGrV69uttclr1mzZsH5jGi1vgmFTTMchz35zzrrrGZLeY0qJTYW+saGLzfeeGOzR9UWShw2JjIzvWpKVUrJuw5Swq6JzYhsrHPeeec12wYeNq/ybgllgg35xnhxHFYW2ERJStDe77NCKtCrt71fQxreuDAL+ZJLLmm2zUmqpvTpyC/S4K6H1KZxYxx4B4d3Y7iXlC38/N4vUrX6zyZm+t7zxNfMA541XvFsQzQztn2+NLXnn/djVE3jSh8oUViJMpLsbBR0xRVXNNt4VpL07B2teU+je0Y7Dve0DbI8T0aVEhsDKzA8662YUnIxo985KXH2FQOO3XPZZ+g7YUw7bxuP+d3l35Vtnafr3VezKUP5OudkEyUlChvCLRZhCIIgCIIgyA+CIAiCIAiWIBmYFek1nTaUkV7y+lvfK9UhDV01zZYcXZUqrSN1J/UjfSaNJx0jRarcIG27fv36ZvfXH0u/OSc/y+ugpaildeYBaWezp5UvpLFcH2lRKa3eN9Lf+kZaSipUqkv6zYx+x6Ts4R0MrqfzUXrqrySVWnNO0ug2yDr//PObbVOXWeHa2qRGOUbK08Y/xoix3TcbscJFmtn3mEltoxqby7inzTjXL1LXyjJKf8oHG/LLqLrC66AvvPDCZvf7b1a4vjaqsTJnMbSsska/Z0bZ+qPKHM8/G/SY0e9+c08qdelvY8291zd60jfOyfF5HbTSxTybRulnzxN94fis5HAcSmlWAFRN95ASgPvVv/u94ed6n4pjUg7wjFMGNM68n6e//liMqiuUI21QNWrQtiGEIQiCIAiCID8IgiAIgiBYgmQgfWNWvs1XpMhsxGFjDbM/++Y3UjbSHVIwUqE2idCW5vQZygE25PH+Ailc53zIIYdMxmrTFGkaqT5pc+lZs8znAce50047LfgcfSNlJgUpdSq9VTVdd2ktactRAw9tX29TFjOm9Y10nb6RZj7wwAMnY3XeZo2PfCPdpy9nhRneyiCj5j3uGSlS+6H3jYnMxNYvZtEbe/rCZ+sX6Uiz691X+sX97ZxXrFgxGavz9hzQL0oa/l0/zgNWWFhJ4RXNoz0j5a+E2cOKEOUE19E5+jxtz79RlrvxrG+UW6X/d9lll8lYnbc0vGeykoYysDLRrDD2RlUzQjnAMenfHlYZiJE0oJyiX5TrPC89U/1ucY+5b/VdX+Xk2W61ltKK41by9IxcLMIQBEEQBEGQHwRBEARBECxBMjj44IObbTMOaTGpQ6k/M+ylPG1GUzXN3pX+lMaV8rTBjk1uvGfAz5HKkT6zEsG/77bbbs22YUrVlLKRfpOmMcva9XDc88B+++3XbCUcqz70jbSj6yZ11Y9R2sw1tfHGqEGTa2e8jHwj/aZMoPxjw6KLL754MlazjZctW9Zs6Wx947ylj2fFnnvu2WxlDOPNKggpWZsDGTt9UyvXXMlB+UeYdW7DI8fk5zg+q0uMJ9dPaeTyyy+fPNtrcn2dWfH6zjjtm5jNCquMjGHnpUxm0xuvTtdnXkdcNb2uVsnABlu+xsxzrwC36kPfKF1ItXunjGtobHv/SdXUN8aY56GvcT08b2eFlLlZ/M5VaUD5xe8JZUfPkKr/txrkP3CurrNnk9VCjsnPtOrLdVKqcNxKI0ppVVPp3PPcM8Fz1M81lheLMARBEARBEOQHQRAEQRAES5AMbNIjlSZFLqUrbWKmqlmvPXUjBSOtJn1jVqkZnNKza9eubbbZwdKlZsaa4f7pT3+62WbN22Ciakqx2jzHsUqxKj9I98wDl156abNdQ6lYpZqRb6SY+mYeZqqb9S5dZdatlJ1zX7du3YLvlW7V9/rs5JNPbraSjTFVNfXVyDfKElbN9P33Z4HNuaxkGPlFXyifSQFLMVdN/eI9HCO/jO69MIbcM9LYZtRLE69atarZZsT3vdSlbs0a1y9S396XIDU/D4zu+TAujAX9YSa3Z0R/nulzs8X1zajSw/sIvHretTLOXU8z2G16o/+UgqqmvjJD3+eNZM/+XoRZYLWDmf7KUj5P6U+5yTOux6jxkn5xn7l/jHurZkbrNKo+8F4Q19uqk6rp/vasMNb8HtT37vXFIgxBEARBEAT5QRAEQRAEQX4QBEEQBEFQS8gh2H777ZutdmJHNMtSLJdQ1953332b/fGPf3w6GLRLcxO8r1ytxnKcNWvWNPuYY45ptiWS6lNqYOrRXhThpTfOoR+rOqI6rlqXnQ37z5oVlm9aQqVv9IdamRqVeRif+tSnJs+w1MfyPcsLR90pLTt84xvf2Gx1XHM6vDBLjdDLddRGnUPV4rrK6Rsv9uk/axYYt5YBqkmO/KLOaY7DSSedNHmGOrJ7xtyLUQdEyzUPO+ywZlsiab6IJWH6xXmecsopC86hauoX/0/9Wz101AF0HjBvyM587hn94RqqT3tenH766ZNn2FnOPTPqLmlulvknlhW7v82PUdv2HHaeloG6n6umvnF+zsGOeJ6lG9Lrl4rNNttswWebayHU0s3zcD0uu+yyyXtGl+iZq6NfzDHx8jjPXXNEzEswv8kYsoTQnKL+4jv94v+Zm2CugDGxMRe1hSEIgiAIgiA/CIIgCIIgqNrkPludbQCWjlleKJ1pqZnlQ9KLvtcym6qqQw89tNmWvVmG5jOke6TApCOVDCx7sqRpdJmO9GV/uYa0kLTqUUcd1WzpJcu3pHg+8pGP1Kywa59UkhcdOUdLePSN7+19c8QRRzRbms65u3Z+luVDUr9KBso/UnF+puvm50uTVU1jTDp45cqVzbbjmBSu8/7Yxz5Ws0CZzZIhaWbnZJmbJVTupd4vb37zm5utX+zyNvK9lKK0qJKB0o+d7hy3trJMf/GPa+D/HXTQQc22U5v0sVR+L2dtDJyLZWLKVc5LmllK3vdacllVdeyxxzZb33iW6M/RvrQcVclAils5x890TEpjxkf/bKWnPfbYo9nudeNC35x66qk1C9yLxu1oj+ojaX4lQSW6qum+1i+eG87Jcfh6949nkN8/xrBddh3TOeec02zXuH+233fGr3KUEqlz6Lu5jhCGIAiCIAiC/CAIgiAIgmAJVQZ2krOzkt3OpBqlIM28N3O4v0DHu9eleqVPvQjCzk/Sz0oDvlcKxTlI9zhu6UsrCfpn+zxpUSUHKSKziecBLwxSypCydt2do9Std85LnVZN/e97pDal9cxmt8JCaWDkG6kxs3cdt5R/n80ulerzpHel4vTNPCtAjFXnNKoKUfaStjUO+0t+lKJcK2ltKxHsUHfrrbc223j2va6NlKo0pX6xmkKpqH/26LItJSIzy3sfz4rRnrVqxsoXaXuz36V43YdV026mrpfUtuvgGeNedN30jeujrOTZZkwpE5jxXjWNe58n9S49L/2t/2fF6EKpUTWNtL1roI/6S36MdWN6VJ3k/Pxe8u9WCwnPmVFXUuVPx1A1reDQdp8ooeiXvmJhMQhDEARBEARBfhAEQRAEQbAEyUAK2Lu0P/rRjzbbzEepa6lTacr9999/8gybzVhZ4LNtHCNFpG0DGzNrR3dTS4/bJEIqRkq7aloFIZUmHWUFhvNxbeYBKfLrr7++2Z/4xCea7Rylr6X5bTi1YsWKyTPOPffcZkuTSi+6pvrDDGbHocQktSY1JkXuuKWy++xuaTd97liVuqQBfcaskAI2Dk888cRmmx0udShFqqxg1nfVtOmT0pzPdj31hRKK4zAOpCb9HCUkqVBp7N4v0tTKaY7VCgylKddmHvAckZ4/7bTTmq0c47rZbObuu+9uttU+VdNGQKMsdKlj/aGM4jhsCOT55OcoVRj/nln9eaZv3IuO1b0rbe/azAolDv3v5VtKAM7POLKSw8ZZVdP4VspR4lIK1Rc+z+8lxzqi7R2f+0TptF9Lqxr0t+vkGe4+2ZhLp8IQBEEQBEGQHwRBEARBECxBMrAxh5SUtKAZ3WZ/SjHvvffezT7jjDMmz7D5kZSwNL5UjvSZVK+fc9FFFzVb6kj6WApfClAKXKmjakplS6vddNNNzV62bNmCn9Vnks4KfWMjDNdQms27G6TLXbfVq1dPnrHPPvs0W0ps1Idbnysf+AwpVelIM431jWOVvpYCrJr2b/d1+lBqWvqtz76eBVJ20szStdKO+kUK0gZTa9eunTzjkEMOabY9xlw35ze6N2DXXXdttrKT0oBx7n6TDh+td9W0v7xj0n9WHvkaqdN5QFrcs8DYNg49d6TLt9lmm2Ybz1VVb3rTmxZ8trSu0oU+k1JWelQ+8r4K6e5RNYiyglVZVVPa2ddZNeK+8nyf53mm/GoTM8+yUXa/87biQ0moatoIy7hSsvG7TBrfdXZfesYp4SrT+DlWBPmZSotVUwna88H3e+Ybp5tsskktFWEIgiAIgiDID4IgCIIgCJZwl8Hhhx/ebCkl6SIz6aVspEulGv171ZR2MXtamcFmIVJsBxxwQLPNiJeSs6mH0/YqyquvvrrZ0m3SnVVTetz1cA1cG58nPTePuwy8A8JxSX+bge58lRKkMqWqqqaZtvbFHvU6l75TbvBKaX1jsxbXygYy1157bbP1jVnYVVNqTdrZ8Y0oaOWw4447bsHXLBbGpH7x2UoDVhYoJegX6ciqKfV9zTXXNHuXXXZptg1prAjQd+vXr2+2lKwxJJ0upepzbQjjWlZNaWbXw/GN9oxxcMIJJ9Ss2H333ZttvEjxGmPO16Y10sB98zJpbuWT7bbbrtmul5S1VTBXXnllsz1rpIRdKyUQ94zPcj5VU6pZPykxuTbCs2LVqlULvmax2GqrrRYck8/2bHHN3evKdf31zH4PuCZKIvrCShulYpujjc4v94y+s1LPRlJ9YzTHqjzimTDyixJf7jIIgiAIgmDRyA+CIAiCIAgWLxkEQRAEQfC/F2EIgiAIgiDID4IgCIIgCPKDIAiCIAiCyg+CIAiCIAgqPwiCIAiCIKj8IAiCIAiCoPKDIAiCIAiCyg+CIAiCIAgqPwiCIAiCIKiq/wKl5x+1j31Q2gAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Define the RNN goodies" | |
], | |
"metadata": { | |
"id": "-CeUWnzpEoqb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageEncoder(nn.Module):\n", | |
" def __init__(self, in_channels):\n", | |
" super().__init__()\n", | |
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(32)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(64)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(128)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Global Average Pooling to reduce spatial dimensions to 1x1\n", | |
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.gap(x)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "z81MMxklEqhO" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageDecoder(nn.Module):\n", | |
" def __init__(self, out_channels, initial_height, initial_width):\n", | |
" super(ImageDecoder, self).__init__()\n", | |
" self.conv_transpose1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(32)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(out_channels)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Additional layer to ensure correct output dimensions\n", | |
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n", | |
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv_transpose1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv_transpose2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv_transpose3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "gzQ35KxeE2Cz" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CustomRecurrence(nn.Module):\n", | |
" def __init__(self, in_channels, initial_height, initial_width, hidden_dim, num_layers=1, training=False):\n", | |
" super().__init__()\n", | |
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n", | |
" self.positional_encoder = nn.Embedding(timesteps, hidden_dim)\n", | |
" self.rnn = nn.RNN(\n", | |
" input_size=128, # hardcoded for the time being\n", | |
" hidden_size=hidden_dim,\n", | |
" num_layers=num_layers,\n", | |
" batch_first=True\n", | |
" )\n", | |
" self.image_decoder = ImageDecoder(\n", | |
" out_channels=in_channels,\n", | |
" initial_height=initial_height,\n", | |
" initial_width=initial_width\n", | |
" )\n", | |
" self.training=training\n", | |
"\n", | |
" def forward(self, x, hidden_states=None, timesteps=timesteps):\n", | |
" batch_size = x.shape[0]\n", | |
"\n", | |
" # x : (b, t-1, c, h, w)\n", | |
" x = x.reshape(batch_size * (timesteps-1), channels, image_size, image_size)\n", | |
"\n", | |
" latent_vectors = self.image_encoder(x) # (b*t, c, h, w)\n", | |
" latent_vectors = latent_vectors.reshape(batch_size, timesteps-1, -1) # (b, t, c*h*w)\n", | |
"\n", | |
" pos_embeds = torch.arange(timesteps-1).unsqueeze(0).repeat(batch_size, 1).to(device) # (b, t)\n", | |
" latent_vectors += self.positional_encoder(pos_embeds) # (b, t, c*h*w)\n", | |
"\n", | |
" if self.training:\n", | |
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, c*h*w)\n", | |
" else:\n", | |
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, c*h*w)\n", | |
"\n", | |
" rnn_outputs = rnn_outputs.reshape(batch_size * (timesteps-1), 128, 1, 1) # (b*t, c, 1, 1)\n", | |
" reconstructed_x = self.image_decoder(rnn_outputs)\n", | |
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps-1, channels, image_size, image_size)\n", | |
"\n", | |
" if self.training:\n", | |
" return reconstructed_x\n", | |
" else:\n", | |
" return reconstructed_x, hidden_states" | |
], | |
"metadata": { | |
"id": "chqbXzZaE5Ok" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Training Loop" | |
], | |
"metadata": { | |
"id": "KD_os31TMhEN" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = CustomRecurrence(\n", | |
" in_channels=channels,\n", | |
" initial_height=image_size,\n", | |
" initial_width=image_size,\n", | |
" hidden_dim=hidden_dim,\n", | |
" training=True, # important parameter\n", | |
")\n", | |
"model.to(device)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(\n", | |
" model.parameters(),\n", | |
" lr=1e-3\n", | |
")\n", | |
"\n", | |
"loss_fn = nn.MSELoss()" | |
], | |
"metadata": { | |
"id": "TYo28dSj49K_" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for step, batch in enumerate(dataloader):\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" batch_size = batch[\"pixel_values\"].shape[0]\n", | |
" batch = batch[\"pixel_values\"].to(device)\n", | |
"\n", | |
" t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1).to(device)\n", | |
"\n", | |
" input_images = q_sample(\n", | |
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, timesteps)\n", | |
" )\n", | |
" reconstructed_images = model(input_images[:, :timesteps-1, ...])\n", | |
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n", | |
"\n", | |
"\n", | |
" if step % 100 == 0:\n", | |
" print(\"Loss:\", loss.item())\n", | |
"\n", | |
" loss.backward()\n", | |
" optimizer.step()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ImaLvHwjMiDf", | |
"outputId": "444f58e7-29ac-4fda-d74f-6714344ba205" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Loss: 1.4240089654922485\n", | |
"Loss: 0.7346835732460022\n", | |
"Loss: 0.7231897115707397\n", | |
"Loss: 0.7261417508125305\n", | |
"Loss: 0.7263044714927673\n", | |
"Loss: 0.7309923768043518\n", | |
"Loss: 0.7178468704223633\n", | |
"Loss: 0.7321939468383789\n", | |
"Loss: 0.706930935382843\n", | |
"Loss: 0.7191136479377747\n", | |
"Loss: 0.7089812159538269\n", | |
"Loss: 0.7193553447723389\n", | |
"Loss: 0.7181675434112549\n", | |
"Loss: 0.7335854172706604\n", | |
"Loss: 0.7130610942840576\n", | |
"Loss: 0.714571475982666\n", | |
"Loss: 0.7261708378791809\n", | |
"Loss: 0.7198599576950073\n", | |
"Loss: 0.7220982909202576\n", | |
"Loss: 0.7202709317207336\n", | |
"Loss: 0.7199481129646301\n", | |
"Loss: 0.71610426902771\n", | |
"Loss: 0.6913580894470215\n", | |
"Loss: 0.7348539233207703\n", | |
"Loss: 0.7180370688438416\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Inference Loop" | |
], | |
"metadata": { | |
"id": "PP6CFN4vMidh" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.training=False\n", | |
"model = model.eval()" | |
], | |
"metadata": { | |
"id": "iNwii9gJ-Utu" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"with torch.no_grad():\n", | |
" current_input = torch.randn(\n", | |
" 1, 1, channels, image_size, image_size\n", | |
" ).to(device)\n", | |
" hidden_state = None # Start with no hidden state\n", | |
" generated_sequence = []\n", | |
"\n", | |
" for _ in range(timesteps):\n", | |
" # Forward pass\n", | |
" output, hidden_state = model(current_input, hidden_state, timesteps=2)\n", | |
"\n", | |
" # Use output as next input\n", | |
" current_input = output" | |
], | |
"metadata": { | |
"id": "8NxLub4ZMjpH" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rNmgRa7m-ZEW", | |
"outputId": "8cafec2a-40ec-4e8f-ca55-94d6e78512ea" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1, 1, 28, 28])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plt.imshow(output[0, 0, 0].detach().cpu().numpy(), cmap=\"gray\")\n", | |
"plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 406 | |
}, | |
"id": "OkJeUd45JooW", | |
"outputId": "be0a63f0-2584-43f2-bee7-a5fb0aa0bd2d" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHXUlEQVR4nO3cMY6NfR/HYUcmkZEggw6NiERiDyqFBViDBViBxjroprQAQUGrkikUFoDMIGGGDOftPnkLuTPHc875z/O4rvr+Jd8pZj7uwj2bz+fzEwBw4sSJk6MHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQDaO+uBsNlvlDhju6tWroydM2tvbW8vNumxtbY2eMGl/f3/hm4ODgxUsWZ6j/F9lbwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD0Ajou9vb3REybt7++PnrBUx/3nOTw8HD1hCG8KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgs/l8Pj/Sg7PZqrcAsEJH+XPvTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4vSdPnoyeMGl7e3stN+v0+fPn0RMm7e7uLnyzt7e3giXL8fDhw9ETJu3s7Kzl5rjxpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4ve3t7dETJr1582bhm5Mnj/e/QR49ejR6wqRPnz6t5WZddnZ2Rk+Y9OHDh9EThjjev6UArJUoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAZvP5fH6kB2ezVW/hX+RPPm533D+Id/PmzdETJn358mUtN+vy8ePH0RP+Okf5c3+8f0sBWCtRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAZGP0AH7v2rVroydM2t3dXcvNOt25c2f0hEnfvn1b+Obr168rWLIcV65cGT1h0vPnzxe+efHixfKHrJk3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGwc9cFTp06tcsc/dnh4uPDNz58/V7BkOXZ3d0dPmLS/vz96wtK9evVq9IRJv379Gj1hqc6dOzd6wqSDg4PRE4bwpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPKf+SDen/BBPP7fy5cvR0+YtLm5uZabdblx48boCZN+/PgxesIQ3hQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD1gWW7fvr3wza1bt1awZDkeP348esKk9+/fr+Vmne7duzd6wqTz588vfLO1tbWCJctx//790RMmPXjwYOGb169fr2DJenlTACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyMZRH/z+/fsqd/xjb9++XfjmOP9MFy9eHD1h0oULFxa+uX79+gqWLM/ly5dHT5h09uzZhW/OnDmzgiXL8ezZs9ETJr179270hCG8KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgPzVH8T7k5t1uXv37ugJk/7kg3h/crNOly5dGj1h0unTpxe+2dzcXMGS5Xj69OnoCZN8EA+Av54oABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAMpvP5/PRIwA4HrwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ/wHzssF0wLHC6wAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": {} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment