Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created April 16, 2024 18:44
Show Gist options
  • Save ariG23498/3043580b7f73313f6657b22c77988079 to your computer and use it in GitHub Desktop.
Save ariG23498/3043580b7f73313f6657b22c77988079 to your computer and use it in GitHub Desktop.
rnn-diffusion.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyP8xftETb2XVvWYheumJWH6",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"ef0db216de374d8d9fe4ef7b7c27988b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e096b4adf8d74c5bb62dc62de43d5485",
"IPY_MODEL_5947389f26104af2aeadde1aadb56c46",
"IPY_MODEL_c9d86aca6d1b4e9092d087faea2a5f23"
],
"layout": "IPY_MODEL_9302b16aa22647d0ac292e492f944e32"
}
},
"e096b4adf8d74c5bb62dc62de43d5485": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d67abee0a02f41b8ac4d28ed89a43188",
"placeholder": "​",
"style": "IPY_MODEL_3ed6412412644d19bd435f5e688072e7",
"value": "Downloading data: 100%"
}
},
"5947389f26104af2aeadde1aadb56c46": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_329a1621b4c9484bae3089c3d0a4f30d",
"max": 30931277,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_161fb7ef927a4cb7873d14e0b5fa500b",
"value": 30931277
}
},
"c9d86aca6d1b4e9092d087faea2a5f23": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2aa04feb3e48483c84bcbde834897aad",
"placeholder": "​",
"style": "IPY_MODEL_c1d1f44a1aae457fafb01295c6c9c6dc",
"value": " 30.9M/30.9M [00:02<00:00, 16.9MB/s]"
}
},
"9302b16aa22647d0ac292e492f944e32": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d67abee0a02f41b8ac4d28ed89a43188": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3ed6412412644d19bd435f5e688072e7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"329a1621b4c9484bae3089c3d0a4f30d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"161fb7ef927a4cb7873d14e0b5fa500b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2aa04feb3e48483c84bcbde834897aad": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c1d1f44a1aae457fafb01295c6c9c6dc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"ea998c61f56c49fdbb27229f15c84c3d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_b1d6c88baf604285aa06840605b2b08e",
"IPY_MODEL_ada4f71b6a4c449ca86b21449428bf9b",
"IPY_MODEL_1828031ff524483aa28bd4b8c6f35981"
],
"layout": "IPY_MODEL_ff984182fc064033b3095ca08c8220b2"
}
},
"b1d6c88baf604285aa06840605b2b08e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_756a2563ac4147d0bf4376a52cd26870",
"placeholder": "​",
"style": "IPY_MODEL_707a7542d7d848cdb1f26cccf6e778d5",
"value": "Downloading data: 100%"
}
},
"ada4f71b6a4c449ca86b21449428bf9b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fa86fdcadc8c4015987dad3a05152e0e",
"max": 5175617,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_adc80c861a914013a40393c44ac50210",
"value": 5175617
}
},
"1828031ff524483aa28bd4b8c6f35981": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fbfc5527b1c14d7aa42295d35e2b3620",
"placeholder": "​",
"style": "IPY_MODEL_f107f457eba6446490b5adb8ac983869",
"value": " 5.18M/5.18M [00:00<00:00, 4.38MB/s]"
}
},
"ff984182fc064033b3095ca08c8220b2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"756a2563ac4147d0bf4376a52cd26870": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"707a7542d7d848cdb1f26cccf6e778d5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"fa86fdcadc8c4015987dad3a05152e0e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"adc80c861a914013a40393c44ac50210": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"fbfc5527b1c14d7aa42295d35e2b3620": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f107f457eba6446490b5adb8ac983869": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"53af44e4c7dd4d6fb5e3b385bc5f7473": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_ad5d7fa88a2d48ffbbd5148b1d726d36",
"IPY_MODEL_2ce5f202a0ad4c59b3dd54e7778e5ac7",
"IPY_MODEL_5235bd0e9078423489d90061ebb23a68"
],
"layout": "IPY_MODEL_d4426c5e37d04d55bb8107c8672311dc"
}
},
"ad5d7fa88a2d48ffbbd5148b1d726d36": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3946b48ad5cd496b8b7323ffd586133c",
"placeholder": "​",
"style": "IPY_MODEL_958423b6803641dc916e1d5669cfad68",
"value": "Generating train split: 100%"
}
},
"2ce5f202a0ad4c59b3dd54e7778e5ac7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4e06d3fa7b594b34abf19e44de503e87",
"max": 60000,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_48cb9ee8be1e4d938789a90defbfb704",
"value": 60000
}
},
"5235bd0e9078423489d90061ebb23a68": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_809081ccc38a48f9869cf0ef956ed24c",
"placeholder": "​",
"style": "IPY_MODEL_e840f8e33dd0434fa8bd6d3bcd5efda7",
"value": " 60000/60000 [00:00<00:00, 90943.69 examples/s]"
}
},
"d4426c5e37d04d55bb8107c8672311dc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3946b48ad5cd496b8b7323ffd586133c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"958423b6803641dc916e1d5669cfad68": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4e06d3fa7b594b34abf19e44de503e87": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"48cb9ee8be1e4d938789a90defbfb704": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"809081ccc38a48f9869cf0ef956ed24c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e840f8e33dd0434fa8bd6d3bcd5efda7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"f68f3db300a947838e60acd6eef73f56": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9711ce9ed95f4c57a96cc42c419a6535",
"IPY_MODEL_f08c421ac9634861bb7104f9e1f3e5ce",
"IPY_MODEL_08d967d6785a4a8c861f23d8259f2ac2"
],
"layout": "IPY_MODEL_7fbca9fc1fb0413b8848723c8c286b3c"
}
},
"9711ce9ed95f4c57a96cc42c419a6535": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c9cc94dc9efd4b73ac5e196de8914180",
"placeholder": "​",
"style": "IPY_MODEL_243b70f554464303843eff056b169773",
"value": "Generating test split: 100%"
}
},
"f08c421ac9634861bb7104f9e1f3e5ce": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e101f5e375c74a13887a46ce51c1bc0d",
"max": 10000,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b96c61f62aaa4666aa46211b53e7fd62",
"value": 10000
}
},
"08d967d6785a4a8c861f23d8259f2ac2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ecdfef650f07497e975f0edec0153ad4",
"placeholder": "​",
"style": "IPY_MODEL_6cf4e93859f445f0b514c626c6b62cc5",
"value": " 10000/10000 [00:00<00:00, 53018.19 examples/s]"
}
},
"7fbca9fc1fb0413b8848723c8c286b3c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c9cc94dc9efd4b73ac5e196de8914180": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"243b70f554464303843eff056b169773": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"e101f5e375c74a13887a46ce51c1bc0d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b96c61f62aaa4666aa46211b53e7fd62": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"ecdfef650f07497e975f0edec0153ad4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6cf4e93859f445f0b514c626c6b62cc5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/3043580b7f73313f6657b22c77988079/rnn-diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion"
],
"metadata": {
"id": "02vydzVb6VWO"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup and Imports"
],
"metadata": {
"id": "-CFKktKV2qqo"
}
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade -qq datasets"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BmiGcVlQ64dD",
"outputId": "2201fe6e-d14c-4d0e-f19a-a53531bc30ed"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "hzUHzDmy2kOy"
},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchvision import transforms"
]
},
{
"cell_type": "markdown",
"source": [
"## Configurations"
],
"metadata": {
"id": "rlNouEPu7ART"
}
},
{
"cell_type": "code",
"source": [
"batch_size = 128\n",
"image_size = 28\n",
"channels = 1\n",
"timesteps = 100\n",
"hidden_dim = 128\n",
"epochs = 5\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"{device=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3HNObPXp7Bff",
"outputId": "264c2094-1e45-44c9-f7c5-a5bef28208a9"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device='cuda'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Dataset and Loaders"
],
"metadata": {
"id": "JGEzhe8s7B1p"
}
},
{
"cell_type": "code",
"source": [
"# load dataset from the hub\n",
"dataset = load_dataset(\"fashion_mnist\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 145,
"referenced_widgets": [
"ef0db216de374d8d9fe4ef7b7c27988b",
"e096b4adf8d74c5bb62dc62de43d5485",
"5947389f26104af2aeadde1aadb56c46",
"c9d86aca6d1b4e9092d087faea2a5f23",
"9302b16aa22647d0ac292e492f944e32",
"d67abee0a02f41b8ac4d28ed89a43188",
"3ed6412412644d19bd435f5e688072e7",
"329a1621b4c9484bae3089c3d0a4f30d",
"161fb7ef927a4cb7873d14e0b5fa500b",
"2aa04feb3e48483c84bcbde834897aad",
"c1d1f44a1aae457fafb01295c6c9c6dc",
"ea998c61f56c49fdbb27229f15c84c3d",
"b1d6c88baf604285aa06840605b2b08e",
"ada4f71b6a4c449ca86b21449428bf9b",
"1828031ff524483aa28bd4b8c6f35981",
"ff984182fc064033b3095ca08c8220b2",
"756a2563ac4147d0bf4376a52cd26870",
"707a7542d7d848cdb1f26cccf6e778d5",
"fa86fdcadc8c4015987dad3a05152e0e",
"adc80c861a914013a40393c44ac50210",
"fbfc5527b1c14d7aa42295d35e2b3620",
"f107f457eba6446490b5adb8ac983869",
"53af44e4c7dd4d6fb5e3b385bc5f7473",
"ad5d7fa88a2d48ffbbd5148b1d726d36",
"2ce5f202a0ad4c59b3dd54e7778e5ac7",
"5235bd0e9078423489d90061ebb23a68",
"d4426c5e37d04d55bb8107c8672311dc",
"3946b48ad5cd496b8b7323ffd586133c",
"958423b6803641dc916e1d5669cfad68",
"4e06d3fa7b594b34abf19e44de503e87",
"48cb9ee8be1e4d938789a90defbfb704",
"809081ccc38a48f9869cf0ef956ed24c",
"e840f8e33dd0434fa8bd6d3bcd5efda7",
"f68f3db300a947838e60acd6eef73f56",
"9711ce9ed95f4c57a96cc42c419a6535",
"f08c421ac9634861bb7104f9e1f3e5ce",
"08d967d6785a4a8c861f23d8259f2ac2",
"7fbca9fc1fb0413b8848723c8c286b3c",
"c9cc94dc9efd4b73ac5e196de8914180",
"243b70f554464303843eff056b169773",
"e101f5e375c74a13887a46ce51c1bc0d",
"b96c61f62aaa4666aa46211b53e7fd62",
"ecdfef650f07497e975f0edec0153ad4",
"6cf4e93859f445f0b514c626c6b62cc5"
]
},
"id": "PdIqVYuj6axU",
"outputId": "81b1d933-710c-4721-e89a-5e8001a8f771"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/30.9M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "ef0db216de374d8d9fe4ef7b7c27988b"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/5.18M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "ea998c61f56c49fdbb27229f15c84c3d"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "53af44e4c7dd4d6fb5e3b385bc5f7473"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "f68f3db300a947838e60acd6eef73f56"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# define image transformations\n",
"transform = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Lambda(lambda t: (t * 2) - 1)\n",
"])\n",
"\n",
"# define function\n",
"def transforms(examples):\n",
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n",
" del examples[\"image\"]\n",
" return examples\n",
"\n",
"transformed_dataset = (\n",
" dataset\n",
" .with_transform(transforms)\n",
" .remove_columns(\"label\")\n",
")\n",
"\n",
"# create dataloader\n",
"dataloader = DataLoader(\n",
" transformed_dataset[\"train\"],\n",
" batch_size=batch_size,\n",
" shuffle=True\n",
")"
],
"metadata": {
"id": "SbnfSuvR6eel"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Get a batch of data and check the shape\n",
"batch = next(iter(dataloader))\n",
"print(batch.keys())\n",
"print(batch[\"pixel_values\"].shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GneDUV5J74KJ",
"outputId": "7e5a4cb6-73ee-4b2f-d335-4922d096e4dd"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"dict_keys(['pixel_values'])\n",
"torch.Size([128, 1, 28, 28])\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Forward Diffusion Process"
],
"metadata": {
"id": "Ggg6okjO8NFS"
}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "9_7ipOKVIbFn"
},
"outputs": [],
"source": [
"# Define a linear schedule\n",
"def linear_beta_schedule(timesteps):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" return torch.linspace(beta_start, beta_end, timesteps)\n",
"\n",
"# define beta schedule\n",
"betas = linear_beta_schedule(timesteps=timesteps)\n",
"\n",
"# define alphas = 1 - beta\n",
"alphas = 1. - betas\n",
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
"\n",
"# batchify the cum prods\n",
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).repeat(batch_size, 1)\n",
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).repeat(batch_size, 1)\n",
"\n",
"# define the extraction function\n",
"# essentially all the alphas are cumprods\n",
"# we need to extract the alpha according to the\n",
"# time steps\n",
"def extract(a, t, x_shape):\n",
" batch_size = t.shape[0]\n",
" out = a.gather(dim=-1, index=t.cpu())\n",
" return out.reshape(\n",
" batch_size, timesteps, *((1,) * (len(x_shape) - 2))\n",
" ).to(t.device)"
]
},
{
"cell_type": "code",
"source": [
"# forward diffusion\n",
"def q_sample(x_start, t, noise=None):\n",
" if noise is None:\n",
" noise = torch.randn_like(x_start)\n",
"\n",
" sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)\n",
" sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n",
"\n",
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise"
],
"metadata": {
"id": "pmHZGKa8Lhbz"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# create a batch of noise images\n",
"t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1)\n",
"input_images = q_sample(\n",
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, timesteps)\n",
")"
],
"metadata": {
"id": "uWgxpZPj8sY6"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Viz the forward process"
],
"metadata": {
"id": "aQ0H8DVxEJqF"
}
},
{
"cell_type": "code",
"source": [
"idx = random.randint(0, batch_size-1)\n",
"for i in range(5):\n",
" plt.subplot(1, 5, i+1)\n",
" plt.imshow(input_images[idx, 100//5 * i].permute(1, 2, 0), cmap=\"gray\")\n",
" plt.axis(\"off\")\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"id": "I-hoBcj3DNSf",
"outputId": "f63af000-1c1b-40f2-caa3-9d5db9eebd5c"
},
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 5 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Define the RNN goodies"
],
"metadata": {
"id": "-CeUWnzpEoqb"
}
},
{
"cell_type": "code",
"source": [
"class ImageEncoder(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(32)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(64)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n",
" self.bn3 = nn.BatchNorm2d(128)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Global Average Pooling to reduce spatial dimensions to 1x1\n",
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.gap(x)\n",
" return x"
],
"metadata": {
"id": "z81MMxklEqhO"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ImageDecoder(nn.Module):\n",
" def __init__(self, out_channels, initial_height, initial_width):\n",
" super(ImageDecoder, self).__init__()\n",
" self.conv_transpose1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn2 = nn.BatchNorm2d(32)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn3 = nn.BatchNorm2d(out_channels)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Additional layer to ensure correct output dimensions\n",
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n",
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_transpose1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv_transpose2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv_transpose3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n",
" return x"
],
"metadata": {
"id": "gzQ35KxeE2Cz"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class CustomRecurrence(nn.Module):\n",
" def __init__(self, in_channels, initial_height, initial_width, hidden_dim, num_layers=1, training=False):\n",
" super().__init__()\n",
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n",
" self.positional_encoder = nn.Embedding(timesteps, hidden_dim)\n",
" self.rnn = nn.RNN(\n",
" input_size=128, # hardcoded for the time being\n",
" hidden_size=hidden_dim,\n",
" num_layers=num_layers,\n",
" batch_first=True\n",
" )\n",
" self.image_decoder = ImageDecoder(\n",
" out_channels=in_channels,\n",
" initial_height=initial_height,\n",
" initial_width=initial_width\n",
" )\n",
" self.training=training\n",
"\n",
" def forward(self, x, hidden_states=None, timesteps=timesteps):\n",
" batch_size = x.shape[0]\n",
"\n",
" # x : (b, t-1, c, h, w)\n",
" x = x.reshape(batch_size * (timesteps-1), channels, image_size, image_size)\n",
"\n",
" latent_vectors = self.image_encoder(x) # (b*t, c, h, w)\n",
" latent_vectors = latent_vectors.reshape(batch_size, timesteps-1, -1) # (b, t, c*h*w)\n",
"\n",
" pos_embeds = torch.arange(timesteps-1).unsqueeze(0).repeat(batch_size, 1).to(device) # (b, t)\n",
" latent_vectors += self.positional_encoder(pos_embeds) # (b, t, c*h*w)\n",
"\n",
" if self.training:\n",
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, c*h*w)\n",
" else:\n",
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, c*h*w)\n",
"\n",
" rnn_outputs = rnn_outputs.reshape(batch_size * (timesteps-1), 128, 1, 1) # (b*t, c, 1, 1)\n",
" reconstructed_x = self.image_decoder(rnn_outputs)\n",
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps-1, channels, image_size, image_size)\n",
"\n",
" if self.training:\n",
" return reconstructed_x\n",
" else:\n",
" return reconstructed_x, hidden_states"
],
"metadata": {
"id": "chqbXzZaE5Ok"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Training Loop"
],
"metadata": {
"id": "KD_os31TMhEN"
}
},
{
"cell_type": "code",
"source": [
"model = CustomRecurrence(\n",
" in_channels=channels,\n",
" initial_height=image_size,\n",
" initial_width=image_size,\n",
" hidden_dim=hidden_dim,\n",
" training=True, # important parameter\n",
")\n",
"model.to(device)\n",
"\n",
"optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=1e-3\n",
")\n",
"\n",
"loss_fn = nn.MSELoss()"
],
"metadata": {
"id": "TYo28dSj49K_"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for epoch in range(epochs):\n",
" for step, batch in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
"\n",
" batch_size = batch[\"pixel_values\"].shape[0]\n",
" batch = batch[\"pixel_values\"].to(device)\n",
"\n",
" t = torch.arange(0, timesteps).flip((-1,)).repeat(batch_size, 1).to(device)\n",
"\n",
" input_images = q_sample(\n",
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, timesteps)\n",
" )\n",
" reconstructed_images = model(input_images[:, :timesteps-1, ...])\n",
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n",
"\n",
"\n",
" if step % 100 == 0:\n",
" print(\"Loss:\", loss.item())\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ImaLvHwjMiDf",
"outputId": "444f58e7-29ac-4fda-d74f-6714344ba205"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loss: 1.4240089654922485\n",
"Loss: 0.7346835732460022\n",
"Loss: 0.7231897115707397\n",
"Loss: 0.7261417508125305\n",
"Loss: 0.7263044714927673\n",
"Loss: 0.7309923768043518\n",
"Loss: 0.7178468704223633\n",
"Loss: 0.7321939468383789\n",
"Loss: 0.706930935382843\n",
"Loss: 0.7191136479377747\n",
"Loss: 0.7089812159538269\n",
"Loss: 0.7193553447723389\n",
"Loss: 0.7181675434112549\n",
"Loss: 0.7335854172706604\n",
"Loss: 0.7130610942840576\n",
"Loss: 0.714571475982666\n",
"Loss: 0.7261708378791809\n",
"Loss: 0.7198599576950073\n",
"Loss: 0.7220982909202576\n",
"Loss: 0.7202709317207336\n",
"Loss: 0.7199481129646301\n",
"Loss: 0.71610426902771\n",
"Loss: 0.6913580894470215\n",
"Loss: 0.7348539233207703\n",
"Loss: 0.7180370688438416\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Inference Loop"
],
"metadata": {
"id": "PP6CFN4vMidh"
}
},
{
"cell_type": "code",
"source": [
"model.training=False\n",
"model = model.eval()"
],
"metadata": {
"id": "iNwii9gJ-Utu"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with torch.no_grad():\n",
" current_input = torch.randn(\n",
" 1, 1, channels, image_size, image_size\n",
" ).to(device)\n",
" hidden_state = None # Start with no hidden state\n",
" generated_sequence = []\n",
"\n",
" for _ in range(timesteps):\n",
" # Forward pass\n",
" output, hidden_state = model(current_input, hidden_state, timesteps=2)\n",
"\n",
" # Use output as next input\n",
" current_input = output"
],
"metadata": {
"id": "8NxLub4ZMjpH"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rNmgRa7m-ZEW",
"outputId": "8cafec2a-40ec-4e8f-ca55-94d6e78512ea"
},
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 1, 1, 28, 28])"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"plt.imshow(output[0, 0, 0].detach().cpu().numpy(), cmap=\"gray\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 406
},
"id": "OkJeUd45JooW",
"outputId": "be0a63f0-2584-43f2-bee7-a5fb0aa0bd2d"
},
"execution_count": 20,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHXUlEQVR4nO3cMY6NfR/HYUcmkZEggw6NiERiDyqFBViDBViBxjroprQAQUGrkikUFoDMIGGGDOftPnkLuTPHc875z/O4rvr+Jd8pZj7uwj2bz+fzEwBw4sSJk6MHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQDaO+uBsNlvlDhju6tWroydM2tvbW8vNumxtbY2eMGl/f3/hm4ODgxUsWZ6j/F9lbwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD0Ajou9vb3REybt7++PnrBUx/3nOTw8HD1hCG8KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgs/l8Pj/Sg7PZqrcAsEJH+XPvTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4vSdPnoyeMGl7e3stN+v0+fPn0RMm7e7uLnyzt7e3giXL8fDhw9ETJu3s7Kzl5rjxpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4ve3t7dETJr1582bhm5Mnj/e/QR49ejR6wqRPnz6t5WZddnZ2Rk+Y9OHDh9EThjjev6UArJUoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAZvP5fH6kB2ezVW/hX+RPPm533D+Id/PmzdETJn358mUtN+vy8ePH0RP+Okf5c3+8f0sBWCtRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAZGP0AH7v2rVroydM2t3dXcvNOt25c2f0hEnfvn1b+Obr168rWLIcV65cGT1h0vPnzxe+efHixfKHrJk3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGwc9cFTp06tcsc/dnh4uPDNz58/V7BkOXZ3d0dPmLS/vz96wtK9evVq9IRJv379Gj1hqc6dOzd6wqSDg4PRE4bwpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPKf+SDen/BBPP7fy5cvR0+YtLm5uZabdblx48boCZN+/PgxesIQ3hQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD1gWW7fvr3wza1bt1awZDkeP348esKk9+/fr+Vmne7duzd6wqTz588vfLO1tbWCJctx//790RMmPXjwYOGb169fr2DJenlTACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyMZRH/z+/fsqd/xjb9++XfjmOP9MFy9eHD1h0oULFxa+uX79+gqWLM/ly5dHT5h09uzZhW/OnDmzgiXL8ezZs9ETJr179270hCG8KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgPzVH8T7k5t1uXv37ugJk/7kg3h/crNOly5dGj1h0unTpxe+2dzcXMGS5Xj69OnoCZN8EA+Av54oABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAMpvP5/PRIwA4HrwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ/wHzssF0wLHC6wAAAABJRU5ErkJggg==\n"
},
"metadata": {}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment